# Main training notebook

Lowest val loss: 0.022405 after 3 epochs

Lowest train loss: 0.02135

## Imports

In [None]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.ops import FeaturePyramidNetwork

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

from nilearn import datasets
from nilearn import plotting

from tqdm import tqdm

import numpy as np

import os

import matplotlib.pyplot as plt

from utils.dataset import Dataset
from utils.model import RegressionHead

## Hyperparameters

In [None]:
EPOCHS = 5
lr = 2e-3
batch_size = 32
l2 = 0

## Load dataset

In [None]:
# loading dataset + creating train test split for verifying performance
dataset = Dataset("../../data/subj01")

train_set, val_set = torch.utils.data.random_split(dataset, [9000, 841])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

## Feature extractor + trainable regression head instantiation

In [None]:
# loading pretrained model
device = torch.device("cuda")

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

layer_names = []

for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer_names += [name]

print(layer_names)

feature_extractor = create_feature_extractor(model, 
        return_nodes=["layer1.1.conv1", "layer2.0.conv1", "layer3.0.conv1", "layer4.0.conv1"]).to(device)
fpn = FeaturePyramidNetwork([64, 128, 256, 512], 256).to(device)


# instantiating trainable head
head = RegressionHead().to(device)

## Instantiating loss function + optimizer

In [None]:
# used to make spread more closely match target distribution
def squared_spread_loss():
    def loss(output, target):
        output_std = torch.std(output)
        target_std = torch.std(target)
        return (output_std - target_std) ** 2
    return loss

In [None]:
mse_weight = 0.5
spread_weight = 0.5

criterion = nn.MSELoss() # loss
auxiliary_spread_loss = squared_spread_loss() # aucilliary loss
optimizer = torch.optim.Adam(head.parameters(), lr=lr, weight_decay=l2) # optimizer

## Selecting ROI vertices

In [None]:
img = 0 #@param
hemisphere = 'left' #@param ['left', 'right'] {allow-input: true}
roi = "EBA" #@param ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4", "EBA", "FBA-1", "FBA-2", "mTL-bodies", "OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces", "OPA", "PPA", "RSC", "OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words", "early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"] {allow-input: true}


# pulling sample
sample = next(iter(val_loader))

# Define the ROI class based on the selected ROI
if roi in ["V1v", "V1d", "V2v", "V2d", "V3v", "V3d", "hV4"]:
    roi_class = 'prf-visualrois'
elif roi in ["EBA", "FBA-1", "FBA-2", "mTL-bodies"]:
    roi_class = 'floc-bodies'
elif roi in ["OFA", "FFA-1", "FFA-2", "mTL-faces", "aTL-faces"]:
    roi_class = 'floc-faces'
elif roi in ["OPA", "PPA", "RSC"]:
    roi_class = 'floc-places'
elif roi in ["OWFA", "VWFA-1", "VWFA-2", "mfs-words", "mTL-words"]:
    roi_class = 'floc-words'
elif roi in ["early", "midventral", "midlateral", "midparietal", "ventral", "lateral", "parietal"]:
    roi_class = 'streams'

# Load the ROI brain surface maps
challenge_roi_class_dir = os.path.join("../../data/subj01/", 'roi_masks',
    hemisphere[0]+'h.'+roi_class+'_challenge_space.npy')
fsaverage_roi_class_dir = os.path.join("../../data/subj01/", 'roi_masks',
    hemisphere[0]+'h.'+roi_class+'_fsaverage_space.npy')
roi_map_dir = os.path.join("../../data/subj01/", 'roi_masks',
    'mapping_'+roi_class+'.npy')
challenge_roi_class = np.load(challenge_roi_class_dir)
fsaverage_roi_class = np.load(fsaverage_roi_class_dir)
roi_map = np.load(roi_map_dir, allow_pickle=True).item()

# Select the vertices corresponding to the ROI of interest
roi_mapping = list(roi_map.keys())[list(roi_map.values()).index(roi)]
challenge_roi = np.asarray(challenge_roi_class == roi_mapping, dtype=int)
fsaverage_roi = np.asarray(fsaverage_roi_class == roi_mapping, dtype=int)

## Main training loop

In [None]:
train_loss = 0
mse_train_loss = 0
val_loss = 0
mse_val_loss = 0
count = 0

all_train_loss_vals = []
all_val_loss_vals = []

