In [None]:
import os
import random
import torch
import pickle
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import lovely_tensors as lt
lt.monkey_patch()

from csng.utils.data import crop
from csng.data import get_dataloaders, get_sample_data
from csng.models.readins import MEIReadIn
from csng.utils.mix import seed_all, update_config_paths, update_config_keys_to_value, plot_comparison
from csng.models.utils.gan import init_decoder as init_gan_decoder
from csng.models.utils.cnn import init_decoder as init_cnn_decoder
from csng.models.utils.gan import train
from csng.utils.comparison import eval_decoder
from csng.losses import get_metrics
from csng.brainreader_mouse.encoder import get_encoder as get_encoder_brainreader
from csng.mouse_v1.encoder import get_encoder as get_encoder_sensorium_mouse_v1
from csng.cat_v1.encoder import get_encoder as get_encoder_cat_v1

### set paths
DATA_PATH = os.environ["DATA_PATH"]
DATA_PATH_CAT_V1 = os.path.join(DATA_PATH, "cat_V1_spiking_model", "50K_single_trial_dataset")
DATA_PATH_MOUSE_V1 = os.path.join(DATA_PATH, "mouse_v1_sensorium22")
DATA_PATH_BRAINREADER = os.path.join(DATA_PATH, "brainreader")
print(f"{DATA_PATH=}\n{DATA_PATH_CAT_V1=}\n{DATA_PATH_MOUSE_V1=}\n{DATA_PATH_BRAINREADER=}")

!nvidia-smi

In [None]:
### setup config
config = {
    "device": os.environ["DEVICE"],
    "seed": 0,
    "data": {"mixing_strategy": "sequential"},
    "crop_wins": {"cat_v1": (20, 20), "mouse_v1": (22, 36), "brainreader_mouse": (36, 64)},
}

print(f"... Running on {config['device']} ...")
seed_all(config["seed"])

# Data

## Load source dataset(s)

In [None]:
### prep data config
# brainreader mouse data
config["data"]["brainreader_mouse"] = {
    "device": config["device"],
    "mixing_strategy": config["data"]["mixing_strategy"],
    "max_batches": None,
    "data_dir": os.path.join(DATA_PATH_BRAINREADER, "data"),
    "batch_size": 8,
    # "sessions": list(range(1, 7)),
    "sessions": [6],
    "resize_stim_to": (36, 64),
    "normalize_stim": True,
    "normalize_resp": False,
    "div_resp_by_std": True,
    "clamp_neg_resp": False,
    "additional_keys": None,
    "avg_test_resp": True,
    "train_datapoint_idxs_to_use": None,
    # "train_datapoint_idxs_to_use": np.random.default_rng(seed=config["seed"]).choice(4500, size=int(4500 * 0.5), replace=False),
}

### cat v1 data
# config["data"]["cat_v1"] = {
#     "dataset_config": {
#         "train_path": os.path.join(DATA_PATH_CAT_V1, "datasets", "train"),
#         "val_path": os.path.join(DATA_PATH_CAT_V1, "datasets", "val"),
#         "test_path": os.path.join(DATA_PATH_CAT_V1, "datasets", "test"),
#         "image_size": [50, 50],
#         "crop": False,
#         "batch_size": 4,
#         "stim_keys": ("stim",),
#         "resp_keys": ("exc_resp", "inh_resp"),
#         "return_coords": True,
#         "return_ori": False,
#         "coords_ori_filepath": os.path.join(DATA_PATH_CAT_V1, "pos_and_ori.pkl"),
#         "cached": False,
#         "stim_normalize_mean": 46.143,
#         "stim_normalize_std": 24.960,
#         # "resp_normalize_mean": None, # don't center responses
#         "resp_normalize_std": torch.load(
#             os.path.join(DATA_PATH_CAT_V1, "responses_std.pt")
#         ),
#         "clamp_neg_resp": False,
#     },
# }

