# Assignment 5 - Modeling in Neuroscience

# Creating the Data

Important message is that we tried getting the data from the siibra API, however we encountered many problems along the way. Mainly we were struggling to find any relevant data for all of our regions (V1, V2, V4, IT, MT, LIP). We could only get some info about cell density and structural connectivity for the V1 region in human brain (there is a possiblity that we were fetching the data in a incorrect manner). Nevertheless, in the end we decided to create a synthetic connectivity matrix, based on some biological literature, so that we preserve the neuroanatomical inductive biases of our models.

In [None]:
import numpy as np
import pandas as pd

# we found that there are 6 visual areas common across the three species
# V1, V2, V4, IT, MT, LIP

# for the ventral stream: V1 -> V2 -> V4 -> IT
# for the dorsal stream: V1 -> V2 -> MT -> LIP
# so the V1,V2 are shared, then the streams split into two paths

regions = ['V1', 'V2', 'V4', 'IT', 'MT', 'LIP']

# we assumed the connection strengths between the above mentioned regions based on the following studies:

# Felleman, D. J., & Van Essen, D. C. (1991). Distributed hierarchical processing in the primate cerebral cortex. Cerebral cortex (New York, N.Y. : 1991), 1(1), 1–47. https://doi.org/10.1093/cercor/1.1.1-a
# Kaas, J. H. (2001). The organization of sensory cortex. Current Opinion in Neurobiology, 11(4), 498–504. https://doi.org/10.1016/S0959-4388(00)00240-3

def get_connectivity_matrix(species):
    """
    Create connectivity matrix for visual pathways.
    """
    connectivity = np.zeros((6, 6))
    
    if species == 'human':    
        connectivity[0, 1] = 0.90 # connection between V1 and V2 (ventral and dorsal stream shared)
        connectivity[1, 2] = 0.75 # connection between V2 and V4 (ventral stream)
        connectivity[2, 3] = 0.80 # connection between V4 and IT (ventral stream)
        connectivity[1, 4] = 0.70 # connection between V2 and MT (dorsal stream)
        connectivity[4, 5] = 0.75 # connection between MT and LIP (dorsal stream)
        
    elif species == 'marmoset':
        connectivity[0, 1] = 0.85 # V1→V2
        connectivity[1, 2] = 0.70 # V2→V4
        connectivity[2, 3] = 0.75 # V4→IT
        connectivity[1, 4] = 0.65 # V2→MT
        connectivity[4, 5] = 0.70 # MT→LIP
        
    elif species == 'mouse':
        connectivity[0, 1] = 0.75 # V1→V2
        connectivity[1, 2] = 0.60 # V2→V4
        connectivity[2, 3] = 0.65 # V4→IT
        connectivity[1, 4] = 0.55 # V2→MT
        connectivity[4, 5] = 0.60 # MT→LIP
        
        # we 
        connectivity[0, 2] = 0.30  # V1→V4 shortcut
        connectivity[0, 4] = 0.35  # V1→MT shortcut
    
    return connectivity

# generate data for each species
species_data = {}
for species in ['human', 'marmoset', 'mouse']:
    species_data[species] = get_connectivity_matrix(species)
    print(f"\n{species} connectivity:")
    df = pd.DataFrame(species_data[species], index=regions, columns=regions)
    print(df.round(2))


HUMAN connectivity:
      V1   V2    V4   IT   MT   LIP
V1   0.0  0.9  0.00  0.0  0.0  0.00
V2   0.0  0.0  0.75  0.0  0.7  0.00
V4   0.0  0.0  0.00  0.8  0.0  0.00
IT   0.0  0.0  0.00  0.0  0.0  0.00
MT   0.0  0.0  0.00  0.0  0.0  0.75
LIP  0.0  0.0  0.00  0.0  0.0  0.00

MARMOSET connectivity:
      V1    V2   V4    IT    MT  LIP
V1   0.0  0.85  0.0  0.00  0.00  0.0
V2   0.0  0.00  0.7  0.00  0.65  0.0
V4   0.0  0.00  0.0  0.75  0.00  0.0
IT   0.0  0.00  0.0  0.00  0.00  0.0
MT   0.0  0.00  0.0  0.00  0.00  0.7
LIP  0.0  0.00  0.0  0.00  0.00  0.0

MOUSE connectivity:
      V1    V2   V4    IT    MT  LIP
