#Install necessary packages

In [None]:
!pip install nilearn==0.9.2

#Import libraries

In [13]:
import os
import numpy as np
from pathlib import Path
from PIL import Image
import matplotlib
from matplotlib import pyplot as plt
from nilearn import datasets, plotting
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

#Mount to drive

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)
data_dir = '/content/drive/MyDrive/algonauts_2023_tutorial_data'
parent_submission_dir = '/content/drive/MyDrive/algonauts_2023_challenge_submission'

#Select device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device(device)
print(device)

#Import data

In [16]:
subj = 1 #@param ["1", "2", "3", "4", "5", "6", "7", "8"] {type:"raw", allow-input: true}

In [17]:
class argObj:
  def __init__(self, data_dir, parent_submission_dir, subj):

    self.subj = format(subj, '02')
    self.data_dir = os.path.join(data_dir, 'subj'+self.subj)
    self.parent_submission_dir = parent_submission_dir
    self.subject_submission_dir = os.path.join(self.parent_submission_dir,
        'subj'+self.subj)

args = argObj(data_dir, parent_submission_dir, subj)

In [None]:
fmri_dir = os.path.join(args.data_dir, 'training_split', 'training_fmri')
rh_fmri = np.load(os.path.join(fmri_dir, 'rh_training_fmri.npy'))

print('\nRH training fMRI data shape:')
print(rh_fmri.shape)
print('(Training stimulus images × RH vertices)')

In [None]:
train_img_dir  = os.path.join(args.data_dir, 'training_split', 'training_images')
test_img_dir  = os.path.join(args.data_dir, 'test_split', 'test_images')

# Create lists will all training and test image file names, sorted
train_img_list = os.listdir(train_img_dir)
train_img_list.sort()
test_img_list = os.listdir(test_img_dir)
test_img_list.sort()
print('Training images: ' + str(len(train_img_list)))
print('Test images: ' + str(len(test_img_list)))

#Train Validation and Test Split

In [None]:
rand_seed = 5
np.random.seed(rand_seed)

# Calculate how many stimulus images correspond to 90% of the training data
num_train = int(np.round(len(train_img_list) / 100 * 90))
# Shuffle all training stimulus images
idxs = np.arange(len(train_img_list))
np.random.shuffle(idxs)
# Assign 90% of the shuffled stimulus images to the training partition,
# and 10% to the test partition
idxs_train, idxs_val = idxs[:num_train], idxs[num_train:]
# No need to shuffle or split the test stimulus images
idxs_test = np.arange(len(test_img_list))

print('Training stimulus images: ' + format(len(idxs_train)))
print('\nValidation stimulus images: ' + format(len(idxs_val)))
print('\nTest stimulus images: ' + format(len(idxs_test)))

#Dataloader

In [21]:
# Define the custom dataset
class ImageDataset(Dataset):
    def __init__(self, imgs_paths, idxs, transform):
        self.imgs_paths = np.array(imgs_paths)[idxs]
        self.transform = transform

    def __len__(self):
        return len(self.imgs_paths)

    def __getitem__(self, idx):
        img_path = self.imgs_paths[idx]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img).to(device)
        return img

#Transfer Learning and Training Loop

In [None]:

# Define the transform for image preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # resize the images to 224x224 pixels
    transforms.ToTensor(),  # convert the images to a PyTorch tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # normalize the images color channels
])

# Remove the last layer of the pretrained model
model = models.vgg19(pretrained=True)
model.to(device)
model.eval()
feature_extractor = nn.Sequential(*list(model.children())[:-1])