for epoch in range(EPOCHS):

    print("\nEpoch " + str(epoch))

    # setting to train mode for gradient calculations
    head.train()

    for i, (inputs, targets) in tqdm(enumerate(train_loader), total=int(9000/batch_size)+1):
        inputs = inputs.to(device)
        targets = targets.to(device)[:, np.where(challenge_roi)[0]] # selecting proper vertices based on ROI

        optimizer.zero_grad()

        # feature extractor backbone
        outputs = feature_extractor(inputs)
        outputs = fpn(outputs)

        # trainable head
        outputs = head(outputs)

        mse_loss = (mse_weight * criterion(outputs, targets))
        loss = mse_loss # + (spread_weight * auxiliary_spread_loss(outputs, targets))
        loss.backward()

        del inputs
        del targets

        #mse_train_loss = mse_train_loss + mse_loss.item()
        train_loss = train_loss + loss.item()

        count += 1

        optimizer.step()
    
    torch.cuda.empty_cache() # frees up memory for val

    all_train_loss_vals += [train_loss/count]
    
    print("Train loss: " + str(train_loss/count))
    #print("MSE train loss: " + str(mse_train_loss/count))

    count = 0

    head.eval()
    
    for i, (inputs, targets) in tqdm(enumerate(val_loader), total=int(841/batch_size)+1):
        inputs = inputs.to(device)
        targets = targets.to(device)[:, np.where(challenge_roi)[0]] # selecting proper vertices based on ROI

        # feature extractor backbone
        outputs = feature_extractor(inputs)
        outputs = fpn(outputs)

        # trainable head
        outputs = head(outputs)

        mse_loss = (mse_weight * criterion(outputs, targets))
        loss = mse_loss #+ (spread_weight * auxiliary_spread_loss(outputs, targets))

        del inputs
        del targets

        #mse_val_loss = mse_val_loss + mse_loss.item()
        val_loss = val_loss + loss.item()
        count += 1
    
    torch.cuda.empty_cache()

    all_val_loss_vals += [val_loss/count]
    
    print("Val loss: " + str(val_loss/count))
    # print("MSE val loss: " + str(mse_val_loss/count))

    train_loss = 0
    mse_train_loss = 0
    val_loss = 0
    mse_val_loss = 0
    
    count = 0

In [None]:
plt.plot([i for i in range(EPOCHS)], all_train_loss_vals, label="Train")
plt.plot([i for i in range(EPOCHS)], all_val_loss_vals, label="Val")
plt.legend()
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.savefig("./loss_graphs/5_epochs_roi_EBA.png")

## Saving trained model

In [None]:
torch.save(head.state_dict(), "saved_models/5_epochs_roi_EBA")

## Visualizing training results

In [None]:
### original visualization script pulled from challenge provided notebook

# pulling sample
sample = next(iter(val_loader))

# Map the truth fMRI data onto the brain surface map
fsaverage_response_truth = np.zeros(len(fsaverage_roi))
if hemisphere == 'left':
    fsaverage_response_truth[np.where(fsaverage_roi)[0]] = \
        sample[1][0][np.where(challenge_roi)[0]]
elif hemisphere == 'right':
    fsaverage_response_truth[np.where(fsaverage_roi)[0]] = \
        sample[1][0][np.where(challenge_roi)[0]]

# Map the predicted fMRI data onto the brain surface map
feature_extractor.to("cpu")
fpn.to("cpu")
head.to("cpu")

visualize_output = feature_extractor(sample[0])
visualize_output = fpn(visualize_output)
visualize_output = head(visualize_output).detach()

fsaverage_response_predicted = np.zeros(len(fsaverage_roi))
if hemisphere == 'left':
    fsaverage_response_predicted[np.where(fsaverage_roi)[0]] = \
        visualize_output[0]
elif hemisphere == 'right':
    fsaverage_response_predicted[np.where(fsaverage_roi)[0]] = \
        visualize_output[0]

# Create the interactive brain surface map
fsaverage = datasets.fetch_surf_fsaverage('fsaverage')
view1 = plotting.view_surf(
    surf_mesh=fsaverage['infl_'+hemisphere],
    surf_map=fsaverage_response_truth,
    bg_map=fsaverage['sulc_'+hemisphere],
    threshold=1e-14,
    cmap='cold_hot',
    colorbar=True,
    title=roi+', '+hemisphere+' hemisphere')

view2 = plotting.view_surf(
    surf_mesh=fsaverage['infl_'+hemisphere],
    surf_map=fsaverage_response_predicted,
    bg_map=fsaverage['sulc_'+hemisphere],
    threshold=1e-14,
    cmap='cold_hot',
    colorbar=True,
    title=roi+', '+hemisphere+' hemisphere')

In [None]:
view1

In [None]:
view2

## Output distribution visualization

In [None]:
sample = next(iter(val_loader))
print("Output dims: " + str(visualize_output[0].size()))
print("Target dims: " + str(sample[1][0][np.where(challenge_roi)[0]].size()))

print("Output mean: " + str(torch.mean(visualize_output[0]) ** 2))
print("Target mean: " + str(torch.mean(sample[1][0][np.where(challenge_roi)[0]]) ** 2))

loss = criterion(visualize_output[0], sample[1][0][np.where(challenge_roi)[0]])

print("Loss: " + str(loss))

plt.hist(sample[1][0][np.where(challenge_roi)[0]].numpy(), bins=100, label="target")
plt.hist(visualize_output[0].numpy(), bins=100, label="model output")
plt.legend(loc="upper right")
plt.savefig("histograms/" + "5_epochs_roi_EBA" + ".output.jpg")