In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
from copy import deepcopy
import dill
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt
import wandb
from nnfabrik.builder import get_data

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import crop, plot_comparison, standardize, normalize, get_mean_and_std, count_parameters, plot_losses
from csng.losses import SSIMLoss, MultiSSIMLoss, Loss, CroppedLoss
from csng.data import MixedBatchLoader
from csng.readins import (
    MultiReadIn,
    HypernetReadIn,
    ConvReadIn,
    AttentionReadIn,
    FCReadIn,
    AutoEncoderReadIn,
    Conv1dReadIn,
    LocalizedFCReadIn,
    MEIReadIn,
)

from encoder import get_encoder
from data_utils import (
    get_mouse_v1_data,
    append_syn_dataloaders,
    append_data_aug_dataloaders,
    RespGaussianNoise,
)
from cnn_decoder_utils import train, val, get_all_data

lt.monkey_patch()
DATA_PATH = os.path.join(os.environ["DATA_PATH"], "mouse_v1_sensorium22")
print(f"{DATA_PATH=}")

%env "WANDB_NOTEBOOK_NAME" "cnn_decoder.ipynb"
wandb.login()

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
    # "crop_win": None,
    # "crop_win": (slice(7, 29), slice(15, 51)),
    "crop_win": (22, 36),
    # "wandb": None,
    "wandb": {
        "project": "CSNG",
        "group": "sensorium_2022",
    },
}
print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Data

In [None]:
dataloaders = dict()
config["data"]["mouse_v1"] = None
config["data"]["syn_dataset_config"] = None
config["data"]["data_augmentation"] = None

### Mouse V1 dataset (Sensorium 2022)

