In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
cd /notebooks/

/notebooks


In [3]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch import nn
from tqdm.notebook import trange, tqdm
from digital_twin_library import SubsetSampler, corr
import seaborn as sns
from neuralpredictors.data.datasets import FileTreeDataset
from neuralpredictors.measures.modules import PoissonLoss
from torch.utils.data import DataLoader

from neuralpredictors.data.transforms import (
    ToTensor,
    NeuroNormalizer,
    ScaleInputs,
)


# Improving the model

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Clearning up the code

Configure the dataset, samplers and loaders

In [5]:
root_dir = 'data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6'
dat = FileTreeDataset(root_dir, 'images', 'responses')

transforms = [ScaleInputs(scale=0.25), ToTensor(torch.cuda.is_available())]
transforms.insert(0, NeuroNormalizer(dat))
dat.transforms.extend(transforms)

train_sampler = SubsetSampler(dat.trial_info.tiers == 'train', shuffle=True)
test_sampler = SubsetSampler(dat.trial_info.tiers == 'test', shuffle=False)
val_sampler = SubsetSampler(dat.trial_info.tiers == 'validation', shuffle=False)

train_loader = DataLoader(dat, sampler=train_sampler, batch_size=64)
val_loader = DataLoader(dat, sampler=val_sampler, batch_size=64)
test_loader = DataLoader(dat, sampler=test_sampler, batch_size=64)

Write helper functions to reduce boilerplate code. 

In [None]:
def train_epoch(model, loader, optimizer, loss_fn, log=False):
    model.train()
    for images, responses in loader if not log else tqdm(loader):
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, responses)
        loss.backward()
        optimizer.step()
    return model

def get_correlations(model, loader):
    """
    Calculates the correlation between the model's predictions and the actual responses.

    Args:
        model (torch.nn.Module): The trained model.
        loader (torch.utils.data.DataLoader): The data loader containing the images and responses.

    Returns:
        float: The correlation between the model's predictions and the actual responses.
    """
    resp, pred = [], []
    model.eval()
    for images, responses in loader:
        outputs = model(images)
        resp.append(responses.cpu().detach().numpy())
        pred.append(outputs.cpu().detach().numpy())
    resp = np.vstack(resp)
    pred = np.vstack(pred)
    return corr(resp, pred, dim=0)
    

# Let's make the model convolutional

In [None]:
class GaussianReadout(nn.Module):
    def __init__(self, output_dim, channels, spatial_scale, std_scale):
        super(GaussianReadout, self).__init__()
        self.pos_mean = nn.Parameter(torch.zeros(output_dim, 1, 2))
        self.pos_sqrt_cov = nn.Parameter(torch.zeros(output_dim, 2, 2))
        
        self.linear = nn.Parameter(torch.zeros(output_dim, channels))
        self.bias = nn.Parameter(torch.zeros(output_dim))
        
        self.pos_sqrt_cov.data.uniform_(-std_scale, std_scale)
        self.pos_mean.data.uniform_(-spatial_scale, spatial_scale)
        self.linear.data.fill_(1./channels)        
        
        
    def grid_positions(self, batch_size):
        if self.training:
            z = torch.randn(self.pos_mean.shape).to(self.pos_mean.device)
            grid = self.pos_mean + torch.einsum('nuk, njk->nuj', z, self.pos_sqrt_cov)
        else:
            grid = self.pos_mean
        grid = torch.clip(grid, -1, 1)
        return grid.expand(batch_size, -1, -1, -1) 
            
    def forward(self, x):
        batch_size = x.shape[0]
        grid = self.grid_positions(batch_size)
        
        # output will be batch_size x channels x neurons 
        x = torch.nn.functional.grid_sample(x, grid, align_corners=False).squeeze(-1)
        x = torch.einsum('bcn,nc->bn', x, self.linear) + self.bias.view(1, -1)
        return x
    
class ConvModel(nn.Module):
    def __init__(self, layers, input_kern, hidden_kern, hidden_channels, output_dim, spatial_scale = 0.1, std_scale = 0.5):
        super(ConvModel, self).__init__()
        
        self.conv_layers = nn.Sequential()
        core_layers = [nn.Conv2d(1, hidden_channels, input_kern), nn.BatchNorm2d(hidden_channels), nn.SiLU()]
        
        for _ in range(layers - 1):
            core_layers.extend([
                nn.Conv2d(hidden_channels, hidden_channels, hidden_kern),
                nn.BatchNorm2d(hidden_channels),
                nn.ELU()
            ]
            )
        self.core = nn.Sequential(*core_layers)
        
        # self.readout = FullGaussian2d((32, 18, 46), output_dim, bias=False)
        
        self.readout = GaussianReadout(output_dim, hidden_channels, spatial_scale=spatial_scale, std_scale=std_scale)
        
        
    def forward(self, x):
        x = self.core(x)
        x = self.readout(x)
        
        return nn.functional.softplus(x)


In [None]:
model_m = ConvModel(layers=3, input_kern=11, hidden_kern=5, hidden_channels=32, output_dim=8372)
model_m = model_m.to(device)
loss = PoissonLoss() # use different loss
optimizer = torch.optim.Adam(model_m.parameters(), lr=1e-3)

epochs = 50
for epoch in trange(epochs):
    model = train_epoch(model_m, train_loader, optimizer, loss)
    if epoch % 3 == 0:
        train_corrs = get_correlations(model_m, train_loader)
        val_corrs = get_correlations(model_m, val_loader)
        print(f'Epoch [{epoch+1}/{epochs}], Validation correlation: {val_corrs.mean():.4f}, Training correlation: {train_corrs.mean():.4f}')

Let's plot the correlations

In [None]:
train_corrs = get_correlations(model_m, train_loader)
val_corrs = get_correlations(model_m, val_loader)


In [None]:
sns.set_context('notebook', font_scale=1.5)
with sns.axes_style('ticks'):
    fig, ax = plt.subplots(figsize=(8, 6))
sns.histplot(val_corrs, kde=False, ax = ax, color=sns.xkcd_rgb['denim blue'], label='Validation')
sns.histplot(train_corrs, kde=False, ax = ax, color='orange', label='Train')