# # mouse v1 data
# config["data"]["mouse_v1"] = {
#     "dataset_fn": "sensorium.datasets.static_loaders",
#     "dataset_config": {
#         "paths": [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
#             os.path.join(DATA_PATH_MOUSE_V1, "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-1
#             # os.path.join(DATA_PATH_MOUSE_V1, "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-2
#             # os.path.join(DATA_PATH_MOUSE_V1, "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-3
#             # os.path.join(DATA_PATH_MOUSE_V1, "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-4
#             # os.path.join(DATA_PATH_MOUSE_V1, "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-5
#         ],
#         "normalize": True,
#         "z_score_responses": False,
#         "scale": 0.25, # 256x144 -> 64x36
#         "include_behavior": False,
#         "add_behavior_as_channels": False,
#         "include_eye_position": True,
#         "exclude": None,
#         "file_tree": True,
#         "cuda": "cuda" in config["device"],
#         # "batch_size": 5,
#         "batch_size": 16,
#         "drop_last": True,
#         "use_cache": False,
#         "train_datapoint_idxs_to_use": None,
#         # "train_datapoint_idxs_to_use": np.random.default_rng(seed=config["seed"]).choice(4473, size=int(4473 * 0.5), replace=False),
#     },
#     "crop_win": (22, 36),
#     "skip_train": False,
#     "skip_val": False,
#     "skip_test": False,
#     "normalize_neuron_coords": True,
#     "average_test_multitrial": True,
#     "save_test_multitrial": True,
#     "test_batch_size": 7,
#     "neuron_coords_to_use": None, # if None, uses the neuron coordinates from the dataset
#     "device": config["device"],
# }

In [None]:
dls, neuron_coords = get_dataloaders(config)
for tier, data_dict in dls.items():
    print(f"{tier}:")
    for data_name, dl in data_dict.items():
        print(f"  {data_name}: {len(dl)} batches")
        print(f"    data keys: {', '.join(dl.data_keys)}")
        print(f"    size of datasets: {', '.join([str(len(dl) * _dl.batch_size) for _dl in dl.dataloaders])}")

In [None]:
### get sample data
s = get_sample_data(dls=dls, config=config, sample_from_tier="val")
resp, stim, sample_dataset, sample_data_key = s["resp"], s["stim"], s["sample_dataset"], s["sample_data_key"]
print(f"{sample_dataset=}, {sample_data_key=}, {resp.shape=}, {stim.shape=}")

---
# Decoding models

## Inverted Encoder (InvEnc)

In [None]:
### load encoder
encoder = get_encoder_brainreader(
    ckpt_path=os.path.join(DATA_PATH, "models", "encoder_b6.pt"),
    eval_mode=True,
    device=config["device"],
)
encoder.training = False
encoder.eval()

## Generative Adversarial Network (GAN)

In [None]:
### config for model to load
config["decoder"] = {
    "load_ckpt": {
        "ckpt_path": os.path.join(
            DATA_PATH,
            "models",
            "gan",
            "2025-02-27_18-49-52",
            "decoder.pt",
        ),
        "load_only_core": False, # set to True if you want to keep only the core and reset the readins
        "load_best": True, # load best model (val. loss during pretraining)
        "load_opter_state": False, # don't load optimizer state for fine-tuning
        "load_history": False, # reset training history for fine-tuning
        "reset_best": True, # reset best-model tracking for fine-tuning
    },
}

In [None]:
### utility functions
def merge_configs_fn(cfg, ckpt_cfg):
    """ Utility function for mergins config from checkpoint (ckpt_cfg) with the current config """
    ckpt_cfg["decoder"]["load_ckpt"] = cfg["decoder"]["load_ckpt"]
    ckpt_cfg = update_config_keys_to_value(ckpt_cfg, "device", cfg["device"])
    cfg = ckpt_cfg
    cfg = update_config_paths(config=cfg, new_data_path=DATA_PATH)
    return cfg, ckpt_cfg

In [None]:
### load model
cfg, gan, loss_fn, history, best, ckpt = init_gan_decoder(config=config, merge_configs_fn=merge_configs_fn)

## Combined Inverted Encoder and Generative Adversarial Network (InvEnc-GAN)

In [None]:
from torch import nn
import torch.nn.functional as F
from torchvision.transforms import GaussianBlur
import featurevis
from featurevis import ops
from featurevis import utils as fvutils