In [None]:
### prep data config
config["data"]["mouse_v1"] = {
    "dataset_fn": "sensorium.datasets.static_loaders",
    "dataset_config": {
        "paths": [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
            # os.path.join(DATA_PATH, "static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # mouse 1
            # os.path.join(DATA_PATH, "static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # sensorium+ (mouse 2)
            os.path.join(DATA_PATH, "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # pretraining (mouse 3)
            os.path.join(DATA_PATH, "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # pretraining (mouse 4)
            os.path.join(DATA_PATH, "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # pretraining (mouse 5)
            os.path.join(DATA_PATH, "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # pretraining (mouse 6)
            os.path.join(DATA_PATH, "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # pretraining (mouse 7)
        ],
        "normalize": True,
        "scale": 0.25, # 256x144 -> 64x36
        "include_behavior": False,
        "add_behavior_as_channels": False,
        "include_eye_position": True,
        "exclude": None,
        "file_tree": True,
        "cuda": "cuda" in config["device"],
        "batch_size": 7,
        "seed": config["seed"],
        "use_cache": False,
    },
    "skip_train": False,
    "skip_val": False,
    "skip_test": False,
    "normalize_neuron_coords": True,
    "average_test_multitrial": True,
    "save_test_multitrial": True,
    "test_batch_size": 7,
    "device": config["device"],
}

In [None]:
### get dataloaders and cell coordinates
dataloaders, neuron_coords = get_mouse_v1_data(config["data"])

In [None]:
### show data
sample_data_key = dataloaders["mouse_v1"]["test"].data_keys[0]
datapoint = next(iter(dataloaders["mouse_v1"]["test"].dataloaders[0]))
stim, resp = datapoint.images, datapoint.responses
pupil_center = datapoint.pupil_center
print(
    f"Training dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['train'].dataloaders)} samples"
    f"\nValidation dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['val'].dataloaders)} samples"
    f"\nTest dataset:\t\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['test'].dataloaders)} samples"
    f"\nTest (no resp) dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['test_no_resp'].dataloaders)} samples"

    "\n\nstimuli:"
    f"\n  {stim.shape}"
    f"\n  min={stim.min().item():.3f}  max={stim.max().item():.3f}"
    f"\n  mean={stim.mean().item():.3f}  std={stim.std().item():.3f}"
    "\nresponses:"
    f"\n  {resp.shape}"
    f"\n  min={resp.min().item():.3f}  max={resp.max().item():.3f}"
    f"\n  mean={resp.mean().item():.3f}  std={resp.std().item():.3f}"
    "\nneuronal coordinates:"
    f"\n  {neuron_coords[sample_data_key].shape}"
    f"\n  min={neuron_coords[sample_data_key].min():.3f}  max={neuron_coords[sample_data_key].max():.3f}"
    f"\n  mean={neuron_coords[sample_data_key].mean():.3f}  std={neuron_coords[sample_data_key].std():.3f}"
)

### plot sample data
sample_idx = 0

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[sample_idx].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop(stim[sample_idx].cpu(), config["crop_win"]).squeeze().unsqueeze(-1), cmap="gray")

### bin the neuronal responses based on their neuron coordinates and sum within each bin -> 2D grid of vals
coords = neuron_coords[sample_data_key]
H, W = stim.shape[-2:] # the size of the grid
n_x_bins, n_y_bins = 32, 18 # number of bins in each dimension
min_x, max_x, min_y, max_y = coords[:,0].min().item(), coords[:,0].max().item(), coords[:,1].min().item(), coords[:,1].max().item()
x_bins = torch.linspace(min_x, max_x, n_x_bins + 1)
y_bins = torch.linspace(min_y, max_y, n_y_bins + 1)
binned_resp = torch.zeros(n_y_bins, n_x_bins)
for i in range(n_y_bins):
    for j in range(n_x_bins):
        ### mask of the neurons in the bin
        mask = (x_bins[j] <= coords[:,0]) &\
               (coords[:,0] < x_bins[j + 1]) &\
               (y_bins[i] <= coords[:,1]) &\
               (coords[:,1] < y_bins[i + 1])
        binned_resp[i,j] = resp[sample_idx, mask.cpu()].sum(0)
ax = fig.add_subplot(133)
ax.imshow(binned_resp.squeeze().cpu(), cmap="gray")
plt.show()

### Synthetic dataset (different image stimuli -> encoder -> responses)

In [None]:
### append synthetic data
config["data"]["syn_dataset_config"] = {
    "data_keys": [
        "21067-10-18",
        "22846-10-16",
        "23343-5-17",
        "23656-14-22",
        "23964-4-22",
    ],
    "batch_size": 7,
    "append_data_parts": ["train"],
    # "data_key_prefix": "syn",
    "data_key_prefix": None, # the same data key as the original (real) data
    "dir_name": "synthetic_data_mouse_v1_encoder_new_stimuli",
    "device": config["device"],
}

dataloaders = append_syn_dataloaders(dataloaders, config=config["data"]["syn_dataset_config"])

In [None]:
### show data
syn_stim, syn_resp, syn_pupil_center = next(iter(dataloaders["mouse_v1"]["train"].dataloaders[-1]))
syn_sample_data_key = dataloaders["mouse_v1"]["train"].data_keys[-1]
print(
    f"Training dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['train'].dataloaders)} samples"
    f"\nValidation dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['val'].dataloaders)} samples"
    f"\nTest dataset:\t\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['test'].dataloaders)} samples"
    f"\nTest (no resp) dataset:\t {sum(len(dl) * dl.batch_size for dl in dataloaders['mouse_v1']['test_no_resp'].dataloaders)} samples"

    "\n\nstimuli:"
    f"\n  {syn_stim.shape}"
    f"\n  min={syn_stim.min().item():.3f}  max={syn_stim.max().item():.3f}"
    f"\n  mean={syn_stim.mean().item():.3f}  std={syn_stim.std().item():.3f}"
    "\nresponses:"
    f"\n  {syn_resp.shape}"
    f"\n  min={syn_resp.min().item():.3f}  max={syn_resp.max().item():.3f}"
    f"\n  mean={syn_resp.mean().item():.3f}  std={syn_resp.std().item():.3f}"
    "\nNeuron coordinates:"
    f"\n  {neuron_coords[syn_sample_data_key].shape}"
    f"\n  min={neuron_coords[syn_sample_data_key].min():.3f}  max={neuron_coords[syn_sample_data_key].max():.3f}"
    f"\n  mean={neuron_coords[syn_sample_data_key].mean():.3f}  std={neuron_coords[syn_sample_data_key].std():.3f}"
    "\nPupil center:"
    f"\n  {syn_pupil_center.shape}"
    f"\n  min={syn_pupil_center.min().item():.3f}  max={syn_pupil_center.max().item():.3f}"
    f"\n  mean={syn_pupil_center.mean().item():.3f}  std={syn_pupil_center.std().item():.3f}"
)

### plot sample data
sample_idx = 0

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim[sample_idx].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop(syn_stim[sample_idx].cpu(), config["crop_win"]).squeeze().unsqueeze(-1), cmap="gray")

### bin the neuronal responses based on their neuron coordinates and sum within each bin -> 2D grid of vals
coords = neuron_coords[syn_sample_data_key]
H, W = syn_stim.shape[-2:] # the size of the grid
n_x_bins, n_y_bins = 32, 18 # number of bins in each dimension
min_x, max_x, min_y, max_y = coords[:,0].min().item(), coords[:,0].max().item(), coords[:,1].min().item(), coords[:,1].max().item()
x_bins = torch.linspace(min_x, max_x, n_x_bins + 1)
y_bins = torch.linspace(min_y, max_y, n_y_bins + 1)
binned_resp = torch.zeros(n_y_bins, n_x_bins)
for i in range(n_x_bins):
    for j in range(n_y_bins):
        ### mask of the neurons in the bin
        mask = (x_bins[i] <= coords[:,0]) &\
               (coords[:,0] < x_bins[i + 1]) &\
               (y_bins[j] <= coords[:,1]) &\
               (coords[:,1] < y_bins[j + 1])
        binned_resp[j,i] = syn_resp[sample_idx, mask.cpu()].sum(0)
ax = fig.add_subplot(133)
ax.imshow(binned_resp.squeeze().cpu(), cmap="gray")
plt.show()

### Data augmentation

In [None]:
config["data"]["data_augmentation"] = {
    "data_transforms": [[  # for synthetic data
        RespGaussianNoise(
            noise_std=2 * torch.from_numpy(np.load(os.path.join(DATA_PATH, dataset.dirname, "stats", f"responses_iqr.npy"))).float().to(config["device"]),
            clip_min=0.0,
            # dynamic_mul_factor=0.05,
            # resp_fn="squared",
        ) for dataset in dataloaders["mouse_v1"]["train"].datasets
    ]],
    "append_data_parts": ["train"],
    "force_same_order": True,
    "seed": config["seed"],
}

In [None]:
dataloaders = append_data_aug_dataloaders(
    dataloaders=dataloaders,
    config=config["data"]["data_augmentation"],
)

In [None]:
for b in dataloaders["mouse_v1"]["train"]:
    break
dataloaders["mouse_v1"]["train"].dataloaders, b

In [None]:
### show data
aug_sample_data_key = dataloaders["mouse_v1"]["train"].data_keys[-1]
aug_datapoint = next(iter(dataloaders["mouse_v1"]["train"].dataloaders[-1]))
aug_stim, aug_resp = aug_datapoint.images, aug_datapoint.responses
aug_pupil_center = aug_datapoint.pupil_center

no_aug_datapoint = next(iter(dataloaders["mouse_v1"]["train"].dataloaders[len(dataloaders["mouse_v1"]["train"].dataloaders) // 2 - 1]))
no_aug_resp = no_aug_datapoint.responses

### plot sample data
sample_idx = 0

fig = plt.figure(figsize=(20, 6))
ax = fig.add_subplot(141)
ax.imshow(aug_stim[sample_idx].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(142)
ax.imshow(crop(aug_stim[sample_idx].cpu(), config["crop_win"]).squeeze().unsqueeze(-1), cmap="gray")

### no_aug_resp: bin the neuronal responses based on their neuron coordinates and sum within each bin -> 2D grid of vals
coords = neuron_coords[aug_sample_data_key]
H, W = aug_stim.shape[-2:] # the size of the grid
n_x_bins, n_y_bins = 32, 18 # number of bins in each dimension
min_x, max_x, min_y, max_y = coords[:,0].min().item(), coords[:,0].max().item(), coords[:,1].min().item(), coords[:,1].max().item()
x_bins = torch.linspace(min_x, max_x, n_x_bins + 1)
y_bins = torch.linspace(min_y, max_y, n_y_bins + 1)
binned_resp = torch.zeros(n_y_bins, n_x_bins)
for i in range(n_y_bins):
    for j in range(n_x_bins):
        ### mask of the neurons in the bin
        mask = (x_bins[j] <= coords[:,0]) &\
               (coords[:,0] < x_bins[j + 1]) &\
               (y_bins[i] <= coords[:,1]) &\
               (coords[:,1] < y_bins[i + 1])
        binned_resp[i,j] = no_aug_resp[sample_idx, mask.cpu()].sum(0)
ax = fig.add_subplot(143)
ax.set_title("Responses before augmentation")
ax.imshow(binned_resp.squeeze().cpu(), cmap="gray")

### aug_resp: bin the neuronal responses based on their neuron coordinates and sum within each bin -> 2D grid of vals
coords = neuron_coords[aug_sample_data_key]
H, W = aug_stim.shape[-2:] # the size of the grid
n_x_bins, n_y_bins = 32, 18 # number of bins in each dimension
min_x, max_x, min_y, max_y = coords[:,0].min().item(), coords[:,0].max().item(), coords[:,1].min().item(), coords[:,1].max().item()
x_bins = torch.linspace(min_x, max_x, n_x_bins + 1)
y_bins = torch.linspace(min_y, max_y, n_y_bins + 1)
binned_resp = torch.zeros(n_y_bins, n_x_bins)
for i in range(n_y_bins):
    for j in range(n_x_bins):
        ### mask of the neurons in the bin
        mask = (x_bins[j] <= coords[:,0]) &\
               (coords[:,0] < x_bins[j + 1]) &\
               (y_bins[i] <= coords[:,1]) &\
               (coords[:,1] < y_bins[i + 1])
        binned_resp[i,j] = aug_resp[sample_idx, mask.cpu()].sum(0)
ax = fig.add_subplot(144)
ax.set_title("Responses after augmentation")
ax.imshow(binned_resp.squeeze().cpu(), cmap="gray")
plt.show()

## Encoder

In [None]:
encoder = get_encoder(
    ckpt_path=os.path.join(DATA_PATH, "models", "encoder_sens22.pth"),
    device=config["device"],
    eval_mode=True,
    # ckpt_path=os.path.join(DATA_PATH, "models", "encoder_sens22_no_shifter.pth"),
)

## Decoder

In [None]:
config["decoder"] = {
    "model": {
        "readins_config": [
            {
                "data_key": data_key,
                "in_shape": n_coords.shape[-2],
                "decoding_objective_config": None,
                # "decoding_objective_config": {
                #     "decoder_cls": FCReadIn,
                #     "decoder_config": {
                #         "in_shape": 68*9*16,
                #         "layers_config": [("fc", 264), ("fc", d.n_neurons),],
                #         "act_fn": nn.LeakyReLU,
                #         "out_act_fn": nn.Identity,
                #         "dropout": 0.0,
                #         "batch_norm": False,
                #     },
                #     "loss_fn": nn.MSELoss(),
                # },
                "layers": [
                    # ("fc", 432),
                    # ("unflatten", 1, (3, 9, 16)),

                    # (LocalizedFCReadIn, {
                    #     "in_shape": d.n_neurons,
                    #     "layers": [
                    #         {"n_bins": 20, "reduce_by": 3},
                    #         {"n_bins": 12, "reduce_by": 2},
                    #         {"n_bins": 7, "reduce_by": 2},
                    #         {"n_bins": 2, "reduce_by": 2},
                    #     ],
                    #     "out_config": {
                    #         "shape": (3, 9, 16),
                    #         "method": "linear",
                    #     },
                    #     "act_fn": nn.LeakyReLU,
                    #     "out_act_fn": nn.Identity,
                    #     "dropout": 0.15,
                    #     "batch_norm": True,
                    # }),

                    # (AttentionReadIn, {
                    #     "in_shape": d.n_neurons,
                    #     "shift_coords": True,
                    #     "shifter_net_layers": [
                    #         ("fc", 10),
                    #         ("fc", 10),
                    #         ("fc", 2),
                    #     ],
                    #     "shifter_net_act_fn": nn.LeakyReLU,
                    #     "shifter_net_out_act_fn": nn.Tanh,
                    #     "attn_config": {
                    #         "layers": 1,
                    #         "token_neurons": 20,
                    #         "dim_head": 256,
                    #         "dropout": 0.1,
                    #         "attn_num_heads": 1,
                    #     },
                    #     "attn_interleave_config": {
                    #         "layers": [
                    #             ("fc", 512),
                    #             ("act_fn", nn.ReLU),
                    #             ("dropout", 0.15),
                    #             ("fc", 256)
                    #         ],
                    #         "after_last": True,
                    #     },
                    #     "neuron_embed_dim": 16,
                    #     "conv_out_config": {
                    #         "out_channels": 64,
                    #         "kernel_size": 5,
                    #         "stride": 1,
                    #         "padding": 2,
                    #         "bias": False,
                    #         "batch_norm": True,
                    #         "act_fn": nn.LeakyReLU,
                    #     },
                    # }),

                    (ConvReadIn, {
                        "shift_coords": False,
                        "learn_grid": True,
                        "grid_l1_reg": 8e-3,
                        "in_channels_group_size": 1,
                        # "grid_net_config": {
                        #     "in_channels": 32, # x, y, z, resp
                        #     "layers_config": [("fc", 64), ("fc", 64), ("fc", 16*9)],
                        #     "act_fn": nn.LeakyReLU,
                        #     "out_act_fn": nn.Identity,
                        #     "dropout": 0.1,
                        #     "batch_norm": False,
                        # },
                        "pointwise_conv_config": {
                            "in_channels": n_coords.shape[-2],
                            "out_channels": 256,
                            "act_fn": nn.LeakyReLU,
                            "bias": False,
                            "batch_norm": True,
                            # "dropout": 0.15,
                        },
                        "gauss_blur": False,
                        "gauss_blur_kernel_size": 7,
                        "gauss_blur_sigma": "fixed", # "fixed", "single", "per_neuron"
                        # "gauss_blur_sigma": "per_neuron", # "fixed", "single", "per_neuron"
                        "gauss_blur_sigma_init": 1.5,
                        "neuron_emb_dim": None,
                    }),

                    # (MEIReadIn, {
                    #     "meis_path": os.path.join(DATA_PATH, "meis", data_key,  "meis.pt"),
                    #     "n_neurons": n_coords.shape[-2],
                    #     "mei_resize_method": "resize",
                    #     "mei_target_shape": (22, 36),
                    #     "pointwise_conv_config": {
                    #         "out_channels": 256,
                    #         "bias": False,
                    #         "batch_norm": True,
                    #         "act_fn": nn.Identity,
                    #     },
                    #     "ctx_net_config": {
                    #         "in_channels": 3, # resp, x, y
                    #         "layers_config": [("fc", 32), ("fc", 128), ("fc", 22*36)],
                    #         "act_fn": nn.LeakyReLU,
                    #         "out_act_fn": nn.Identity,
                    #         "dropout": 0.1,
                    #         "batch_norm": True,
                    #     },
                    #     "shift_coords": False,
                    #     "device": config["device"],
                    # }),

                    # (FCReadIn, {
                    #     "in_shape": n_coords.shape[-2],
                    #     "layers_config": [
                    #         ("fc", 432),
                    #         ("unflatten", 1, (3, 9, 16)),
                    #     ],
                    #     # "act_fn": nn.LeakyReLU,
                    #     "act_fn": nn.GELU,
                    #     "out_act_fn": nn.Identity,
                    #     # "batch_norm": True,
                    #     "batch_norm": False,
                    #     "layer_norm": True,
                    #     # "dropout": 0.3,
                    #     "dropout": 0.3,
                    #     "l2_reg_mul": 1e-3,
                    #     "out_channels": 3,
                    # }),

                    # (AutoEncoderReadIn, {
                    #     "loss_mul": 1,
                    #     "encoder_config": {
                    #         "in_shape": d.n_neurons,
                    #         "layers_config": [
                    #             ("fc", 288),
                    #             ("unflatten", 1, (2, 9, 16)),
                    #         ],
                    #         "act_fn": nn.LeakyReLU,
                    #         "out_act_fn": nn.Identity,
                    #         "batch_norm": True,
                    #         "dropout": 0.2,
                    #         "out_channels": 2,
                    #     },
                    #     "decoder_config": {
                    #         "layers_config": [
                    #             ("fc", 312),
                    #             ("fc", d.n_neurons),
                    #         ],
                    #         "act_fn": nn.LeakyReLU,
                    #         "out_act_fn": nn.Identity,
                    #         "batch_norm": False,
                    #         "dropout": 0.1,
                    #     },
                    # }),

                    # (Conv1dReadIn, {
                    #     # "in_shape": d.n_neurons,
                    #     "in_shape": 1,
                    #     "out_channels": 2,
                    #     "layers_config": [
                    #         ("conv1d", 64, 7, 3, 3),
                    #         ("conv1d", 32, 7, 3, 3),
                    #         ("conv1d", 16, 5, 2, 2),
                    #         ("conv1d", 8, 4, 2, 1),
                    #         ("flatten", 1, -1, 1632),
                    #         ("fc", 288),
                    #         ("unflatten", 1, (2, 9, 16)),
                    #     ],
                    #     "act_fn": nn.ReLU,
                    #     "out_act_fn": nn.Identity,
                    #     "batch_norm": True,
                    #     "dropout": 0.15,
                    # }),

                    # (HypernetReadIn, {
                    #     "n_neurons": d.n_neurons,
                    #     "hypernet_layers": [
                    #         # ("fc", 40),
                    #         ("fc", 64),
                    #         ("fc", 64),
                    #         ("fc", 1152),
                    #     ],
                    #     "hypernet_act_fn": nn.LeakyReLU,
                    #     # "hypernet_act_fn": nn.Tanh,
                    #     "hypernet_out_act_fn": nn.Identity,
                    #     "hypernet_dropout": 0.,
                    #     "hypernet_batch_norm": False,
                    #     "hypernet_init": "normal",
                    #     "hypernet_init_kwargs": {
                    #         "mean": 0,
                    #         "std": 1/(d.n_neurons*1152),
                    #     },
                    #     "hypernet_neuron_embed_dim": 32,
                    #     "target_in_shape": d.n_neurons,
                    #     "target_layers": [
                    #         ("fc", 1152),
                    #         ("unflatten", 1, (8, 9, 16)),
                    #     ],
                    #     # "target_act_fn": nn.LeakyReLU,
                    #     "target_act_fn": nn.Identity,
                    #     "target_out_act_fn": nn.Identity,
                    #     "target_dropout": 0.15,
                    #     "target_out_layer_norm": True,
                    #     "shift_coords": True,
                    #     "shifter_net_layers": [
                    #         ("fc", 10),
                    #         ("fc", 10),
                    #         ("fc", 2),
                    #     ],
                    #     "shifter_net_act_fn": nn.LeakyReLU,
                    #     "shifter_net_out_act_fn": nn.Tanh,
                    # }),

                ],
            } for data_key, n_coords in dataloaders["mouse_v1"]["train"].neuron_coords.items()
        ],
        "core_cls": CNN_Decoder,
        "core_config": {
            "resp_shape": [256],
            "stim_shape": list(stim.shape[1:]),
            "layers": [
                ### for conv_readin
                # ("deconv", 256, 5, 2, 2),
                ("deconv", 256, 7, 2, 2),
                # ("deconv", 128, 7, 2, 1),
                # ("deconv", 64, 5, 2, 2),
                
                ("deconv", 128, 5, 1, 2),
                # ("deconv", 64, 5, 1, 1),

                # ("deconv", 64, 4, 1, 1),
                ("deconv", 64, 5, 1, 2),
                # ("deconv", 32, 4, 1, 1),

                # ("deconv", 64, 3, 1, 1),
                ("deconv", 64, 4, 1, 1),
                # ("deconv", 32, 4, 1, 1),

                # ("deconv", 64, 4, 1, 1),
                # ("deconv", 64, 3, 1, 1),
                ("deconv", 32, 3, 1, 1),

                ("deconv", 1, 3, 1, 0),

                # ### for MEIReadin
                # ("conv", 256, 7, 1, 3),
                # ("conv", 128, 5, 1, 2),
                # ("conv", 64, 3, 1, 1),
                # ("conv", 64, 3, 1, 1),
                # ("conv", 1, 3, 1, 1),

                ### for attn_readin
                # ("deconv", 64, 7, 2, 3),
                # ("deconv", 32, 4, 1, 2),
                # ("deconv", 1, 3, 1, 0),
            ],
            "act_fn": nn.ReLU,
            "out_act_fn": nn.Identity,
            "dropout": 0.3,
            "batch_norm": True,
        },
    },
    "opter_cls": torch.optim.Adam,
    "opter_kwargs": {
        "lr": 3e-4,
        # "weight_decay": 1e-3,
    },
    "loss": {
        # "loss_fn": CroppedLoss(window=config["crop_win"], loss_fn=nn.MSELoss(), normalize=False, standardize=False),
        # "loss_fn": MultiSSIMLoss(
        "loss_fn": SSIMLoss(
            window=config["crop_win"],
            log_loss=True,
            inp_normalized=True,
            inp_standardized=False,
        ),
        "l1_reg_mul": 0,
        "l2_reg_mul": 1e-5,
        "con_reg_mul": 0,
        # "con_reg_mul": 1,
        # "con_reg_loss_fn": MultiSSIMLoss(
        "con_reg_loss_fn": SSIMLoss(
            window=config["crop_win"],
            log_loss=True,
            inp_normalized=True,
            inp_standardized=False,
        ),
        # "con_reg_loss_fn": CroppedLoss(window=config["crop_win"], loss_fn=nn.MSELoss(), normalize=False, standardize=False),
        "encoder": None,
        # "encoder": get_encoder(
        #     ckpt_path=os.path.join(DATA_PATH, "models", "encoder_sens22.pth"),
        #     device=config["device"],
        #     eval_mode=True,
        #     # ckpt_path=os.path.join(DATA_PATH, "models", "encoder_sens22_no_shifter.pth"),
        # ),
    },
    "n_epochs": 200,
    "load_ckpt": None,
    # "load_ckpt": {
    #     "load_only_core": False,
    #     # "load_only_core": True,
    #     "ckpt_path": os.path.join(
    #         # DATA_PATH, "models", "cat_v1_pretraining", "2024-02-27_19-17-39", "decoder.pt"),
    #         DATA_PATH, "models", "cnn", "2024-03-20_17-51-50", "ckpt", "decoder_55.pt"),
    #     "resume_checkpointing": True,
    #     "resume_wandb_id": "ufhjka2b"
    # },
    "save_run": True,
}

In [None]:
### initialize (and load ckpt if needed)
if config["decoder"]["load_ckpt"] != None:
    print(f"[INFO] Loading checkpoint from {config['decoder']['load_ckpt']['ckpt_path']}...")
    ckpt = torch.load(config["decoder"]["load_ckpt"]["ckpt_path"], map_location=config["device"], pickle_module=dill)

    if config["decoder"]["load_ckpt"]["load_only_core"]:
        print("[INFO] Loading only the core of the model (no history, no best ckpt)...")

        ### init decoder (load only the core)
        config["decoder"]["model"]["core_cls"] = ckpt["config"]["decoder"]["model"]["core_cls"]
        config["decoder"]["model"]["core_config"] = ckpt["config"]["decoder"]["model"]["core_config"]
        decoder = MultiReadIn(**config["decoder"]["model"]).to(config["device"])
        decoder.load_state_dict({k:v for k,v in ckpt["best"]["model"].items() if "readin" not in k}, strict=False)

        ### init the rest
        opter = config["decoder"]["opter_cls"](decoder.parameters(), **config["decoder"]["opter_kwargs"])
        loss_fn = Loss(model=decoder, config=config["decoder"]["loss"])
        history = {"train_loss": [], "val_loss": []}
        best = {"val_loss": np.inf, "epoch": 0, "model": None}
    else:
        print("[INFO] Loading the whole model (the latest - not the BEST; with history and best ckpt)...")
        history, config["decoder"]["model"], best = ckpt["history"], ckpt["config"]["decoder"]["model"], ckpt["best"]

        ### overwrite config?
        if input("[WARNING] Do you want to overwrite the config with the one from the checkpoint? (y/n): ") == "y":
            config = ckpt["config"]

        decoder = MultiReadIn(**config["decoder"]["model"]).to(config["device"])
        decoder.load_state_dict(ckpt["decoder"])

        opter = config["decoder"]["opter_cls"](decoder.parameters(), **config["decoder"]["opter_kwargs"])
        opter.load_state_dict(ckpt["opter"])
        loss_fn = Loss(model=decoder, config=config["decoder"]["loss"])
else:
    print("[INFO] Initializing the model from scratch...")
    decoder = MultiReadIn(**config["decoder"]["model"]).to(config["device"])
    opter = config["decoder"]["opter_cls"](decoder.parameters(), **config["decoder"]["opter_kwargs"])
    loss_fn = Loss(model=decoder, config=config["decoder"]["loss"])
    
    history = {"train_loss": [], "val_loss": []}
    best = {"val_loss": np.inf, "epoch": 0, "model": None}

In [None]:
### print model and fix sizes of stimuli
with torch.no_grad():
    stim_pred = decoder(resp.to(config["device"]), data_key=sample_data_key, neuron_coords=neuron_coords[sample_data_key], pupil_center=pupil_center.to(config["device"]))
    if stim_pred.shape != crop(stim, config["crop_win"]).shape:
        print(f"[WARNING] Stimulus prediction shape {stim_pred.shape} does not match stimulus shape {crop(stim, config['crop_win']).shape}.")
        assert stim_pred.shape[-2] >= crop(stim, config["crop_win"]).shape[-2] \
            and stim_pred.shape[-1] >= crop(stim, config["crop_win"]).shape[-1]
    print(stim_pred.shape)
    del stim_pred

print(
    f"Number of parameters:"
    f"\n  whole model: {count_parameters(decoder)}"
    f"\n  core: {count_parameters(decoder.core)} ({count_parameters(decoder.core) / count_parameters(decoder) * 100:.2f}%)"
    f"\n  readins: {count_parameters(decoder.readins)} ({count_parameters(decoder.readins) / count_parameters(decoder) * 100:.2f}%)"
    f"\n    ({', '.join([f'{k}: {count_parameters(v)} [{count_parameters(v) / count_parameters(decoder) * 100:.2f}%]' for k, v in decoder.readins.items()])})"
)

decoder

In [None]:
### prepare checkpointing and wandb logging
if config["decoder"]["load_ckpt"] == None \
    or config["decoder"]["load_ckpt"]["resume_checkpointing"] is False:
    config["run_name"] = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    if config["decoder"]["save_run"]:
        ### save config
        config["dir"] = os.path.join(DATA_PATH, "models", "cnn", config["run_name"])
        os.makedirs(config["dir"], exist_ok=True)
        with open(os.path.join(config["dir"], "config.json"), "w") as f:
            json.dump(config, f, indent=4, default=str)
        os.makedirs(os.path.join(config["dir"], "samples"), exist_ok=True)
        os.makedirs(os.path.join(config["dir"], "ckpt"), exist_ok=True)
        make_sample_path = lambda epoch, prefix: os.path.join(
            config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
        )
        print(f"Run name: {config['run_name']}\nRun dir: {config['dir']}")
    else:
        make_sample_path = lambda epoch, prefix: None
        print("[WARNING] Not saving the run and the config.")
else:
    config["run_name"] = ckpt["config"]["run_name"]
    config["dir"] = ckpt["config"]["dir"]
    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )
    print(f"Checkpointing resumed - Run name: {config['run_name']}\nRun dir: {config['dir']}")

### wandb logging
if config["decoder"]["load_ckpt"] == None \
    or config["decoder"]["load_ckpt"]["resume_wandb_id"] == None:
    if config["wandb"]:
        wdb_run = wandb.init(**config["wandb"], name=config["run_name"], config=config,
            tags=[
                config["decoder"]["model"]["core_cls"].__name__,
                config["decoder"]["model"]["readins_config"][0]["layers"][0][0].__name__,
            ],
            notes=None)
        wdb_run.watch(decoder)
    else:
        print("[WARNING] Not using wandb.")
else:
    wdb_run = wandb.init(**config["wandb"], name=config["run_name"], config=config, id=config["decoder"]["load_ckpt"]["resume_wandb_id"], resume="must")

In [None]:
### train
s, e = len(history["train_loss"]), config["decoder"]["n_epochs"]
for epoch in range(s, e):
    print(f"[{epoch + 1}/{e}]")

    ### train and val
    dls, neuron_coords = get_all_data(config=config)
    train_dataloader, val_dataloader = dls["mouse_v1"]["train"], dls["mouse_v1"]["val"]
    train_loss = train(
        model=decoder,
        dataloader=train_dataloader,
        opter=opter,
        loss_fn=loss_fn,
    )
    val_losses = val(
        model=decoder,
        dataloader=val_dataloader,
        loss_fn=loss_fn,
    )

    ### save best model
    if val_losses["total"] < best["val_loss"]:
        best["val_loss"] = val_losses["total"]
        best["epoch"] = epoch
        best["model"] = deepcopy(decoder.state_dict())

    ### log
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_losses["total"])
    if config["wandb"]: wdb_run.log({"train_loss": train_loss, "val_loss": val_losses["total"]}, commit=False)
    print(f"{train_loss=:.4f}, {val_losses['total']=:.4f}", end="")
    for data_key, loss in val_losses.items():
        if data_key != "total":
            print(f", {data_key}: {loss:.4f}", end="")
    print("")

    ### plot reconstructions
    stim_pred = decoder(
        resp[:8].to(config["device"]),
        data_key=sample_data_key,
        neuron_coords=neuron_coords[sample_data_key],
        pupil_center=pupil_center[:8].to(config["device"]),
    ).detach()
    fig = plot_comparison(target=crop(stim[:8], config["crop_win"]).cpu(), pred=crop(stim_pred[:8], config["crop_win"]).cpu(), save_to=make_sample_path(epoch, ""))
    if config["wandb"]: wdb_run.log({"val_stim_reconstruction": fig})

    ### plot losses
    if epoch % 5 == 0 and epoch > 0:
        plot_losses(history=history, epoch=epoch, save_to=os.path.join(config["dir"], f"losses_{epoch}.png") if config["decoder"]["save_run"] else None)

    ### save ckpt
    if epoch % 5 == 0 and epoch > 0:
        ### ckpt
        if config["decoder"]["save_run"]:
            torch.save({
                "decoder": decoder.state_dict(),
                "opter": opter.state_dict(),
                "history": history,
                "config": config,
                "best": best,
            }, os.path.join(config["dir"], "ckpt", f"decoder_{epoch}.pt"), pickle_module=dill)

In [None]:
### final evaluation + logging + saving
print(f"Best val loss: {best['val_loss']:.4f} at epoch {best['epoch']}")

### save final ckpt
if config["decoder"]["save_run"]:
    torch.save({
        "decoder": decoder.state_dict(),
        "opter": opter.state_dict(),
        "history": history,
        "config": config,
        "best": best,
    }, os.path.join(config["dir"], f"decoder.pt"), pickle_module=dill)

### eval on test set w/ current params
print("Evaluating on test set with current model...")
dls, neuron_coords = get_all_data(config=config)
test_loss_curr = val(
    model=decoder,
    dataloader=dls["mouse_v1"]["test"],
    loss_fn=loss_fn,
)
print(f"  Test loss (current model): {test_loss_curr['total']:.4f}")

stim_pred_curr = decoder(
    resp.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_data_key],
    pupil_center=pupil_center.to(config["device"]),
).detach().cpu()
fig = plot_comparison(
    target=crop(stim[:8], config["crop_win"]).cpu(),
    pred=crop(stim_pred_curr[:8], config["crop_win"]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_latest_model.png") if config["decoder"]["save_run"] else None,
    pred_title="Reconstructed (latest model)"
)


### load best model
decoder.load_state_dict(best["model"])

### eval on test set w/ best params
print("Evaluating on test set with best model...")
dls, neuron_coords = get_all_data(config=config)
test_loss_final = val(
    model=decoder,
    dataloader=dls["mouse_v1"]["test"],
    loss_fn=loss_fn,
)
print(f"  Test loss (best model): {test_loss_final['total']:.4f}")

### plot reconstructions of the final model
stim_pred_best = decoder(
    resp.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_data_key],
    pupil_center=pupil_center.to(config["device"]),
).detach().cpu()
fig = plot_comparison(
    target=crop(stim[:8], config["crop_win"]).cpu(),
    pred=crop(stim_pred_best[:8], config["crop_win"]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_best.png") if config["decoder"]["save_run"] else None,
)

### log
if config["wandb"]:
    wandb.run.summary["best_val_loss"] = best["val_loss"]
    wandb.run.summary["best_epoch"] = best["epoch"]
    wandb.run.summary["curr_test_loss"] = test_loss_curr["total"]
    wandb.run.summary["final_test_loss"] = test_loss_final["total"]
    wandb.run.summary["best_reconstruction"] = fig

### save/delete wandb run
if config["wandb"]:
    if input("Delete run with 'd', save with anything else: ") == "d":
        print("Deleting wandb run...")
        api = wandb.Api()
        run = api.run(f"johnny1188/{config['wandb']['project']}/{wdb_run.id}")
        run.delete()
    else:
        wdb_run.finish()

### plot losses
plot_losses(
    history=history,
    save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses.png"),
)

In [None]:
### show training reconstructions
datapoint_training = next(iter(dls["mouse_v1"]["train"].dataloaders[0]))
stim_training, resp_training, pupil_center_training = datapoint_training.images, datapoint_training.responses, datapoint_training.pupil_center
stim_training_pred_best = decoder(
    resp_training.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_data_key],
    pupil_center=pupil_center_training.to(config["device"]),
).detach().cpu()
plot_comparison(
    target=crop(stim_training[:8], config["crop_win"]).cpu(),
    pred=crop(stim_training_pred_best[:8], config["crop_win"]).cpu(),
    save_to=os.path.join(config["dir"], "training_stim_comparison_best.png") if config["decoder"]["save_run"] else None,
)