In [1]:
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models import resnet50, ResNet50_Weights

from torchvision.ops import FeaturePyramidNetwork

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

from tqdm import tqdm

import numpy as np

import os

import matplotlib.pyplot as plt

from utils.dataset import Dataset

from scipy.stats import pearsonr as corr

from sklearn.decomposition import IncrementalPCA
from sklearn.linear_model import LinearRegression


In [2]:
batch_size = 400

In [3]:
data = Dataset('../../data/subj08')
#test_data = Dataset('../../data/subj08', test=True)

train_set, val_set = torch.utils.data.random_split(data, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)
#test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

Loading dataset sample names...
Training images: 8779
Test images: 395

LH training fMRI data shape:
(8779, 18981)
(Training stimulus images × LH vertices)

RH training fMRI data shape:
(8779, 20530)
(Training stimulus images × RH vertices)


In [4]:
# loading pretrained model
device = torch.device("cuda")

model = resnet50(weights=ResNet50_Weights.DEFAULT)

layer_names = []

for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer_names += [name]

print(layer_names)

feature_extractor = create_feature_extractor(model, 
        return_nodes=["layer2.3.conv3"]).to(device)

feature_extractor.eval()

['conv1', 'layer1.0.conv1', 'layer1.0.conv2', 'layer1.0.conv3', 'layer1.0.downsample.0', 'layer1.1.conv1', 'layer1.1.conv2', 'layer1.1.conv3', 'layer1.2.conv1', 'layer1.2.conv2', 'layer1.2.conv3', 'layer2.0.conv1', 'layer2.0.conv2', 'layer2.0.conv3', 'layer2.0.downsample.0', 'layer2.1.conv1', 'layer2.1.conv2', 'layer2.1.conv3', 'layer2.2.conv1', 'layer2.2.conv2', 'layer2.2.conv3', 'layer2.3.conv1', 'layer2.3.conv2', 'layer2.3.conv3', 'layer3.0.conv1', 'layer3.0.conv2', 'layer3.0.conv3', 'layer3.0.downsample.0', 'layer3.1.conv1', 'layer3.1.conv2', 'layer3.1.conv3', 'layer3.2.conv1', 'layer3.2.conv2', 'layer3.2.conv3', 'layer3.3.conv1', 'layer3.3.conv2', 'layer3.3.conv3', 'layer3.4.conv1', 'layer3.4.conv2', 'layer3.4.conv3', 'layer3.5.conv1', 'layer3.5.conv2', 'layer3.5.conv3', 'layer4.0.conv1', 'layer4.0.conv2', 'layer4.0.conv3', 'layer4.0.downsample.0', 'layer4.1.conv1', 'layer4.1.conv2', 'layer4.1.conv3', 'layer4.2.conv1', 'layer4.2.conv2', 'layer4.2.conv3']


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Module(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=Fal

In [5]:
def fit_pca(feature_extractor, dataloader):
    # Define PCA parameters
    pca = IncrementalPCA(n_components=400, batch_size=batch_size)

    # Fit PCA to batch
    for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        if _ == len(dataloader)-1:
            break
        with torch.no_grad():
            # Extract features
            ft = feature_extractor(d[0].to(device))
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Fit PCA to batch
        pca.partial_fit(ft.detach().cpu().numpy())

    return pca

In [6]:
pca = fit_pca(feature_extractor, train_loader)

  self.noise_variance_ = explained_variance[self.n_components_ :].mean()
  ret = ret.dtype.type(ret / rcount)
  5%|▍         | 1/22 [00:23<08:10, 23.37s/it]

: 

: 

In [None]:
def extract_features(feature_extractor, dataloader, pca, right=False, test=False):
    fmri = []
    features = []
    for _, d in tqdm(enumerate(dataloader), total=len(dataloader)):
        with torch.no_grad():
            # Extract features
            if test == False:
                ft = feature_extractor(d[0].to(device))
                if right == False:
                    fmri += [d[1].cpu().detach().numpy()]
                else:
                    fmri += [d[2].cpu().detach().numpy()]
            else:
                ft = feature_extractor(d.to(device))
        # Flatten the features
        ft = torch.hstack([torch.flatten(l, start_dim=1) for l in ft.values()])
        # Apply PCA transform
        ft = pca.transform(ft.cpu().detach().numpy())
        features += [ft]
    if test == False:
        return (np.vstack(features), np.vstack(fmri))
    return np.vstack(features)

In [None]:
right = False

features_train, labels_train = extract_features(feature_extractor, train_loader, pca, right=right)
features_val, labels_val = extract_features(feature_extractor, val_loader, pca, right=right)

 32%|███▏      | 7/22 [00:45<01:40,  6.67s/it]

0


100%|██████████| 22/22 [02:19<00:00,  6.35s/it]
 40%|████      | 2/5 [00:11<00:17,  5.82s/it]

0


100%|██████████| 5/5 [00:25<00:00,  5.16s/it]


In [None]:
reg = LinearRegression().fit(features_train, labels_train)
fmri_test_pred = reg.predict(features_val)

In [10]:
from scipy.stats import pearsonr as corr

# Empty correlation array of shape: (LH vertices)
lh_correlation = np.zeros(fmri_test_pred.shape[1])
# Correlate each predicted LH vertex with the corresponding ground truth vertex
for v in tqdm(range(fmri_test_pred.shape[1])):
    lh_correlation[v] = corr(fmri_test_pred[:,v], labels_val[:,v])[0]

score = np.median(lh_correlation) * 100
print(score)

100%|██████████| 18981/18981 [00:03<00:00, 6169.47it/s]

33.99590956716545



