In [2]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm, trange
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Import functions and classes from the libraries
from fundamental_library import *
from digital_twin_library import ConvModel, get_correlations, PoissonLoss

# Check for GPU availability
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

images_path = '/project/subiculum/data/images_uint8.npy'
v1_responses_path = '/project/subiculum/data/V1_Data.mat'

# Load Images to np.array
images=np.load(images_path)
# Load responses and preprocess them

v1_responses,_,_ = load_mat_file(v1_responses_path)
v1_responses = preprocess_responses(v1_responses)

#define Dataset
train_ratio=0.6
val_ratio=0.2

dataset = NeuralDataset(images, v1_responses)
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Initialize the model
model = ConvModel(
    layers=5, 
    input_kern=11,
    hidden_kern=5, 
    hidden_channels=32, 
    output_dim=13, 
    spatial_scale=0.05, 
    std_scale=0.2)

model = model.to(device)

# Define loss function and optimizer
poisson_loss = PoissonLoss()
gamma = 1e-4
loss_fn = lambda outputs, targets: poisson_loss(outputs, targets) + gamma * model.regularizer()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Define the learning rate schedule
lr_scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=3, verbose=True)

# Define early stopping criteria
early_stopping_patience = 5
early_stopping_counter = 0
best_val_loss = float('-inf')

# Define the number of epochs
epochs = 100

def my_train_epoch(model, loader, optimizer, loss_fn):
    model.train()
    train_loss = 0.0
    for images, responses in loader:
        images, responses = images.to(device), responses.to(device)  # Move data to the appropriate device
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, responses)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    return train_loss / len(loader)

for epoch in trange(epochs):
    # Training loop
    loss = my_train_epoch(model, train_loader, optimizer, loss_fn)
    
    # Validation loop
    with torch.no_grad():
        val_corrs = get_correlations(model, val_loader, device)
    validation_correlation = val_corrs.mean()
    
    # Update learning rate schedule
    lr_scheduler.step(validation_correlation)
    
    # Print training and validation losses
    print(f'Epoch [{epoch+1}/{epochs}], validation correlation: {validation_correlation:.4f}')
    
    # Check for early stopping
    if validation_correlation > best_val_loss:
        best_val_loss = validation_correlation
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= early_stopping_patience:
            print('Early stopping triggered!')
            break

# Evaluation
def evaluate_model(model, test_loader, device):
    model.eval()
    test_loss = 0.0
    all_preds = []
    all_resps = []
    loss_fn = nn.PoissonNLLLoss(log_input=False)
    
    with torch.no_grad():
        for images, responses in test_loader:
            images, responses = images.to(device), responses.to(device)  # Move data to the appropriate device
            outputs = model(images)
            loss = loss_fn(outputs, responses)
            test_loss += loss.item()
            all_preds.append(outputs.cpu().numpy())
            all_resps.append(responses.cpu().numpy())
    
    test_loss /= len(test_loader)
    print(f'Test Loss: {test_loss:.4f}')
    
    all_preds = np.vstack(all_preds)
    all_resps = np.vstack(all_resps)
    correlation = get_correlations(model, test_loader, device)
    print(f'Test Correlation: {correlation.mean():.4f}')  # Print mean correlation
    
    return all_preds, all_resps

# Evaluate the model
evaluate_model(model, test_loader, device)


Using device: cuda:0


HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))

Epoch [1/100], validation correlation: 0.0052
Epoch [2/100], validation correlation: 0.0091
Epoch [3/100], validation correlation: 0.0344
Epoch [4/100], validation correlation: 0.0109
Epoch [5/100], validation correlation: 0.0309
Epoch [6/100], validation correlation: 0.0310
Epoch     7: reducing learning rate of group 0 to 1.0000e-04.
Epoch [7/100], validation correlation: 0.0293
Epoch [8/100], validation correlation: 0.0424
Epoch [9/100], validation correlation: 0.0449
Epoch [10/100], validation correlation: 0.0455
Epoch [11/100], validation correlation: 0.0456
Epoch [12/100], validation correlation: 0.0463
Epoch [13/100], validation correlation: 0.0452
Epoch [14/100], validation correlation: 0.0415
Epoch [15/100], validation correlation: 0.0448
Epoch    16: reducing learning rate of group 0 to 1.0000e-05.
Epoch [16/100], validation correlation: 0.0434
Epoch [17/100], validation correlation: 0.0459
Early stopping triggered!

Test Loss: 0.2720
Test Correlation: 0.0853


(array([[0.01008549, 0.03810545, 0.011841  , ..., 0.15358616, 0.03044046,
         0.04027349],
        [0.01106965, 0.03842531, 0.0137359 , ..., 0.16363904, 0.03013641,
         0.04329807],
        [0.01355773, 0.04494735, 0.02007333, ..., 0.24951576, 0.02880695,
         0.05608019],
        ...,
        [0.01234926, 0.04182239, 0.01383986, ..., 0.16693483, 0.03733826,
         0.04467087],
        [0.01547256, 0.0563703 , 0.02432122, ..., 0.30664867, 0.03362843,
         0.06763098],
        [0.04139634, 0.10930552, 0.04767177, ..., 0.24041812, 0.07216546,
         0.10983258]], dtype=float32),
 array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32))