# Intro to 2D CNNs
A description and demo notebook to go through creating a 2D CNN and using it with dummy data

## 0. Setting up dummy data

Before learning how to use the cores, let's create a dummy data images. This data will be similar to a batch of images.

Throughout the notebook we will refer to the elements of this shape in the following manner:

[1] is the number of channels (can be input, hidden, output)

[144] is the height of image or feature maps

[256] is the height of image or feature maps

[32] is the batch size, which is not as relevant for understanding the material in this notebook.

In [1]:
# To access to neuropixel_predictor
import sys
import os
sys.path.append('../')

# Basic imports
import warnings
import random

# Essential imports
import numpy as np
import torch

In [2]:
warnings.filterwarnings("ignore", category=UserWarning)
device = "cuda" if torch.cuda.is_available() else "cpu"
random_seed = 42

## 1. Setting Up the data

In [3]:
IMAGE_WIDTH = 36
IMAGE_HEIGHT = 22

images = torch.ones(32, 1, IMAGE_WIDTH, IMAGE_HEIGHT)


## 2. Using Stacked 2D Core

In [4]:
stacked2dcore_config = {
    # core args
    'input_channels': 1,
    'input_kern': 7,
    'hidden_kern': 5,
    'hidden_channels': 64,
    'layers': 3,
    'stack': -1,
    'pad_input': True,
    'batch_norm': False
}

In [5]:
# from neuralpredictors.layers.cores import Stacked2dCore
# from neuralpredictors.utils import get_module_output

# core = Stacked2dCore(input_channels=1,
#                      hidden_channels=64,
#                      input_kern=9,
#                      hidden_kern=7)

# in_shape_dict = {k: get_module_output(core, in_shape)[1:] for k, in_shape in in_shapes_dict.items()}



from neuropixel_predictor.layers.cores import Stacked2dCore 

stacked2d_core = Stacked2dCore(**stacked2dcore_config)
stacked2d_core

Stacked2dCore(
  (_input_weights_regularizer): LaplaceL2(
    (laplace): Laplace()
  )
  (features): Sequential(
    (layer0): Sequential(
      (conv): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (nonlin): AdaptiveELU()
    )
    (layer1): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (nonlin): AdaptiveELU()
    )
    (layer2): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (nonlin): AdaptiveELU()
    )
  )
) [Stacked2dCore regularizers: gamma_hidden = 0|gamma_input = 0.0|skip = 0]

In [6]:
# Trying out on images
stacked2dcore_out = stacked2d_core(images)
print(stacked2dcore_out.shape)

torch.Size([32, 64, 36, 22])


## 3. Using Readout to attach Factorized Readout

In [7]:
from neuropixel_predictor.layers.readouts import FullFactorized2d, MultiReadoutBase

In [8]:
in_shapes_dict = {
    '2023-03-15_11-05': torch.Size([64, IMAGE_WIDTH, IMAGE_HEIGHT]),
    # '2023-03-15_15-23': torch.Size([64, 150, 200])
} 

n_neurons_dict = {
    '2023-03-15_11-05': 453,
    # '2023-03-15_15-23': 200
}

In [9]:
factorized_readout = MultiReadoutBase(
    in_shape_dict=in_shapes_dict,
    n_neurons_dict=n_neurons_dict,
    base_readout=FullFactorized2d,
    bias=True,
)

## 4. Invoke core and readout (dummy data)

In [10]:
# Forward pass with core
core_output = stacked2d_core(images)

# Forward pass with readout
readout_output_sample = factorized_readout(core_output, data_key="2023-03-15_11-05")

readout_output_sample

tensor([[ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05],
        [ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05],
        [ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05],
        ...,
        [ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05],
        [ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05],
        [ 8.0406e-04, -6.9038e-04, -2.6662e-03,  ...,  2.7474e-04,
          2.4679e-05,  1.2928e-05]], grad_fn=<AddBackward0>)

## 5. Testing with test data (from sinzlab)
Utility functions from https://github.com/sinzlab/sensorium

In [11]:
# %%capture 
# !pip install git+https://github.com/sinzlab/sensorium.git

**Defining some helper functions to extract the data**

In [12]:
# # The following are minimal adaptations of three utility functions found in nnfabrik that we need to initialise
# # the core and readouts later on.

