In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
from copy import deepcopy
import dill
import wandb
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt
from nnfabrik.builder import get_data
from neuralpredictors.data.transforms import NeuroNormalizer
from neuralpredictors.measures.modules import PoissonLoss
from lurz2020.models.models import se2d_fullgaussian2d
from lurz2020.utility.measures import get_correlations

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import crop, plot_comparison, standardize, normalize, get_mean_and_std, count_parameters, plot_losses
from csng.losses import SSIMLoss, MSELossWithCrop, Loss
from csng.data import MixedBatchLoader
from csng.data import MixedBatchLoader
from csng.readins import (
    MultiReadIn,
    HypernetReadIn,
    ConvReadIn,
    AttentionReadIn,
    FCReadIn,
    AutoEncoderReadIn,
    Conv1dReadIn,
)

# from models import MultiReadIn, ConvReadIn, Loss
from L2O_Decoder import L2O_Decoder, L2O_Shallow_Decoder
from data_utils import get_mouse_v1_data

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "mouse_v1_sensorium22")
print(f"{DATA_PATH=}")

%env "WANDB_NOTEBOOK_NAME" "l2o_decoder.ipynb"
wandb.login()

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
    "crop_win": (slice(7, 29), slice(15, 51)),
    # "wandb": None,
    "wandb": {
        "project": "CSNG",
        "group": "sensorium_2022",
    },
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Data

In [None]:
dataloaders = dict()

### Mouse V1 dataset (Sensorium 2022)

In [None]:
### prep data config
filenames = [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
    # "static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # mouse 1
    # "static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # sensorium+ (mouse 2)
    "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 3)
    "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 4)
    "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 5)
    "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 6)
    "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 7)
]
for f_idx, f_name in enumerate(filenames):
    filenames[f_idx] = os.path.join(DATA_PATH, f_name)

config["data"].update({
    "paths": filenames,
    "dataset_fn": "sensorium.datasets.static_loaders",
    "dataset_config": {
        "paths": filenames,
        "normalize": True,
        "scale": 0.25, # 256x144 -> 64x36
        "include_behavior": False,
        "add_behavior_as_channels": False,
        "include_eye_position": True,
        "exclude": None,
        "exclude": ["images"], # manual normalization of images
        "file_tree": True,
        "cuda": False,
        # "batch_size": 32,
        "batch_size": 7,
        "seed": config["seed"],
        "use_cache": False,
    },
    "normalize_neuron_coords": True,
})

In [None]:
def get_normalized_mouse_v1_data(config):
    ### insert normalization to [0,1] transform to all datasets
    ### - note: train, val, test all share one underlying dataset,
    ###         so we only need to add the transform to one of them
    
    dls, neuron_coords = get_mouse_v1_data(config=config)
    
    for d_idx in range(dls["mouse_v1"]["test"].n_dataloaders):
        tr_to_add = NeuroNormalizer(
            data=dls["mouse_v1"]["test"].datasets[d_idx],
            exclude=["behavior", "responses", "eye_position"],
            inputs_mean=0,
            inputs_std=255,
        )
        dls["mouse_v1"]["test"].datasets[d_idx].transforms.insert(0, tr_to_add)
    
    return dls, neuron_coords

In [None]:
### get dataloaders and cell coordinates
dataloaders, neuron_coords = get_normalized_mouse_v1_data(config)