V1   0.0  0.75  0.3  0.00  0.35  0.0
V2   0.0  0.00  0.6  0.00  0.55  0.0
V4   0.0  0.00  0.0  0.65  0.00  0.0
IT   0.0  0.00  0.0  0.00  0.00  0.0
MT   0.0  0.00  0.0  0.00  0.00  0.6
LIP  0.0  0.00  0.0  0.00  0.00  0.0


In [None]:
pathway_data = [] # list to hold pathway strengths for each species in a format suitable for a DataFrame

for species in ['human', 'marmoset', 'mouse']:
    conn = species_data[species]
    # Region indices: V1=0, V2=1, V4=2, IT=3, MT=4, LIP=5    
    pathway_data.append({
        'species': species,
        'V1_to_V2': conn[0, 1],
        'V2_to_V4': conn[1, 2],
        'V4_to_IT': conn[2, 3],
        'V2_to_MT': conn[1, 4],
        'MT_to_LIP': conn[4, 5],
    })

pathway_df = pd.DataFrame(pathway_data)
print(pathway_df)

    species  V1_to_V2  V2_to_V4  V4_to_IT  V2_to_MT  MT_to_LIP
0     human      0.90      0.75      0.80      0.70       0.75
1  marmoset      0.85      0.70      0.75      0.65       0.70
2     mouse      0.75      0.60      0.65      0.55       0.60


# CNN model definition

In [47]:
import torch
import torch.nn as nn

class SpeciesVisualCNN(nn.Module):

    def __init__(self, species_name, connectivity):
        super().__init__()
        self.species = species_name
        
        # weights between areas are based on the connectivity strength between regions (from previous cell)
        self.w_v1_v2 = connectivity[0, 1]
        self.w_v2_v4 = connectivity[1, 2] 
        self.w_v4_it = connectivity[2, 3]
        self.w_v2_mt = connectivity[1, 4]
        self.w_mt_lip = connectivity[4, 5]
        
         # the model shares two early convolutional layers (V1 and V2) - adapted for FashionMNIST dataset (1 input channel, 28x28 images)
        self.v1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)  # 28x28 -> 28x28
        self.v2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  # 28x28 -> 28x28
        self.pool = nn.MaxPool2d(2, 2)  # Will use after each layer
        
        # then it splits into two parallel paths:
        # ventral stream, namely V4 and IT 
        self.v4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.it = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
         
        # at the end we have a fully connected layer for classification
        self.fc_ventral = nn.Linear(256 * 3 * 3, 10)  # 10 FashionMNIST classes, 256 is the number of features in the last layer and 3*3 is the size after all pooling
        
        # dorsal stream - MT and LIP (we include it however we do not use it for FashionMNIST, as it is not spatial task)
        self.mt = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.lip = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
         
        # at the end we have a fully connected layer for regression
        self.fc_dorsal = nn.Linear(256 * 3 * 3, 2)  # the output is just (x, y) center position, 256 is the number of features in the last layer and 3*3 is the size after pooling
        
    def forward(self, x):
        v1_out = torch.relu(self.v1(x))  
        v1_out = self.pool(v1_out) * self.w_v1_v2   #reduce spatial dimensions to 14x14
        
        v2_out = torch.relu(self.v2(v1_out))  # weight by V1->V2 strength
        v2_out = self.pool(v2_out)   # reduce spatial dimensions to 7x7
        
        v4_out = torch.relu(self.v4(v2_out)) * self.w_v2_v4  # weight by V2->V4 strength
        it_out = torch.relu(self.it(v4_out)) * self.w_v4_it  # weight by V4->IT strength
        it_out = self.pool(it_out)  # 3x3
        it_flat = it_out.view(it_out.size(0), -1)  # flatten
        
        object_class = self.fc_ventral(it_flat) # final classification output

        mt_out = torch.relu(self.mt(v2_out)) * self.w_v2_mt # weight by V2->MT strength
        lip_out = torch.relu(self.lip(mt_out)) * self.w_mt_lip # weight by MT->LIP strength
        lip_out = self.pool(lip_out)  # 3x3 
        lip_flat = lip_out.view(lip_out.size(0), -1) # Flatten
        
        spatial_info = self.fc_dorsal(lip_flat) # final regression output
        
        return object_class, spatial_info # return both outputs so we can know what is on the image (the ventral stream) and where it is located (the dorsal stream)
    