In [None]:
def gan_resp_loss_fn(resp_pred, resp, x_hat, data_key=None, neuron_coords=None, pupil_center=None):
    # stim_pred_from_resp_pred = decoder(resp_pred, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)
    # stim_pred_from_resp_gt = decoder(resp, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)
    return decoder(resp_pred, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)

In [None]:
class InvertedEncoderDecoder(nn.Module):
    def __init__(
        self,
        encoder,
        decoder,
        img_dims=(1, 36, 64),
        stim_pred_init="zeros",
        opter_cls=torch.optim.SGD,
        opter_config={"lr": 50},
        n_steps=1000,
        resp_loss_fn=lambda resp_pred, resp_target: F.mse_loss(resp_pred, resp_target, reduction="none").mean(-1).sum(),
        stim_loss_fn=lambda stim_pred, stim_target: F.mse_loss(stim_pred, stim_target, reduction="none").mean((-1, -2, -3)).sum(),
        img_gauss_blur_config=None,
        img_gauss_blur_freq=1,
        img_grad_gauss_blur_config=None,
        img_grad_gauss_blur_freq=1,
        device="cpu",
    ):
        super().__init__()
        self.encoder = encoder.requires_grad_(False)
        self.encoder.training = False
        self.encoder.eval()

        self.decoder = decoder.requires_grad_(False)
        self.decoder.training = False
        self.decoder.eval()

        self.stim_pred_init = stim_pred_init
        self.img_dims = img_dims
        self.opter_cls = opter_cls
        self.opter_config = opter_config
        self.n_steps = n_steps
        self.resp_loss_fn = resp_loss_fn
        self.stim_loss_fn = stim_loss_fn
        
        self.img_gauss_blur_config = img_gauss_blur_config
        self.img_gauss_blur_freq = img_gauss_blur_freq
        self.img_gauss_blur = None if img_gauss_blur_config is None else GaussianBlur(**img_gauss_blur_config)
        self.img_grad_gauss_blur_config = img_grad_gauss_blur_config
        self.img_grad_gauss_blur_freq = img_grad_gauss_blur_freq
        self.img_grad_gauss_blur = None if img_grad_gauss_blur_config is None else GaussianBlur(**img_grad_gauss_blur_config)

        self.resp_pred = None
        self.history = None

        self.device = device

    def _init_x_hat(self, B, resp=None, data_key=None, neuron_coords=None, pupil_center=None):
        ### init decoded img
        if self.stim_pred_init == "zeros":
            x_hat = torch.zeros((B, *self.img_dims), requires_grad=True, device=self.device)
        elif self.stim_pred_init == "rand":
            x_hat = torch.rand((B, *self.img_dims), requires_grad=True, device=self.device)
        elif self.stim_pred_init == "randn":
            x_hat = torch.randn((B, *self.img_dims), requires_grad=True, device=self.device)
        elif self.stim_pred_init == "decoder":
            x_hat = self.decoder(resp, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)
            x_hat = x_hat.detach().clone().requires_grad_(True)
        else:
            raise ValueError(f"Unknown stim_pred_init: {self.stim_pred_init}")
        return x_hat

    def forward(self, resp, data_key=None, neuron_coords=None, pupil_center=None, ckpt_config=None, stim_target=None):
        assert resp.ndim > 1, "resp should be at least 2d (batch_dim, neurons_dim)"

        ### init decoded img
        x_hat = self._init_x_hat(resp.size(0) if resp.ndim > 1 else 1, resp=resp, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)

        ### optimize decoded img
        opter = self.opter_cls([x_hat], **self.opter_config)
        history = {"resp_loss": [], "stim_loss": [], "best": {"stim_loss": np.inf, "stim_pred": None}}
        for step_i in range(self.n_steps):
            opter.zero_grad()

            resp_pred = self.encoder(x_hat, data_key=data_key, pupil_center=pupil_center)
            resp_loss = self.resp_loss_fn(resp_pred, resp)

            # stim_pred_from_resp_pred = self.decoder(resp_pred, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)
            stim_pred_from_resp_gt = self.decoder(resp, data_key=data_key, neuron_coords=neuron_coords, pupil_center=pupil_center)
            stim_loss = 5e-3 * self.stim_loss_fn(stim_pred_from_resp_gt, x_hat)

            resp_loss.backward()

            ### apply gaussian blur to gradients
            if self.img_grad_gauss_blur is not None and step_i % self.img_grad_gauss_blur_freq == 0:
                x_hat.grad = self.img_grad_gauss_blur(x_hat.grad)

            ### update
            opter.step()
            if stim_target is not None:
                stim_loss = self.stim_loss_fn(x_hat.detach(), stim_target)
                history["stim_loss"].append(stim_loss.item())
                if stim_loss.item() < history["best"]["stim_loss"]:
                    history["best"]["stim_loss"] = stim_loss.item()
                    history["best"]["stim_pred"] = x_hat.detach().clone()

            ### apply gaussian blur to image
            if self.img_gauss_blur is not None and step_i % self.img_gauss_blur_freq == 0:
                with torch.no_grad():
                    x_hat.data = self.img_gauss_blur(x_hat)

            ### log
            history["resp_loss"].append(resp_loss.item())

            ### ckpt
            if ckpt_config is not None and step_i % ckpt_config["ckpt_freq"] == 0:
                curr_ckpt_dir = os.path.join(ckpt_config["ckpt_dir"], str(step_i))
                os.makedirs(curr_ckpt_dir)
                torch.save({
                    "reconstruction": x_hat,
                    "history": history,
                    "opter_state": opter.state_dict(),
                }, os.path.join(curr_ckpt_dir, "ckpt.pt"), pickle_module=dill)
                if ckpt_config.get("plot_fn", None) is not None:
                    ckpt_config["plot_fn"](target=stim_target, pred=x_hat, save_to=os.path.join(curr_ckpt_dir, f"stim_pred.png"))

        self.resp_pred = resp_pred.detach()
        self.history = history
        return x_hat.detach()

