In [1]:
from torch.utils.data import Dataset, DataLoader
import torch
from torchvision.models import resnet152, ResNet152_Weights
import torch.nn as nn
from torch.optim import NAdam

from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score, explained_variance_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils import load_it_data, visualize_img
import matplotlib.pyplot as plt
import numpy as np
import collections
from tqdm import tqdm
import gdown
import os

output = "IT_data.h5"

if not os.path.exists(output):
    url = "https://drive.google.com/file/d/1s6caFNRpyR9m7ZM6XEv_e8mcXT3_PnHS/view?usp=share_link"
    gdown.download(url, output, quiet=False, fuzzy=True)

# Load the data

In [3]:
path_to_data = ""

(
    stimulus_train,
    stimulus_val,
    stimulus_test,
    objects_train,
    objects_val,
    objects_test,
    spikes_train,
    spikes_val,
) = load_it_data(path_to_data)

# Training / Evaluation

In [4]:

def train_model(model,scheduler, train_loader, validation_dataloader, criterion, optimizer, num_epochs=10, device='cuda', max_patience=20):
    model = model.to(device)

    model.train()

    best_loss = 1e3
    patience = 0
    with tqdm(total=num_epochs, desc=f"Epoch 0/{num_epochs}") as pbar:

        for epoch in tqdm(range(num_epochs), desc=f"Training epochs"):
            running_loss = 0.0
            val_loss = 0.0
            model.train()
            for inputs, labels in train_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()
                if not scheduler is None:
                    scheduler.step()

                running_loss += loss.item() * inputs.size(0)
            model.eval()
            with torch.no_grad():
                for inputs, labels in validation_dataloader:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    val_loss += loss.item() * inputs.size(0)

            epoch_loss = running_loss / len(train_loader.dataset)
            val_loss = val_loss / len(validation_dataloader.dataset)
            if val_loss<best_loss:
                best_model_state_dict = {k:v.detach().to('cpu') for k, v in model.state_dict().items()}
                best_model_state_dict = collections.OrderedDict(best_model_state_dict)
                best_loss = val_loss
                patience = 0
            else:
                patience+=1

            if patience == max_patience:
                break
            pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
            pbar.set_postfix(train_loss=epoch_loss , val_loss=val_loss, patience=patience)
            pbar.update(1)
            pbar.refresh() 
            
    model.load_state_dict(best_model_state_dict)

    return model

In [5]:
def evaluate(model, validation_dataloader, device='cuda'):
    model.to(device)
    model.eval()
    predictions = []
    true_labels = []
    for inputs, labels in validation_dataloader:
        inputs = inputs.to(device)

        predictions.append(model(inputs).detach().cpu())
        true_labels.append(labels)
    y_pred = torch.cat(predictions, dim=0).numpy()
    y_true = torch.cat(true_labels, dim=0).numpy()
    
    mse = mean_squared_error(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
    explained_variance = explained_variance_score(y_true, y_pred)
    
    print(f"MSE: {mse}")
    print("R2:", r2)
    print("Explained Variance:", explained_variance)

In [6]:
class CustomTensorDataset(Dataset):
    """TensorDataset with support of transforms."""

    def __init__(self, tensors, transform=None):
        assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
        self.tensors = tensors
        self.transform = transform

    def __getitem__(self, index):
        x = self.tensors[0][index]

        if self.transform:
            x = self.transform(x)

        y = self.tensors[1][index]

        return x, y

    def __len__(self):
        return self.tensors[0].size(0)

In [7]:
class CustomLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomLayer, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

        nn.init.kaiming_normal_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        x = self.fc(x)

        return x

# Data driven approach

## Our best-performing model: pretrained Resnet-152

In [8]:
batch_size = 64

In [9]:
train_dataset = CustomTensorDataset(
    (
        torch.Tensor(stimulus_train),
        torch.Tensor(spikes_train),
    ),
    transform=ResNet152_Weights.IMAGENET1K_V1.transforms(),
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [10]:
validation_dataset = CustomTensorDataset(
    (torch.Tensor(stimulus_val), torch.Tensor(spikes_val)),
    transform=ResNet152_Weights.IMAGENET1K_V1.transforms(),
)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

In [11]:
model = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)

in_features = model.fc.in_features

custom_layer = CustomLayer(in_features, spikes_train.shape[1])
model.fc = custom_layer

criterion = nn.MSELoss()
num_epochs = 200
lr = 1e-3

optimizer = NAdam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer=optimizer,
    max_lr=lr,
    total_steps=num_epochs * len(train_loader),
    pct_start=0.1,
    anneal_strategy="linear",
    cycle_momentum=False,
    div_factor=1e2,
    final_div_factor=0.1,
)

model = train_model(
    model, scheduler, train_loader, validation_loader, criterion, optimizer, num_epochs=num_epochs
)

torch.save(model.state_dict(), "resnet152_best_model.pth")

model.eval()
print("##################")
evaluate(model, validation_loader)

Training epochs:  64%|██████▎   | 127/200 [51:30<29:36, 24.33s/it]atience=19, train_loss=0.00773, val_loss=0.0681]
Epoch 127/200:  64%|██████▎   | 127/200 [51:30<29:36, 24.33s/it, patience=19, train_loss=0.00773, val_loss=0.0681]


##################
MSE: 0.06672779470682144
R2: 0.43175606231408487
Explained Variance: 0.452819168922447