class LinearizingEncodingModel(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim1, hidden_dim2, activation1, activation2,
                 bnorm1, bnorm2, dropout1, dropout_ratio1, dropout2, dropout_ratio2):
        super(LinearizingEncodingModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        if activation1:
            self.activation1 = activation1()
        if bnorm1:
            self.batchnorm1 = nn.BatchNorm1d(hidden_dim1)
        if dropout1:
            self.dropout1 = nn.Dropout(dropout_ratio1)
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        if activation2:
            self.activation2 = activation2()
        if bnorm2:
            self.batchnorm2 = nn.BatchNorm1d(hidden_dim2)
        if dropout2:
            self.dropout2 = nn.Dropout(dropout_ratio2)
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.fc1(x) #initial
        if hasattr(self, 'activation1'):
            x = self.activation1(x) #initial
        if hasattr(self, 'batchnorm1'):
            x = self.batchnorm1(x)
        if hasattr(self, 'dropout1'):
            x = self.dropout1(x)
        x = self.fc2(x) #initial
        if hasattr(self, 'activation2'):
            x = self.activation2(x)
        if hasattr(self, 'batchnorm2'):
            x = self.batchnorm2(x)
        if hasattr(self, 'dropout2'):
            x = self.dropout2(x)
        x = self.fc3(x)
        return x

def train_linearizing_encoding_model(network, train_dataloader, train_targets, val_dataloader, val_targets, num_epochs,
                                     batch_size, loss_function, optimizer, learning_rate, weight_decay, save_name=None, patience=3):
    criterion = loss_function
    optimizer = optimizer(network.parameters(), lr=learning_rate, weight_decay=0.01)
    network.train()

    train_losses = []
    val_losses = []

    best_val_loss = float('inf')
    best_epoch = 0
    early_stopping_counter = 0

    for epoch in range(num_epochs):
        training_loss = 0.0
        for index, data in enumerate(train_dataloader):

            inputs = data.to(device)
            inputs = feature_extractor(inputs)
            inputs = inputs.view(inputs.size(0), -1)

            targets_batch = torch.tensor(train_targets[index*batch_size : index*batch_size + batch_size if index+batch_size <= train_targets.shape[0] else train_targets.shape[0]-index*batch_size]).to(device)

            optimizer.zero_grad()
            outputs = network(inputs)
            loss = criterion(outputs, targets_batch)
            loss.backward()
            optimizer.step()

            training_loss += loss.item()

        training_loss /= len(train_dataloader)
        train_losses.append(training_loss)
        print(f'Epoch {epoch + 1}/{num_epochs}, Training Loss: {training_loss:.4f}')

        network.eval()
        val_loss = 0.0
        with torch.no_grad():
            for index, data in enumerate(val_dataloader):
                inputs = data.to(device)
                inputs = feature_extractor(inputs)
                inputs = inputs.view(inputs.size(0), -1)

                targets_batch = torch.tensor(val_targets[index * batch_size: (index + 1) * batch_size]).to(device)

                outputs = network(inputs)
                loss = criterion(outputs, targets_batch)
                val_loss += loss.item()

        val_loss /= len(val_dataloader)
        val_losses.append(val_loss)
        print(f'Validation - Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}')

        # Check if the current validation loss is the best so far
        if round(val_loss, 2) < round(best_val_loss, 2):
            best_val_loss = val_loss
            best_epoch = epoch
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        # Check if early stopping criterion is met
        if early_stopping_counter >= patience:
            print(f'Early stopping triggered. No improvement in {patience} epochs.')
            break

    if save_name:
        torch.save(network.state_dict(), save_name+'.pt')

    return train_losses, val_losses


#Model Training

In [None]:
#Train and validation targets
rh_fmri_train = rh_fmri[idxs_train]
rh_fmri_val = rh_fmri[idxs_val]


# Get the output shape of the feature extractor layer
with torch.no_grad():
    sample_input = torch.zeros(1, 3, 224, 224).to(device)
    output = feature_extractor(sample_input)

#The hyperparameter values where selected using optuna
input_dim = output.flatten().shape[0]  # Set the dimensions for input and output of thr pretrained model
output_dim = rh_fmri_train.shape[1]
hidden_dim1 = 6487
hidden_dim2 = 711
num_epochs = 50
activation1 = nn.ReLU
activation2 = nn.Tanh
bnorm1 = True
bnorm2 = True
dropout1 = False
dropout_ratio1 = 0.4666007846830111
dropout2 = True
dropout_ratio2 = 0.098405552999689
learning_rate = 0.008417049414747052
optimizer = optim.SGD
weight_decay = 0.07356043913788857
loss_function = nn.MSELoss()
batch_size = 40

train_imgs_paths = sorted(list(Path(train_img_dir).iterdir()))
test_imgs_paths = sorted(list(Path(test_img_dir).iterdir()))
train_imgs_dataloader = DataLoader(
    ImageDataset(train_imgs_paths, idxs_train, transform),
    batch_size=batch_size
)
val_imgs_dataloader = DataLoader(
    ImageDataset(train_imgs_paths, idxs_val, transform),
    batch_size=batch_size
)

network = LinearizingEncodingModel(input_dim, output_dim, hidden_dim1,  hidden_dim2, activation1, activation2, bnorm1, bnorm2, dropout1, dropout_ratio1, dropout2, dropout_ratio2).to(device)
train_losses, val_losses = train_linearizing_encoding_model(network, train_imgs_dataloader, rh_fmri_train, val_imgs_dataloader, rh_fmri_val, num_epochs, batch_size, loss_function, optimizer,  learning_rate, weight_decay, "vgg16_right_hemishpere", 2)


#Learning Curves

In [None]:
plt.plot(range(1, len(train_losses)+1), train_losses, label='Training Loss')
plt.plot(range(1, len(val_losses)+1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.ylim(0, 1)
plt.legend()
plt.savefig('vgg16_learning_curves_right_hemishpere.pdf', format='pdf')
plt.show()