In [None]:
### show data
sample_data_key = dataloaders["mouse_v1"]["val"].data_keys[0]
datapoint = next(iter(dataloaders["mouse_v1"]["val"].dataloaders[0]))
stim, resp = datapoint.images, datapoint.responses
pupil_center = datapoint.pupil_center
print(
    f"Training dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['train'].dataloaders)} samples"
    f"\nValidation dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['val'].dataloaders)} samples"
    f"\nTest dataset:\t\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['test'].dataloaders)} samples"
    f"\nTest (no resp) dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['test_no_resp'].dataloaders)} samples"

    "\n\nstimuli:"
    f"\n  {stim.shape}"
    f"\n  min={stim.min().item():.3f}  max={stim.max().item():.3f}"
    f"\n  mean={stim.mean().item():.3f}  std={stim.std().item():.3f}"
    "\nresponses:"
    f"\n  {resp.shape}"
    f"\n  min={resp.min().item():.3f}  max={resp.max().item():.3f}"
    f"\n  mean={resp.mean().item():.3f}  std={resp.std().item():.3f}"
    "\nneuronal coordinates:"
    f"\n  {neuron_coords[sample_data_key].shape}"
    f"\n  min={neuron_coords[sample_data_key].min():.3f}  max={neuron_coords[sample_data_key].max():.3f}"
    f"\n  mean={neuron_coords[sample_data_key].mean():.3f}  std={neuron_coords[sample_data_key].std():.3f}"
)

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(121)
ax.imshow(stim[0].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(122)
reshape_to = None
for i in range(30, 150):
    if resp.shape[-1] % i == 0:
        reshape_to = (i, resp.shape[-1] // i)
        break
if reshape_to != None:
    ax.imshow(resp[0].view(reshape_to).squeeze(0).unsqueeze(-1).cpu(), cmap="gray")

plt.show()

## Encoder

In [None]:
### load encoder
print("Loading encoder...")

### load pretrained encoder ckpt
encoder_ckpt = torch.load(
    os.path.join(DATA_PATH, "models", "encoder.pt"),
    map_location=config["device"],
)

### get temporary dataloaders for the encoder
_dataloaders = get_data(
    encoder_ckpt["config"]["data"]["dataset_fn"],
    encoder_ckpt["config"]["data"]["dataset_config"]
)

### init encoder
encoder = se2d_fullgaussian2d(
    **encoder_ckpt["config"]["encoder"]["model_config"],
    dataloaders=_dataloaders,
    seed=encoder_ckpt["config"]["seed"],
).float()
encoder.load_state_dict(encoder_ckpt["encoder_state"], strict=True)
encoder.to(config["device"])
encoder.eval()

In [None]:
### validate encoder is working (corr on val set should be ~ 0.32)
train_correlation = get_correlations(encoder, _dataloaders["train"], device=config["device"], as_dict=False, per_neuron=False)
validation_correlation = get_correlations(encoder, _dataloaders["validation"], device=config["device"], as_dict=False, per_neuron=False)
test_correlation = get_correlations(encoder, _dataloaders["test"], device=config["device"], as_dict=False, per_neuron=False)

print(
    f"Correlation (train set):      {train_correlation:.3f}"
    f"\nCorrelation (validation set): {validation_correlation:.3f}"
    f"\nCorrelation (test set):       {test_correlation:.3f}"
)

### validate w/ my data (TODO: normalizing in L2O_Decoder, so this is not needed atm)
t = {dk: dl for dk, dl in zip(dataloaders["mouse_v1"]["train"].data_keys, dataloaders["mouse_v1"]["train"].dataloaders)}
v = {dk: dl for dk, dl in zip(dataloaders["mouse_v1"]["val"].data_keys, dataloaders["mouse_v1"]["val"].dataloaders)}
te = {dk: dl for dk, dl in zip(dataloaders["mouse_v1"]["test"].data_keys, dataloaders["mouse_v1"]["test"].dataloaders)}

### validate encoder is working (corr on val set should be ~ 0.32)
train_correlation = get_correlations(encoder, t, device=config["device"], as_dict=False, per_neuron=False)
validation_correlation = get_correlations(encoder, v, device=config["device"], as_dict=False, per_neuron=False)
test_correlation = get_correlations(encoder, te, device=config["device"], as_dict=False, per_neuron=False)

print(
    f"Correlation (train set):      {train_correlation:.3f}"
    f"\nCorrelation (validation set): {validation_correlation:.3f}"
    f"\nCorrelation (test set):       {test_correlation:.3f}"
)

In [None]:
### delete dataloaders
del _dataloaders

## Decoder

In [None]:
def train(model, dataloader, config, verbose=True):
    model.train()
    train_loss = 0
    n_batches = len(dataloader)

    ### run
    for batch_idx, b in enumerate(dataloader):
        loss = 0
        
        ### combine from all data keys
        for data_key, (stim, resp, neuron_coords, pupil_center) in b.items():
            ### train
            stim_pred, _ = model(
                x=resp,
                data_key=data_key,
                neuron_coords=neuron_coords,
                pupil_center=pupil_center,
                additional_core_inp=dict(
                    train=True,
                    resp=resp,
                    stim=stim,
                    neuron_coords=neuron_coords,
                    pupil_center=pupil_center,
                    data_key=data_key,
                    n_steps=config["decoder"]["n_steps"],
                    x_hat_history_iters=None,
                ),
            )

            ### log
            loss += config["decoder"]["model"]["core_config"]["stim_loss_fn"](
                stim_pred, stim, phase="train", data_key=data_key).item()
            model.set_additional_loss(
                inp={
                    "resp": resp,
                    "stim": stim,
                    "neuron_coords": neuron_coords,
                    "pupil_center": pupil_center,
                    "data_key": data_key,
                }, out={
                    "stim_pred": stim_pred,
                },
            )
            loss += model.get_additional_loss(data_key=data_key)

        loss /= len(b)
        train_loss += loss

        if verbose and batch_idx % 100 == 0:
            print(f"Training progress: [{batch_idx}/{n_batches} ({100. * batch_idx / n_batches:.0f}%)]"
                  f"  Loss: {loss:.6f}")

    train_loss /= n_batches 
    return train_loss

In [None]:
def val(model, dataloader, loss_fn, config, only_data_keys=None):
    model.eval()
    val_losses = {"total": 0}
    denom_data_keys = {}
    for b in dataloader:
        ### combine from all data keys
        for data_key, (stim, resp, neuron_coords, pupil_center) in b.items():
            if only_data_keys is not None and data_key not in only_data_keys:
                continue

            stim_pred, stim_pred_history = model(
                x=resp,
                data_key=data_key,
                neuron_coords=neuron_coords,
                pupil_center=pupil_center,
                additional_core_inp=dict(
                    train=False,
                    stim=None,
                    resp=resp,
                    neuron_coords=neuron_coords,
                    pupil_center=pupil_center,
                    data_key=data_key,
                    n_steps=config["decoder"]["n_steps"],
                    x_hat_history_iters=None,
                ),
            )
            loss = loss_fn(stim_pred, stim).item()
            val_losses[data_key] = loss if data_key not in val_losses else val_losses[data_key] + loss
            val_losses["total"] += loss / len(b)
            denom_data_keys[data_key] = denom_data_keys[data_key] + 1 if data_key in denom_data_keys else 1

    val_losses["total"] /= len(dataloader)
    for k in denom_data_keys:
        val_losses[k] /= denom_data_keys[k]
    return val_losses

In [None]:
config["decoder"] = {
    "model": {
        "readins_config": [
            {
                "data_key": data_key,
                "in_shape": d.n_neurons,
                "layers": [
                    (ConvReadIn, {
                        "shift_coords": True,
                        "learn_grid": False,
                        "grid_l1_reg": 3e-3,
                        "in_channels_group_size": 2,
                        "pointwise_conv_config": {
                            "in_channels": d.n_neurons,
                            "out_channels": 86,
                            "act_fn": nn.LeakyReLU,
                            "bias": False,
                            "batch_norm": True,
                            "dropout": 0.,
                        },
                        "gauss_blur": True,
                        "gauss_blur_kernel_size": 7,
                        # "gauss_blur_sigma": "fixed", # "fixed", "single", "per_neuron"
                        "gauss_blur_sigma": "per_neuron", # "fixed", "single", "per_neuron"
                        "gauss_blur_sigma_init": 1.5,
                    }),
                ],
            } for d, data_key in zip(dataloaders["mouse_v1"]["train"].datasets, dataloaders["mouse_v1"]["train"].data_keys)
        ],
        # "core_cls": L2O_Decoder,
        "core_cls": L2O_Decoder,
        "core_config": {
            "encoder": encoder.float(),
            # "resp_shape": resp.shape[1:],
            "resp_shape": [86],
            "stim_shape": (1, 36, 64),
            "in_shape": (4, 36, 64),
            # "reconstruction_init_method": "zero",
            # "resp_layers_cfg": {
            #     "layers": [
            #         ("fc", 288),
            #         ("unflatten", 1, (2, 9, 16)),
            #         ("deconv", 64, 9, 2, 4),
            #         ("deconv", 32, 7, 2, 3),
            #         ("deconv", 1, 4, 1, 0),
            #     ],
            #     "act_fn": nn.ReLU(),
            #     "out_act_fn": nn.Sigmoid(),
            #     "dropout": 0.2,
            #     "batch_norm": True,
            # },
            # "reconstruction_init_method": "resp_layers",
            "resp_layers_cfg": {
                "layers": [
                    # ("deconv", 64, 9, 2, 4),
                    # ("deconv", 32, 7, 2, 3),
                    # ("deconv", 1, 4, 1, 0),
                    ("deconv", 64, 7, 2, 3),
                    ("deconv", 32, 6, 2, 2),
                    # ("deconv", 32, 5, 1, 2),
                    ("deconv", 1, 3, 1, 0),
                ],
                "act_fn": nn.ReLU(),
                "out_act_fn": nn.Sigmoid(),
                "dropout": 0.2,
                "batch_norm": True,
            },
            "reconstruction_init_method": "resp_layers",
            "act_fn": nn.ReLU(),
            # "reconstruction_layers_cfg": {
            #     "layers": [
            #         ("conv", 64, 5, 1, 2),
            #         ("conv", 64, 5, 1, 2),
            #         ("conv", 32, 5, 1, 2),
            #         ("conv", 1, 3, 1, 1)
            #     ],
            #     "batch_norm": True,
            #     "dropout": 0.2,
            #     "act_fn": nn.LeakyReLU,
            #     "out_act_fn": nn.Identity,
            # },
            # "stim_loss_fn": MSELossWithCrop(config["stim_crop_win"]),
            # "stim_loss_fn": SSIMLoss(
            #     log_loss=True,
            #     inp_normalized=False,
            #     inp_standardized=True,
            # ),
            # "stim_loss_fn": lambda x_hat, x: 0.9 * ssim_loss(x_hat, x) + 0.1 * F.mse_loss(x_hat, x),
            # "resp_loss_fn": nn.MSELoss(),
            "resp_loss_fn": PoissonLoss(avg=True),
            "opter_cls": torch.optim.Adam,
            "opter_kwargs": {
                "lr": 0.001,
            },
            "unroll": 4,
            "preproc_grad": True,
            "device": config["device"],
        },
    },
    "loss": {
        "loss_fn": SSIMLoss(
            window=config["crop_win"],
            log_loss=True,
            inp_normalized=False,
            inp_standardized=True,
        ),
        "l1_reg_mul": 0,
        "l2_reg_mul": 1e-5,
        "con_reg_mul": 0,
        # "con_reg_loss_fn": SSIMLoss(
        #     window=config["crop_win"],
        #     log_loss=True,
        #     inp_normalized=True,
        #     inp_standardized=False,
        # ),
        "encoder": None,
        # "encoder": encoder,
    },
    "n_epochs": 200,
    "n_steps": 8,
    "load_ckpt": None,
    # "load_ckpt": {
    #     "run_name": "2024-02-25_17-07-04",
    #     "ckpt_path": os.path.join(DATA_PATH, "models", "l2o", "2024-02-25_17-07-04", "decoder.pt"),
    # },
    "save_run": True,
}

config["decoder"]["model"]["core_config"]["stim_loss_fn"] = Loss(model=None, config=config["decoder"]["loss"])
decoder = MultiReadIn(**config["decoder"]["model"]).to(config["device"])
decoder.core.stim_loss_fn.model = decoder
config["decoder"]["model"]["core_config"]["stim_loss_fn"] = decoder.core.stim_loss_fn
sl = SSIMLoss(
    window=config["crop_win"],
    log_loss=True,
    inp_normalized=True,
    inp_standardized=False,
)
val_loss_fn = lambda y_pred, y: sl(normalize(y_pred), normalize(y))

In [None]:
### prepare checkpointing
config["run_name"] = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if config["decoder"]["save_run"]:
    ### save config
    config["dir"] = os.path.join(DATA_PATH, "models", "l2o", config["run_name"])
    os.makedirs(config["dir"], exist_ok=True)
    with open(os.path.join(config["dir"], "config.json"), "w") as f:
        json.dump(config, f, indent=4, default=str)
    os.makedirs(os.path.join(config["dir"], "samples"), exist_ok=True)
    os.makedirs(os.path.join(config["dir"], "ckpt"), exist_ok=True)
    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )
    
    print(f"Run name: {config['run_name']}\nRun dir: {config['dir']}")
else:
    make_sample_path = lambda epoch, prefix: None
    print("WARNING: Not saving the run and the config.")

### wandb
if config["wandb"]:
    wdb_run = wandb.init(**config["wandb"], name=config["run_name"], config=config,
        tags=[
            config["decoder"]["model"]["core_cls"].__name__,
            config["decoder"]["model"]["readins_config"][0]["layers"][0][0].__name__,
        ],
        notes=None)
    wdb_run.watch(decoder)

In [None]:
### load ckpt
if config["decoder"]["load_ckpt"] != None:
    run_name = config["decoder"]["load_ckpt"]["run_name"] # "2023-09-24_18-49-50"
    ckpt = torch.load(config["decoder"]["load_ckpt"]["ckpt_path"], map_location=config["device"], pickle_module=dill)

    history = ckpt["history"]
    config = ckpt["config"]
    best = ckpt["best"]

    # decoder = L2O_Decoder(**config["decoder"]["model"]).to(config["device"])
    decoder = MultiReadIn(**config["decoder"]["model"]).to(config["device"])
    decoder.core.stim_loss_fn.model = decoder
    config["decoder"]["model"]["core_config"]["stim_loss_fn"] = decoder.core.stim_loss_fn
    decoder.load_state_dict(ckpt["decoder"])

    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )
