In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
from torch.nn import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW

from tqdm import tqdm
import numpy as np
import pandas as pd
from utils import load_it_data, visualize_img
import h5py
import os
device = torch.device("mpc" if torch.cuda.is_available() else "cpu")

In [2]:
path_to_data = '' ## Insert the folder where the data is, if you download in the same folder as this notebook then leave it blank

stimulus_train, stimulus_val, stimulus_test, objects_train, objects_val, objects_test, spikes_train, spikes_val = load_it_data(path_to_data)

In [3]:
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
objects_train_encoded = label_encoder.fit_transform(objects_train)

In [4]:
class SpikeData(Dataset):
    def __init__(self,stimulus,objects,label_encoder = None):
        self.stimulus = torch.tensor(stimulus)
        #self.objects = objects
        #self.labels = label_encoder.transform(objects)
        self.labels = torch.tensor(label_encoder.transform(objects))
        #self.spikes = spikes if spikes is not None else np.zeros((stimulus.shape[0], stimulus.shape[1], 1))
        
        self.number_of_class = np.unique(objects).shape[0] 
    def __len__(self):
        return len(self.stimulus)

    def __getitem__(self, idx):
        return self.stimulus[idx], self.labels[idx]

In [5]:
IT5dataloader_train = DataLoader(SpikeData(stimulus_train, objects_train, label_encoder=label_encoder), batch_size=64, shuffle=True)
IT5dataloader_val = DataLoader(SpikeData(stimulus_val,  objects_val, label_encoder=label_encoder), batch_size=64, shuffle=False)
IT5dataloader_test = DataLoader(SpikeData(stimulus_test, objects_test,label_encoder=label_encoder), batch_size=64, shuffle=False)

In [6]:
class ShallowCNN(nn.Module):
    def __init__(self, num_classes):
        super(ShallowCNN, self).__init__()

        # Conv block 1
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)   # (3, 224, 224) → (32, 224, 224)
        self.bn1 = nn.BatchNorm2d(32)
        
        # Conv block 2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # (64, 112, 112)
        self.bn2 = nn.BatchNorm2d(64)
        
        # Conv block 3
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) # (128, 56, 56)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Compute flattened size dynamically
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 224, 224)
            x = self._forward_features(dummy_input)
            self.flattened_size = x.view(1, -1).shape[1]

        # FC layers
        self.fc1 = nn.Linear(self.flattened_size, 256)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, num_classes)

    def _forward_features(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)   # (32, 112, 112)

        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)   # (64, 56, 56)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)   # (128, 28, 28)

        return x

    def forward(self, x):
        x = self._forward_features(x)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x
    
    @staticmethod # defines that the following function does not take self as input
    @torch.no_grad() # ensures that following function runs without tracking gradients
    # making the initialisation faster and more memory-efficient. The parameters remain trainable.
    def init_weights(module):
        # YOUR CODE HERE
        # raise NotImplementedError()

        # He initialization
        # Conv2d layers
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        # Linear layers
        elif isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
            nn.init.zeros_(module.bias)
    
        # Initialize BatchNorm weights to 1 and bias to 0
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.ones_(module.weight)
            nn.init.zeros_(module.bias)

In [7]:
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=2):
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}") as pbar:
            for stimulus, labels in train_loader:
                stimulus, labels = stimulus.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(stimulus)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_postfix(loss=running_loss/len(train_loader))
                pbar.update(1)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for stimulus, labels in val_loader:
                stimulus, labels = stimulus.to(device), labels.to(device)
                outputs = model(stimulus)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        print(f"Validation Loss: {val_loss/len(val_loader):.4f}")
        
    

In [8]:
loss_fn = nn.CrossEntropyLoss()
model = ShallowCNN(num_classes=64)
model.apply(ShallowCNN.init_weights)
model.to(device)
optimizer = AdamW(model.parameters(), lr=0.005, weight_decay=1e-5)
train_model(model, IT5dataloader_train, IT5dataloader_val, loss_fn, optimizer, num_epochs=2)

Epoch 1/2: 100%|██████████| 41/41 [01:39<00:00,  2.44s/it, loss=104]


Epoch [1/2], Loss: 104.1130
Validation Loss: 4.1589


Epoch 2/2: 100%|██████████| 41/41 [01:36<00:00,  2.35s/it, loss=4.16]


Epoch [2/2], Loss: 4.1605
Validation Loss: 4.1592


In [10]:
from collections import defaultdict
layer_outputs = defaultdict(list)
all_labels = []

def save_output(name):
    def hook(module, input, output):
        # Save the detached output for this batch
        layer_outputs[name].append(output.detach().cpu())
    return hook

model.eval()

# Register forward hooks for the layers you want
model.conv1.register_forward_hook(save_output('conv1'))
model.conv2.register_forward_hook(save_output('conv2'))
model.conv3.register_forward_hook(save_output('conv3'))
model.fc1.register_forward_hook(save_output('fc1')) 

with torch.no_grad():
    for images, labels in IT5dataloader_val:  # all batches
        _ = model(images)  # triggers the hooks
        all_labels.append(labels.cpu())

In [11]:
layer_matrices = {}
for layer_name, outputs in layer_outputs.items():
    # outputs is a list of [batch_size, C, H, W] or [batch_size, D]
    flat = [out.view(out.size(0), -1) for out in outputs]  # flatten per batch
    layer_matrix = torch.cat(flat, dim=0)  # (n_samples, n_features)
    layer_matrices[layer_name] = layer_matrix

# Also flatten all labels
all_labels = torch.cat(all_labels, dim=0)  # shape: (n_samples,)

In [12]:
layer_matrices['conv1']

tensor([[-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733],
        [-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733],
        [-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733],
        ...,
        [-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733],
        [-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733],
        [-0.0321, -0.1428, -0.1428,  ..., -0.0802, -0.0802, -0.0733]])

In [None]:
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.feature_selection import r_regression
from sklearn.metrics import explained_variance_score
from sklearn.model_selection import cross_validate,KFold

scaler = StandardScaler()
pca = PCA(n_components=1000)

X_train = scaler.fit_transform(layer_matrices['conv1'])
X_train.shape
#X_train = pca.fit_transform(X_train)


In [None]:
from sklearn.metrics import make_scorer
#from scipy.stats import pearsonr
# Define a scoring function compatible with make_scorer
def pearson_score(y_true, y_pred):
    scores = []
    for i in range(y_true.shape[1]):
        r = np.corrcoef(y_true[:, i], y_pred[:, i])[0, 1]
        if np.isnan(r):
            r = 0
        if np.isinf(r):
            r = 0
        scores.append(r)
    return np.mean(scores)

# Wrap it as a scorer (greater_is_better=True is default)
pearson_scorer = make_scorer(pearson_score, greater_is_better=True)

In [None]:
from sklearn.linear_model import Ridge
from sklearn.model_selection import GridSearchCV, GroupKFold

# Ridge model
ridge = Ridge(max_iter=10000, tol=1e-6)

# Grid of alpha values
param_grid = {'alpha': np.logspace(-4, 4, 10)}

# GroupKFold ensures the same label distribution in each fold
cv = GroupKFold(n_splits=10)

# Wrap with GridSearchCV, pass groups
grid = GridSearchCV(ridge, param_grid, cv=cv, scoring=pearson_scorer, n_jobs=-1)

# Fit with groups controlling label distribution
grid.fit(X_train_pca, spikes_train, groups=objects_train_encoded)

# Results
print("Best alpha:", grid.best_params_['alpha'])
print("Best pearson correlation:", grid.best_score_)