# def get_data(dataset_fn, dataset_config):
#     """
#     See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/builder.py#L87
#     for the original implementation and documentation if you are interested.
#     """
#     return dataset_fn(**dataset_config)

# def get_dims_for_loader_dict(dataloaders):
#     """
#     See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L39
#     for the original implementation and docstring if you are interested.
#     """
    
#     def get_io_dims(data_loader):
#         items = next(iter(data_loader))
#         if hasattr(items, "_asdict"):  # if it's a named tuple
#             items = items._asdict()

#         if hasattr(items, "items"):  # if dict like
#             return {k: v.shape for k, v in items.items()}
#         else:
#             return (v.shape for v in items)

#     return {k: get_io_dims(v) for k, v in dataloaders.items()}


# def set_random_seed(seed: int, deterministic: bool = True):
#     """
#     See https://github.com/sinzlab/nnfabrik/blob/5b6e7379cb5724a787cdd482ee987b8bc0dfacf3/nnfabrik/utility/nn_helpers.py#L53
#     for the original implementation and docstring if you are intereseted.
#     """
#     random.seed(seed)
#     np.random.seed(seed)
#     if deterministic:
#         torch.backends.cudnn.benchmark = False
#         torch.backends.cudnn.deterministic = True
#     torch.manual_seed(seed)  # this sets both CPU and CUDA seeds for PyTorch

**Loading the data**

In [13]:
# ## Load the data: you can modify this if you have stored it in another location
# from sensorium.datasets import static_loaders

# DATA_PATH = '/Users/tarek/Documents/UNI/Lab Rotations/Kremkow/Data/Test/'

# filenames = [
#     DATA_PATH + 'static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', 
#     DATA_PATH + 'static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip'
#     ]

# dataset_config = {'paths': filenames,
#                  'normalize': True,
#                  'include_behavior': False,
#                  'include_eye_position': True,
#                  'batch_size': 32,
#                  'scale':1,
#                  'cuda': True if device == 'cuda' else False,
#                  }

# dataloaders = get_data(dataset_fn=static_loaders, dataset_config=dataset_config)

In [14]:
# datapoint = list(dataloaders['train']['21067-10-18'])[0]
# images = datapoint[0]
# responses = datapoint[1]
# images.shape, responses.shape

**Process the data step by step**

In [15]:
# # We only need the train dataloaders to extract the session keys (could also use test or validation for this)
# train_dataloaders = dataloaders["train"]

# # Obtain the named tuple fields from the first entry of the first dataloader in the dictionary
# example_batch = next(iter(list(train_dataloaders.values())[0]))
# in_name, out_name = (
#     list(example_batch.keys())[:2] if isinstance(example_batch, dict) else example_batch._fields[:2]
# )

# session_shape_dict = get_dims_for_loader_dict(train_dataloaders)
# input_channels = [v[in_name][1] for v in session_shape_dict.values()]

# core_input_channels = (
#     list(input_channels.values())[0]
#     if isinstance(input_channels, dict)
#     else input_channels[0]
# )

**Core: Define Config Params**

In [16]:
# stacked2dcore_config = {
#     # core args
#     'input_kern': 7,
#     'hidden_kern': 5,
#     'hidden_channels': 64,
#     'layers': 3,
#     'stack': -1,
#     'pad_input': True,
#     'gamma_input': 6.3831
# }

**Core: Setting up**

In [17]:
# set_random_seed(random_seed)
# core = Stacked2dCore(
#     input_channels=core_input_channels,
#     **stacked2dcore_config,
# )
# core

**Core: Example forward pass**

In [18]:
# print(f"Sample batch shape: {example_batch.images.shape} (batch size, in_channels, in_height, in_width)")

# with torch.no_grad():
#     core_output = core(example_batch.images)
    
# print(f"Core output shape: {core_output.shape} (batch_size, out_channels, out_height, out_width)")


**Readout: Test the factorized**

In [19]:
# with torch.no_grad():
#     readout_output_sample = factorized_readout(core_output, data_key="21067-10-18")


# print(f"Readout output shape: {readout_output_sample.shape} (batch_size, n_neurons)")

--------------
--------------
--------------
## 6. Testing with our data

In [20]:
from torch.utils.data import DataLoader

TRAINING_DATA_DIR = '/Users/tarek/Documents/UNI/Lab Rotations/Kremkow/Data/Training'
BATCH_SIZE = 32