# we create a model for each species
models = {}
for species in ['human', 'marmoset', 'mouse']:
    models[species] = SpeciesVisualCNN(species, species_data[species])

# Training the model on the Fashion-MNIST dataset

## Data preprocessing and dataset generation

In [None]:
# now we need to train the model so it can actually learn to recognize objects and their locations
# for that we will use the FashionMNIST dataset from Kaggle (https://www.kaggle.com/datasets/zalando-research/fashionmnist?resource=download)
# first we need to prepare the data, particularly we need to create a PyTorch Dataset and DataLoader for model training. Additionally, we need to reshape the images to 28x28 (because the FashionMNIST images are 28x28) and normalize them to [0, 1] (for more stable training)

import torch
from torch.utils.data import Dataset, DataLoader

train_df = pd.read_csv('Fashion_MNIST/fashion-mnist_train.csv')
test_df = pd.read_csv('Fashion_MNIST/fashion-mnist_test.csv')

class FashionMNISTDataset(Dataset):
    def __init__(self, df):
        self.labels = df['label'].values
        self.pixels = df.drop('label', axis=1).values  # 784 pixel values
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        # get pixels and reshape to 28x28 image (because FashionMNIST images are 28x28 and the original dataset is flattened meaning 28*28 = 784 pixels)
        pixels = self.pixels[idx].reshape(28, 28)
        # normalize to [0, 1] so that pixel values are between 0 and 1 - it allows for more stable training
        pixels = pixels / 255.0 
        # add channel dimension: (1, 28, 28) for grayscale (the images in the dataset are grayscale), because PyTorch expects the input shape to be (batch_size, channels, height, width)
        image = torch.FloatTensor(pixels).unsqueeze(0)
        label = torch.LongTensor([self.labels[idx]])[0]
        
        return image, label


train_dataset = FashionMNISTDataset(train_df)
test_dataset = FashionMNISTDataset(test_df)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # shuffle for training
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False) 

print(f"\nDatasets created:")
print(f"  Train: {len(train_dataset)} samples")
print(f"  Test: {len(test_dataset)} samples")

print(train_dataset[0][0].shape)  # the image tensor (1, 28, 28)
print(train_dataset[0][1].item()) # the label (0-9)


Datasets created:
  Train: 60000 samples
  Test: 10000 samples
torch.Size([1, 28, 28])
2


## Actual training loop

In [46]:
import torch.optim as optim
from tqdm import tqdm

def train_model(model, train_loader, test_loader, epochs=10, learning_rate=0.001):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss() # cross-entropy loss for multi-class classification
    optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Adam optimizer
    
    # dictionary to track metrics
    history = {
        'train_loss': [], 'train_acc': [],
        'test_loss': [], 'test_acc': []
    }
    
    for epoch in range(epochs):
        model.train() # set model to training mode
        train_loss = 0.0 # cumulative loss
        correct = 0 # number of correct predictions
        total = 0 # total number of samples
        
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}'): # iterate over batches
            images, labels = images.to(device), labels.to(device) # move to device
            
            # Forward pass
            obj_out, spatial_out = model(images)  # the output of the ventral stream (object classification), but we ignore spatial_out 
            loss = criterion(obj_out, labels) # compute loss
            
            # Backward pass
            optimizer.zero_grad() # reset the gradients to zero before backpropagation so that they do not accumulate 
            loss.backward() # backpropagate the loss
            optimizer.step() # update the weights
            
            train_loss += loss.item() # accumulate loss
            _, predicted = obj_out.max(1) # get the index of the max log-probability of each sample
            correct += (predicted == labels).sum().item() # count correct predictions
            total += labels.size(0) # accumulate total samples
        
        train_loss /= len(train_loader) # average loss over batches
        train_acc = 100.0 * correct / total # training accuracy
        
        model.eval() # set model to evaluation mode
        test_loss = 0.0 # cumulative test lossx 
        correct = 0 # number of correct predictions
        total = 0 # total number of samples
        
        with torch.no_grad(): # we dont need to compute gradients during evaluation
            for images, labels in test_loader: # we iterate over test batches
                images, labels = images.to(device), labels.to(device) # move to device
                obj_out, _ = model(images) # forward pass, ignore spatial output
                loss = criterion(obj_out, labels) # compute loss
                 
                test_loss += loss.item() # accumulate loss
                _, predicted = obj_out.max(1) # get the index of the max log-probability
                correct += (predicted == labels).sum().item() # count correct predictions
                total += labels.size(0) # accumulate total samples
        
        test_loss /= len(test_loader) # average loss over batches
        test_acc = 100.0 * correct / total # test accuracy
        
        # we append to the history dictionary
        history['train_loss'].append(train_loss) 
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        
        print(f'Epoch {epoch+1}: Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, 'f'Test Loss={test_loss:.4f}, Test Acc={test_acc:.2f}%') 
    
    return history


