# Main training notebook

## Imports

In [1]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.ops import FeaturePyramidNetwork

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

from nilearn import datasets
from nilearn import plotting

from tqdm import tqdm

import numpy as np

import os

import matplotlib.pyplot as plt

from scipy.stats import pearsonr as corr

from utils.dataset_test import Dataset
from utils.dataset import ReferenceDataset
from utils.model import RegressionHead

In [2]:
side = "right"
hemisphere = "rh"
subject = "subj01"
batch_size = 159
vertex_count = 20544

## Load dataset

In [5]:
# loading dataset + creating train test split for verifying performance
dataset = Dataset("../../data/" + subject)
reference_dataset = ReferenceDataset("../../data/" + subject, side=hemisphere)

test_loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
reference_loader = torch.utils.data.DataLoader(reference_dataset, batch_size=32, shuffle=False)
stats = reference_dataset.getStats()

Loading dataset sample names...
Test images: 159
Loading dataset sample names...
Training images: 9841
Test images: 159


## Feature extractor + trainable regression head instantiation

In [4]:
# loading pretrained model
device = torch.device("cuda")

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

layer_names = []

for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer_names += [name]

print(layer_names)

feature_extractor = create_feature_extractor(model, 
        return_nodes=["layer1.1.conv1", "layer2.0.conv1", "layer3.0.conv1", "layer4.0.conv1"]).to(device)
fpn = FeaturePyramidNetwork([64, 128, 256, 512], 256).to(device)

feature_extractor.eval()


# instantiating trainable head
head = RegressionHead(vertex_count).to(device)
head.load_state_dict(torch.load("saved_models/5_epochs_aux_" + subject + "_" + side))
head.eval()

['conv1', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.0.conv3', 'layer1.0.downsample.0', 'layer1.1.conv1', 'layer1.1.conv2', 'layer1.1.conv3', 'layer1.2.conv1', 'layer1.2.conv2', 'layer1.2.conv3', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.conv3', 'layer2.0.downsample.0', 'layer2.1.conv1', 'layer2.1.conv2', 'layer2.1.conv3', 'layer2.2.conv1', 'layer2.2.conv2', 'layer2.2.conv3', 'layer2.3.conv1', 'layer2.3.conv2', 'layer2.3.conv3', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.0.conv3', 'layer3.0.downsample.0', 'layer3.1.conv1', 'layer3.1.conv2', 'layer3.1.conv3', 'layer3.2.conv1', 'layer3.2.conv2', 'layer3.2.conv3', 'layer3.3.conv1', 'layer3.3.conv2', 'layer3.3.conv3', 'layer3.4.conv1', 'layer3.4.conv2', 'layer3.4.conv3', 'layer3.5.conv1', 'layer3.5.conv2', 'layer3.5.conv3', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.0.conv3', 'layer4.0.downsample.0', 'layer4.1.conv1', 'layer4.1.conv2', 'layer4.1.conv3', 'layer4.2.conv1', 'layer4.2.conv2', 'layer4.2.conv3']


RegressionHead(
  (poolx2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (upsamplex2): Upsample(scale_factor=2.0, mode=nearest)
  (bnorm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bnorm4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (final_conv1): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bnorm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (final_conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bnorm6): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (final_conv3): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (bnorm7): BatchNorm2d(1024, e

## Saving files

In [132]:
len(dataset)
with torch.no_grad():
    for i, (image, label) in enumerate(reference_loader):
        if(i > 0):
            break
        image = image.to(device)
        label = label.to(device)
        outputs = feature_extractor(image)
        outputs = fpn(outputs)
        outputs = head(outputs)

        loss = nn.MSELoss()
        print(loss(outputs, label))