def load_dataset(date):
    # 1. Load the images and responses
    training_images = np.load(os.path.join(TRAINING_DATA_DIR, "training_images_{}.npy".format(date)))
    training_responses = np.load(os.path.join(TRAINING_DATA_DIR, "training_responses_{}.npy".format(date)))
    
    test_images = np.load(os.path.join(TRAINING_DATA_DIR, "test_images_{}.npy".format(date)))
    test_responses = np.load(os.path.join(TRAINING_DATA_DIR, "test_responses_{}.npy".format(date)))
    
    training_images = training_images.reshape(training_images.shape[0], 1, IMAGE_WIDTH, IMAGE_HEIGHT)
    test_images = test_images.reshape(test_images.shape[0], 1, IMAGE_WIDTH, IMAGE_HEIGHT)
    
    # 2. Convert to tensors and typecast to float
    training_images = torch.from_numpy(training_images).float()
    test_images = torch.from_numpy(test_images).float()
    
    training_responses = torch.from_numpy(training_responses).float()
    test_responses = torch.from_numpy(test_responses).float()
    
    # 3. Zip training and test
    training_data = list(zip(training_images, training_responses))
    test_data = list(zip(test_images, test_responses))

    # 4. Convert to DataLoader
    train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=False)
    test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

    return train_dataloader, test_dataloader

# 1. Load data with date
train_dataloader_1, test_dataloader_1 = load_dataset("2022-12-20_15-08")
train_dataloader_2, test_dataloader_2 = load_dataset("2023-03-15_11-05")
train_dataloader_3, test_dataloader_3 = load_dataset("2023-03-15_15-23")

train_dataloaders = {
    "2022-12-20_15-08": train_dataloader_1,
    "2023-03-15_11-05": train_dataloader_2,
    "2023-03-15_15-23": train_dataloader_3,
}

test_dataloaders = {
    "2022-12-20_15-08": test_dataloader_1,
    "2023-03-15_11-05": test_dataloader_2,
    "2023-03-15_15-23": test_dataloader_3,
}


# 2. Load sample batch
images_batch, responses_batch = next(iter(train_dataloaders["2022-12-20_15-08"]))

# 5. Validate shape and types
images_batch.shape, responses_batch.shape, images_batch.type()

(torch.Size([32, 1, 36, 22]), torch.Size([32, 466]), 'torch.FloatTensor')

**Core: Define Config Params**

In [21]:
stacked2dcore_config = {
    # core args
    'input_channels': 1,
    'input_kern': 7,
    'hidden_kern': 5,
    'hidden_channels': 64,
    'layers': 3,
    'stack': -1,
    'pad_input': True,
    'batch_norm': False
}

**Core: Setting up**

In [22]:
core = Stacked2dCore(
    **stacked2dcore_config,
)
core

Stacked2dCore(
  (_input_weights_regularizer): LaplaceL2(
    (laplace): Laplace()
  )
  (features): Sequential(
    (layer0): Sequential(
      (conv): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
      (nonlin): AdaptiveELU()
    )
    (layer1): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (nonlin): AdaptiveELU()
    )
    (layer2): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (nonlin): AdaptiveELU()
    )
  )
) [Stacked2dCore regularizers: gamma_hidden = 0|gamma_input = 0.0|skip = 0]

**Core: Example forward pass**

In [23]:
print(f"Sample batch shape: {images_batch.shape} (batch size, in_channels, in_height, in_width)")

with torch.no_grad():
    core_output = core(images_batch)
    
print(f"Core output shape: {core_output.shape} (batch_size, out_channels, out_height, out_width)")


Sample batch shape: torch.Size([32, 1, 36, 22]) (batch size, in_channels, in_height, in_width)
Core output shape: torch.Size([32, 64, 36, 22]) (batch_size, out_channels, out_height, out_width)


**Readout: Setting up**

In [24]:
# Could also build the following programmatically
in_shapes_dict = {
    '2023-03-15_11-05': torch.Size([64, IMAGE_WIDTH, IMAGE_HEIGHT]),
    '2023-03-15_15-23': torch.Size([64, IMAGE_WIDTH, IMAGE_HEIGHT]),
    '2022-12-20_15-08': torch.Size([64, IMAGE_WIDTH, IMAGE_HEIGHT])
} 