else:
    history = {"train_loss": [], "val_loss": []}
    best = {"val_loss": np.inf, "epoch": 0, "model": None}

In [None]:
### print model
print(decoder(
    resp.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_data_key].to(config["device"]),
    pupil_center=pupil_center.to(config["device"]),
    additional_core_inp=dict(
        resp=resp.to(config["device"]),
        stim=stim.to(config["device"]),
        data_key=sample_data_key,
        n_steps=3,
        train=False,
        x_hat_history_iters=None,
    ))[0].shape)
print(f"Number of parameters: {count_parameters(decoder)}")

decoder

In [None]:
### train
s, e = len(history["train_loss"]), len(history["train_loss"]) + config["decoder"]["n_epochs"]
for epoch in range(s, e):
    print(f"[{epoch + 1}/{e}]")

    ### train and val
    dls, neuron_coords = get_normalized_mouse_v1_data(config=config)
    train_dataloader, val_dataloader = dls["mouse_v1"]["train"], dls["mouse_v1"]["val"]
    train_loss = train(
        model=decoder,
        dataloader=train_dataloader,
        config=config,
    )
    val_losses = val(
        model=decoder,
        dataloader=val_dataloader,
        loss_fn=val_loss_fn,
        # loss_fn=config["decoder"]["model"]["stim_loss_fn"],
        config=config,
    )

    ### save best model
    if val_losses["total"] < best["val_loss"]:
        best["val_loss"] = val_losses["total"]
        best["epoch"] = epoch
        best["model"] = deepcopy(decoder.state_dict())

    ### log
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_losses["total"])
    if config["wandb"]: wdb_run.log({"train_loss": train_loss, "val_loss": val_losses["total"]}, commit=False)
    print(f"{val_losses['total']=:.4f}", end="")
    for data_key, loss in val_losses.items():
        if data_key != "total":
            print(f", {data_key}: {loss:.4f}", end="")
    print("")

    ### plot sample reconstructions
    stim_pred = decoder(
        x=resp.to(config["device"]),
        data_key=sample_data_key,
        neuron_coords=neuron_coords[sample_data_key].to(config["device"]),
        pupil_center=pupil_center.to(config["device"]),
        additional_core_inp=dict(
            resp=resp.to(config["device"]),
            stim=stim.to(config["device"]),
            data_key=sample_data_key,
            n_steps=config["decoder"]["n_steps"],
            train=False,
            x_hat_history_iters=None,
        ),
    )[0].detach()
    fig = plot_comparison(target=crop(stim[:8], config["crop_win"]).cpu(), pred=crop(stim_pred[:8], config["crop_win"]).cpu(), save_to=make_sample_path(epoch, ""))
    if config["wandb"]: wdb_run.log({"val_stim_reconstruction": fig})

    ### plot losses
    if epoch % 5 == 0 and epoch > 0:
        plot_losses(history=history, epoch=epoch, save_to=os.path.join(config["dir"], f"losses_{epoch}.png") if config["decoder"]["save_run"] else None)

        ### ckpt
    if epoch % 5 == 0 and epoch > 0 and config["decoder"]["save_run"]:
        torch.save({
            "decoder": decoder.state_dict(),
            "opter": decoder.core.opter.state_dict(),
            "history": history,
            "config": config,
            "best": best,
        }, os.path.join(config["dir"], "ckpt", f"decoder_{epoch}.pt"), pickle_module=dill)

