In [82]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.ops import FeaturePyramidNetwork

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

from tqdm import tqdm

import numpy as np

import os

import matplotlib.pyplot as plt

from utils.dataset import Dataset

from scipy.stats import pearsonr as corr

from sklearn.decomposition import IncrementalPCA
from sklearn.linear_model import LinearRegression


In [2]:
EPOCHS = 100
lr = 0.0001
batch_size = 300
l2 = 0

In [3]:
data = Dataset('../../data/subj01')

train_set, val_set = torch.utils.data.random_split(data, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

Loading dataset sample names...
Training images: 9841
Test images: 159

LH training fMRI data shape:
(9841, 19004)
(Training stimulus images × LH vertices)

RH training fMRI data shape:
(9841, 20544)
(Training stimulus images × RH vertices)


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet')
alexnet.to(device) # send the alexnet to the chosen device ('cpu' or 'cuda')
alexnet.eval() # set the alexnet to evaluation mode, since you are not training it

Using cache found in /home/ubuntu/.cache/torch/hub/pytorch_vision_v0.10.0


AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [5]:
model_layer = "features.2" #@param ["features.2", "features.5", "features.7", "features.9", "features.12", "classifier.2", "classifier.5", "classifier.6"] {allow-input: true}
feature_extractor = create_feature_extractor(alexnet, return_nodes=[model_layer])

In [24]:
#lf_fmri = []

def fit_pca(feature_extractor, dataloader):
    # Define PCA parameters
    pca = IncrementalPCA(n_components=100, batch_size=batch_size)

    # Fit PCA to batch
    for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        if _ == len(dataloader)-1:
            break
        #lh_fmri += [d[1]]
        # Extract features
        ft = feature_extractor(d[0].to(device))
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Fit PCA to batch
        pca.partial_fit(ft.detach().cpu().numpy())

    return pca

In [25]:
pca = fit_pca(feature_extractor, train_loader)

 96%|█████████▋| 26/27 [02:51<00:06,  6.58s/it]


In [79]:
def extract_features(feature_extractor, dataloader, pca):
    lh_fmri = []
    features = []
    for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        if _ == len(dataloader)-1:
            break
        # Extract features
        ft = feature_extractor(d[0].to(device))
        lh_fmri += [d[1].cpu().detach().numpy()]
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Apply PCA transform
        ft = pca.transform(ft.cpu().detach().numpy())
        features += [ft]
    return (np.vstack(features), np.vstack(lh_fmri))

In [80]:
features_train, labels_train = extract_features(feature_extractor, train_loader, pca)
features_val, labels_val = extract_features(feature_extractor, val_loader, pca)

 96%|█████████▋| 26/27 [01:18<00:03,  3.01s/it]
 86%|████████▌ | 6/7 [00:21<00:03,  3.64s/it]


In [None]:
del pca

In [84]:
reg_lh = LinearRegression().fit(features_train, labels_train)
lh_fmri_val_pred = reg_lh.predict(features_val)

In [87]:
from scipy.stats import pearsonr as corr

# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(lh_fmri_val_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(lh_fmri_val_pred.shape[1])):
    lh_correlation[v] = corr(lh_fmri_val_pred[:,v], labels_val[:,v])[0]

score = np.median(lh_correlation) * 100
print(score)

100%|██████████| 19004/19004 [00:01<00:00, 10330.17it/s]

21.263217314225553