n_neurons_dict = {
    '2023-03-15_11-05': 373,
    '2023-03-15_15-23': 388,
    '2022-12-20_15-08': 466
}

factorized_multi_readout = MultiReadoutBase(
    in_shape_dict=in_shapes_dict,
    n_neurons_dict=n_neurons_dict,
    base_readout=FullFactorized2d,
    bias=True
)

**Readout: Test the factorized**

In [25]:
with torch.no_grad():
    readout_output_sample = factorized_multi_readout(core_output, data_key="2023-03-15_11-05")


print(f"Readout output shape: {readout_output_sample.shape} (batch_size, n_neurons)")

Readout output shape: torch.Size([32, 373]) (batch_size, n_neurons)


-------------
## 7. Training the model
Adopted similar usage from https://github.com/sinzlab/nnsysident/blob/master/notebooks/tutorial_mouse_models.ipynb

In [26]:
from neuropixel_predictor.training.trainers import simplified_trainer
from neuropixel_predictor.layers.encoders import GeneralizedEncoderBase
import torch.nn as nn

poisson_loss = nn.PoissonNLLLoss(log_input=False, full=True)
mse_loss = nn.MSELoss()
device = 'mps'

# trainer_config_base = {"track_training": True,
#                        "device": device,
#                        "detach_core": False}

In [27]:
model = GeneralizedEncoderBase(
    core,
    factorized_multi_readout,
    elu=True
)

In [28]:
raise Error
# trained_model, training_history = simplified_trainer(
#     model=model,
#     train_loaders=train_dataloaders,
#     val_loaders=test_dataloaders,
#     loss_fn=poisson_loss,
#     device=device,
#     max_epochs=200,
#     lr=1e-4,
#     patience=10 # For the early stopping
# )

NameError: name 'Error' is not defined

**Plotting the training and test loss**

In [None]:
import matplotlib.pyplot as plt