results = {}
for species in ['human', 'marmoset', 'mouse']: # train each species model with a learning rate 0.001
    print(f"Training the {species} model")
    model = models[species]
    history = train_model(model, train_loader, test_loader, epochs=5, learning_rate=0.001)
    results[species] = history

Training the human model


Epoch 1/5: 100%|██████████| 938/938 [09:05<00:00,  1.72it/s]


Epoch 1: Train Loss=0.3088, Train Acc=88.76%, Test Loss=0.2805, Test Acc=89.88%


Epoch 2/5: 100%|██████████| 938/938 [08:48<00:00,  1.77it/s]


Epoch 2: Train Loss=0.2547, Train Acc=90.75%, Test Loss=0.2373, Test Acc=91.06%


Epoch 3/5: 100%|██████████| 938/938 [08:50<00:00,  1.77it/s]


Epoch 3: Train Loss=0.2243, Train Acc=91.75%, Test Loss=0.2169, Test Acc=91.89%


Epoch 4/5: 100%|██████████| 938/938 [09:30<00:00,  1.64it/s]


Epoch 4: Train Loss=0.1972, Train Acc=92.73%, Test Loss=0.2116, Test Acc=92.04%


Epoch 5/5: 100%|██████████| 938/938 [09:02<00:00,  1.73it/s]


Epoch 5: Train Loss=0.1732, Train Acc=93.59%, Test Loss=0.2062, Test Acc=92.50%
Training the marmoset model


Epoch 1/5: 100%|██████████| 938/938 [09:07<00:00,  1.71it/s]


Epoch 1: Train Loss=0.5136, Train Acc=81.19%, Test Loss=0.3336, Test Acc=88.20%


Epoch 2/5: 100%|██████████| 938/938 [08:44<00:00,  1.79it/s]


Epoch 2: Train Loss=0.3101, Train Acc=88.70%, Test Loss=0.2607, Test Acc=90.33%


Epoch 3/5: 100%|██████████| 938/938 [09:26<00:00,  1.66it/s]


Epoch 3: Train Loss=0.2628, Train Acc=90.51%, Test Loss=0.2417, Test Acc=90.93%


Epoch 4/5: 100%|██████████| 938/938 [09:08<00:00,  1.71it/s]


Epoch 4: Train Loss=0.2309, Train Acc=91.50%, Test Loss=0.2434, Test Acc=90.79%


Epoch 5/5: 100%|██████████| 938/938 [09:26<00:00,  1.66it/s]


Epoch 5: Train Loss=0.2065, Train Acc=92.41%, Test Loss=0.2244, Test Acc=91.52%
Training the mouse model


Epoch 1/5: 100%|██████████| 938/938 [09:20<00:00,  1.67it/s]


Epoch 1: Train Loss=0.5065, Train Acc=81.39%, Test Loss=0.3511, Test Acc=87.16%


Epoch 2/5: 100%|██████████| 938/938 [07:55<00:00,  1.97it/s]


Epoch 2: Train Loss=0.3097, Train Acc=88.72%, Test Loss=0.2704, Test Acc=89.88%


Epoch 3/5: 100%|██████████| 938/938 [09:22<00:00,  1.67it/s]


Epoch 3: Train Loss=0.2647, Train Acc=90.44%, Test Loss=0.2394, Test Acc=91.31%


Epoch 4/5: 100%|██████████| 938/938 [09:43<00:00,  1.61it/s]


Epoch 4: Train Loss=0.2366, Train Acc=91.42%, Test Loss=0.2334, Test Acc=91.31%


Epoch 5/5: 100%|██████████| 938/938 [08:15<00:00,  1.89it/s]


Epoch 5: Train Loss=0.2106, Train Acc=92.28%, Test Loss=0.2194, Test Acc=91.94%