In [None]:
for _lr in [5, 50, 150]:
    for _n_steps in [100, 800, 2000]:
        inv_enc_dec = InvertedEncoderDecoder(
            encoder=encoder,
            decoder=gan,
            **{
                "img_dims": (1, 36, 64),
                "stim_pred_init": "zeros",
                "opter_config": {"lr": _lr},
                "n_steps": _n_steps,
                "img_grad_gauss_blur_config": {"kernel_size": 13, "sigma": 1.},
                "device": config["device"],
            },
        )
        stim_pred = inv_enc_dec(
            resp.to(config["device"]),
            data_key=sample_data_key,
            neuron_coords=neuron_coords[sample_dataset],
            pupil_center=None,
        ) # (batch, n_channels, h, w)

        ### show comparison
        print(f"{_lr=}, {_n_steps=}")
        fig = plot_comparison(
            target=crop(stim[:8], config["crop_wins"][sample_dataset]).cpu(),
            pred=crop(stim_pred[:8], config["crop_wins"][sample_dataset]).cpu(),
            show=True,
        )

In [None]:
inv_enc_dec = InvertedEncoderDecoder(
    encoder=encoder,
    decoder=gan,
    **{
        "img_dims": (1, 36, 64),
        "stim_pred_init": "zeros",
        "opter_config": {"lr": 50},
        "n_steps": 1000,
        "img_grad_gauss_blur_config": {"kernel_size": 13, "sigma": 1},
        "device": config["device"],
    },
)

In [None]:
stim_pred = inv_enc_dec(
    resp.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_dataset],
    pupil_center=None,
) # (batch, n_channels, h, w)

In [None]:
### show comparison
fig = plot_comparison(
    target=crop(stim[:8], config["crop_wins"][sample_dataset]).cpu(),
    pred=crop(stim_pred[:8], config["crop_wins"][sample_dataset]).cpu(),
    show=True,
)

In [None]:
### show comparison
fig = plot_comparison(
    target=crop(stim[:8], config["crop_wins"][sample_dataset]).cpu(),
    pred=crop(stim_pred[:8], config["crop_wins"][sample_dataset]).cpu(),
    show=True,
)