def plot_training_history(history):
    """
    Plots the training and validation loss curves over epochs.
    Expects history = {'train_loss': [...], 'val_loss': [...]}.
    """
    epochs = range(1, len(history['train_loss']) + 1)

    plt.figure(figsize=(8, 5))
    plt.plot(epochs, history['train_loss'], label='Training Loss', linewidth=2)
    plt.plot(epochs, history['val_loss'], label='Validation Loss', linewidth=2)

    plt.title('Training and Validation Poisson Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.5)
    plt.tight_layout()
    plt.show()

In [None]:
plot_training_history(training_history)


-----------
# 8. Most Exciting Input

In [None]:
from neuropixel_predictor.mei.generate import generate_mei, plot_mei

device = 'mps'
steps = 500
neuron_idx = 55
mei_data_key = "2023-03-15_11-05"
image_shape = (1, 1, IMAGE_WIDTH, IMAGE_HEIGHT)

mei = generate_mei(
    model,
    mei_data_key,
    neuron_idx,
    image_shape,
    steps=steps,
    mode="cei",     # or "vei_plus", "vei_minus"
    device=device
)

plot_mei(mei, title="MEI - Neuron: {}, Steps: {}".format(neuron_idx, steps) )

-----------
# 9. Use MEI in the model

In [None]:
# 1. Fetch a stimulus image from the dataset
random_dataset_indx = torch.randint(len(train_dataloaders[mei_data_key]), (1,))
dataset_image = next(iter(train_dataloaders[mei_data_key]))[0].to(device)

# 2. Generate a completely random image
random_image = torch.randn(image_shape, device=device) * 0.1

# 3. Test dataset image, random image and MEI image and compare predictions
pred_random = model(random_image, data_key=mei_data_key)[0, 0, neuron_idx]
pred_dataset = model(dataset_image, data_key=mei_data_key)[0, 0, neuron_idx]
pred_mei = model(mei, data_key=mei_data_key)[0, 0, neuron_idx]

pred_random, pred_dataset, pred_mei

-----------
# 10. Sanity Checks

## 10.1 Baseline Poisson Loss

In [None]:
# 1. Reshape training responses and validation responses
training_responses = [datapoint[1] for datapoint in list(train_dataloaders["2022-12-20_15-08"])]
training_responses = torch.cat(training_responses, 0)

val_responses = [datapoint[1] for datapoint in list(test_dataloaders["2022-12-20_15-08"])]
val_responses = torch.cat(val_responses, 0)

# 2. Calculate mean
mean_rate = training_responses.mean(dim=0) # shape: (num_of_neurons)
mean_rate.shape

# 3. Repeat mean_rate for all validation samples
baseline_pred_val = mean_rate.expand(val_responses.shape[0], -1)

print("mean_rate of first neuron: ", mean_rate[0])
print("baseline_pred_val of first neuron (subset of 4): ", baseline_pred_val[:4, 0])

# 4. Define Poisson manually
def poisson_loss_manual(pred, target, eps=1e-8):
    return (pred - target * torch.log(pred + eps)).mean()

# 5. Compute baseline validation loss
manual_baseline_val_loss = poisson_loss_manual(baseline_pred_val, val_responses)
nn_baseline_val_loss = nn.PoissonNLLLoss(log_input=False, full=True)(baseline_pred_val, val_responses)

# print("Baseline validation Manual Poisson loss:", manual_baseline_val_loss.item())
print("Baseline validation NN Poisson loss:", nn_baseline_val_loss.item())


## 10.2 Overfit on small dataset

In [29]:
# 1. Create a subset of the first 100 samples
test_data_key = "2022-12-20_15-08"
subset_size = 100

small_training_images = np.load(os.path.join(TRAINING_DATA_DIR, "training_images_{}.npy".format(test_data_key)))[:subset_size]
small_training_responses = np.load(os.path.join(TRAINING_DATA_DIR, "training_responses_{}.npy".format(test_data_key)))[:subset_size]
small_training_images = small_training_images.reshape(small_training_images.shape[0], 1, IMAGE_WIDTH, IMAGE_HEIGHT)

# 2. Convert to tensors and typecast to float
small_training_images = torch.from_numpy(small_training_images).float()
small_training_responses = torch.from_numpy(small_training_responses).float()

# 3. Zip training and test
small_training_data = list(zip(small_training_images, small_training_responses))

# 4. Convert to DataLoader
small_train_dataloader = DataLoader(small_training_data, batch_size=32, shuffle=False)

len(small_training_data), len(small_train_dataloader), model

(100,
 4,
 GeneralizedEncoderBase(
   (core): Stacked2dCore(
     (_input_weights_regularizer): LaplaceL2(
       (laplace): Laplace()
     )
     (features): Sequential(
       (layer0): Sequential(
         (conv): Conv2d(1, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
         (nonlin): AdaptiveELU()
       )
       (layer1): Sequential(
         (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
         (nonlin): AdaptiveELU()
       )
       (layer2): Sequential(
         (conv): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
         (nonlin): AdaptiveELU()
       )
     )
   ) [Stacked2dCore regularizers: gamma_hidden = 0|gamma_input = 0.0|skip = 0]
   
   (readout): MultiReadoutBase(
     (2023-03-15_11-05): FullFactorized2d (64 x 22 x 36 -> 373) with bias, normalized
     (2023-03-15_15-23): FullFactorized2d (64 x 22 x 36 -> 388) with bias, normalized
     (2022-12-20_15-08): FullFactorized2d (64 x 22 x 36 -> 466

In [None]:
# 1. Train on this small dataset
import torch.nn.functional as F
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-3)
model.to(torch.device(device))

for epoch in range(400):
    for x_batch, y_batch in small_train_dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        pred = model(x_batch, data_key=test_data_key)

        loss = poisson_loss(pred, y_batch)
        loss.backward()
        
        optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}: loss = {loss.item():.4f}")


Epoch 0: loss = 8.9026
Epoch 20: loss = 4.1703
Epoch 40: loss = 4.1042
Epoch 60: loss = 4.1004
Epoch 80: loss = 3.8533
Epoch 100: loss = 3.5608
Epoch 120: loss = 3.4044
Epoch 140: loss = 3.3937
Epoch 160: loss = 3.3504
Epoch 180: loss = 3.3298
Epoch 200: loss = 3.3209
Epoch 220: loss = 3.3200
Epoch 240: loss = 3.3232
Epoch 260: loss = 3.3435
Epoch 280: loss = 3.3187
Epoch 300: loss = 3.3159
Epoch 320: loss = 3.3212
Epoch 340: loss = 3.3112
Epoch 360: loss = 3.3192