In [None]:
### final evaluation + logging + saving
print(f"Best val loss: {best['val_loss']:.4f} at epoch {best['epoch']}")

### save final ckpt
if config["decoder"]["save_run"]:
    torch.save({
        "decoder": decoder.state_dict(),
        "opter": decoder.core.opter.state_dict(),
        "history": history,
        "config": config,
        "best": best,
    }, os.path.join(config["dir"], f"decoder.pt"), pickle_module=dill)

### eval on test set w/ current params
print("Evaluating on test set with current model...")
dls, neuron_coords = get_mouse_v1_data(config=config)
test_loss_curr = val(
    model=decoder,
    dataloader=dls["mouse_v1"]["test"],
    loss_fn=val_loss_fn,
    config=config,
)
print(f"  Test loss (current model): {test_loss_curr['total']:.4f}")

### load best model
decoder.load_state_dict(best["model"])

### eval on test set w/ best params
print("Evaluating on test set with best model...")
dls, neuron_coords = get_mouse_v1_data(config=config)
test_loss_final = val(
    model=decoder,
    dataloader=dls["mouse_v1"]["test"],
    loss_fn=val_loss_fn,
    config=config,
)
print(f"  Test loss (best model): {test_loss_final['total']:.4f}")


### plot reconstructions of the final model
decoder.load_state_dict(best["model"])
stim_pred_best, recon_history = decoder(
    x=resp.to(config["device"]),
    data_key=sample_data_key,
    neuron_coords=neuron_coords[sample_data_key].to(config["device"]),
    pupil_center=pupil_center.to(config["device"]),
    additional_core_inp=dict(
        resp=resp.to(config["device"]),
        stim=stim.to(config["device"]),
        data_key=sample_data_key,
        n_steps=config["decoder"]["n_steps"],
        train=False,
        # x_hat_history_iters=None,
        x_hat_history_iters=list(range(1, config["decoder"]["n_steps"] + 1)),
    ),
)
stim_pred_best = stim_pred_best.detach()
fig = plot_comparison(
    target=crop(stim[:8], config["crop_win"]).cpu(),
    pred=crop(stim_pred_best[:8], config["crop_win"]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_best.png") if config["decoder"]["save_run"] else None,
)

### log
if config["wandb"]:
    wandb.run.summary["best_val_loss"] = best["val_loss"]
    wandb.run.summary["best_epoch"] = best["epoch"]
    wandb.run.summary["curr_test_loss"] = test_loss_curr["total"]
    wandb.run.summary["final_test_loss"] = test_loss_final["total"]
    wandb.run.summary["best_reconstruction"] = fig

### save/delete wandb run
if config["wandb"]:
    if input("Delete run with 'd', save with anything else: ") == "d":
        print("Deleting wandb run...")
        api = wandb.Api()
        run = api.run(f"johnny1188/{config['wandb']['project']}/{wdb_run.id}")
        run.delete()
    else:
        wdb_run.finish()

### plot losses
plot_losses(
    history=history,
    save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses.png"),
)

In [None]:
idx = 2

plt.imshow(stim[idx].squeeze().cpu().detach().numpy(), cmap="gray")
plt.show()
for x_hat in history["x_hat_history"]:
    plt.imshow(x_hat[idx].squeeze().cpu().detach().numpy(), cmap="gray")
    plt.show()