In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
from copy import deepcopy
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, SSIMLossWithCrop, MSELossWithCrop

from data import prepare_v1_dataloaders, SyntheticDataset, BatchPatchesDataLoader, MixedBatchLoader, PerSampleStoredDataset
from L2O_Decoder import L2O_Decoder

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [None]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Data

In [None]:
dataloaders = dict()

### V1 dataset (spiking model of cat V1)

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    # "crop": True,
    # "batch_size": 48,
    "batch_size": 30,
    
    ### stim normalization
    # "stim_normalize_mean": 46.236,
    # "stim_normalize_std": 21.196,
    ### stim standardization (important for setting the right SSIM loss args)
    "stim_normalize_mean": 0,
    "stim_normalize_std": 100,

    "resp_normalize_mean": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_mean_from_training_dataset.npy")
    )).float(),
    "resp_normalize_std": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_std_from_training_dataset.npy")
    )).float(),
}

In [None]:
### get data statistics
# data_loaders = prepare_data_loaders(**config["data"])
# data_stats = get_mean_and_std(dataset=data_loaders["train"].dataset, verbose=True)
# for k in data_stats:
#     for ks in data_stats[k]:
#         print(f"{k}.{ks}: {data_stats[k][ks]}")

In [None]:
### get data loaders
dataloaders["v1_data"] = prepare_v1_dataloaders(**config["data"]["v1_data"])

In [None]:
### show data
stim, resp = next(iter(dataloaders["v1_data"]["val"]))
print(
    f"{stim.shape=}, {resp.shape=}"
    f"\n{stim.min()=}, {stim.max()=}"
    f"\n{resp.min()=}, {resp.max()=}"
    f"\n{stim.mean()=}, {stim.std()=}"
    f"\n{resp.mean()=}, {resp.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

### Synthetic data (different stimuli dataset -> encoder -> neuronal responses)

In [None]:
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data", "responses_mean.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data", "responses_std.npy"))).float()

config["data"]["syn_data"] = {
    "dataset": {
        ### stim normalization
        # "stim_transform": transforms.Normalize(
        #     mean=114.457,
        #     std=51.356,
        # ),
        ### stim standardization (important to choose for SSIM loss)
        "stim_transform": transforms.Normalize(
            mean=0,
            std=255,
        ),

        "resp_transform": csng.utils.Normalize(
            mean=resp_mean,
            std=resp_std,
        ),
    },
    "dataloader": {
        "batch_size": 10,
        "shuffle": True,
    }
}

In [None]:
syn_datasets = {
    "train": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "train"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "val": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "val"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "test": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "test"),
        **config["data"]["syn_data"]["dataset"]
    ),
}

dataloaders["syn_data"] = {
    "train": DataLoader(
        dataset=syn_datasets["train"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "val": DataLoader(
        dataset=syn_datasets["val"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "test": DataLoader(
        dataset=syn_datasets["test"],
        **config["data"]["syn_data"]["dataloader"],
    ),
}

In [None]:
### calculate statistics

### for stimuli
# syn_stats = get_mean_and_std(dataset=syn_dataset, verbose=True)
# syn_stats

### for responses
# from csng.utils import RunningStats

# stats = RunningStats(num_components=10000, lib="torch", device="cuda")
# for i, (s, r) in enumerate(syn_dataloader):
#     stats.update(r)
#     if i % 200 == 0:
#         print(f"{i}: {r.mean()=} {r.std()=} {stats.get_mean()=} {stats.get_std()=}")

### save
# torch.save(stats.get_mean(), os.path.join(DATA_PATH, "responses_mean_from_syn_dataset.pt"))
# torch.save(stats.get_std(), os.path.join(DATA_PATH, "responses_std_from_syn_dataset.pt"))

In [None]:
### show data
syn_stim, syn_resp = next(iter(dataloaders["syn_data"]["val"]))
print(
    f"{syn_stim.shape=}, {syn_resp.shape=}"
    f"\n{syn_stim.min()=}, {syn_stim.max()=}"
    f"\n{syn_resp.min()=}, {syn_resp.max()=}"
    f"\n{syn_stim.mean()=}, {syn_stim.std()=}"
    f"\n{syn_resp.mean()=}, {syn_resp.std()=}"
)

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(syn_stim.cpu()[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(syn_resp.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Decoder

In [None]:
def train(model, dataloader, config, verbose=True):
    model.train()
    train_loss = 0
    n_batches = len(dataloader)

    ### run
    for batch_idx, (stim, resp) in enumerate(dataloader):
        ### data
        stim = stim.to(config["device"])
        resp = resp.to(config["device"])
        
        ### train
        stim_pred, stim_pred_history = model.run_batch(
            train=True,
            stim=stim,
            resp=resp,
            n_steps=config["decoder"]["n_steps"],
            x_hat_history_iters=None,
        )

        ### log
        loss = config["decoder"]["model"]["stim_loss_fn"](stim_pred, stim)
        train_loss += loss.item()
        if verbose and batch_idx % 100 == 0:
            print(f"Training progress: [{batch_idx}/{n_batches} ({100. * batch_idx / n_batches:.0f}%)]"
                  f"  Loss: {loss.item():.6f}")
        batch_idx += 1

    train_loss /= n_batches
    return train_loss

In [None]:
def val(model, dataloader, loss_fn, config):
    model.eval()
    val_loss = 0
    for batch_idx, (stim, resp) in enumerate(dataloader):
        ### data
        stim = stim.to(config["device"])
        resp = resp.to(config["device"])
        
        stim_pred, stim_pred_history = model.run_batch(
            train=False,
            stim=None,
            resp=resp,
            n_steps=config["decoder"]["n_steps"],
            x_hat_history_iters=None,
        )
        loss = loss_fn(stim_pred, stim)

        ### log
        val_loss += loss.item()

    val_loss /= len(dataloader)
    return val_loss

In [None]:
def get_dataloaders(config, dataloaders, use_data_names, only_v1_data_eval=True):
    if only_v1_data_eval:
        val_dataloader = dataloaders["v1_data"]["val"]

    ### get dataloaders to mix
    dataloaders_to_mix = []
    for data_name in use_data_names:
        dataloaders_to_mix.append(dataloaders[data_name])

    if len(dataloaders_to_mix) > 1:
        train_dataloader = MixedBatchLoader(
            dataloaders=[dl["train"] for dl in dataloaders_to_mix],
            mixing_strategy=config["data"]["mixing_strategy"],
            device=config["device"],
        )
        if not only_v1_data_eval:
            val_dataloader = MixedBatchLoader(
                dataloaders=[dl["val"] for dl in dataloaders_to_mix],
                mixing_strategy=config["data"]["mixing_strategy"],
                device=config["device"],
            )
    elif len(dataloaders_to_mix) == 1:
        train_dataloader = dataloaders_to_mix[0]["train"]
        if not only_v1_data_eval:
            val_dataloader = dataloaders_to_mix[0]["val"]
    else:
        raise ValueError("No data to train on.")

    return train_dataloader, val_dataloader

In [None]:
### load encoder
from data_orig import prepare_spiking_data_loaders
from lurz2020.models.models import se2d_fullgaussian2d

print("Loading encoder...")

### config only for the encoder
spiking_data_loaders_config = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 32,
}
encoder_config = {
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    "input_kern": 19,
    "hidden_kern": 17,
    "hidden_channels": 32,
    "gamma_input": 1.0,
    "gamma_readout": 2.439,
    "grid_mean_predictor": None,
    "layers": 5
}

### encoder
_dataloaders = prepare_spiking_data_loaders(**spiking_data_loaders_config)
encoder = se2d_fullgaussian2d(
    **encoder_config,
    dataloaders=_dataloaders,
    seed=2,
).float()
del _dataloaders

### load pretrained core
pretrained_core = torch.load(
    os.path.join(DATA_PATH, "models", "spiking_scratch_tunecore_68Y_model.pth"),
    map_location=config["device"],
)
encoder.load_state_dict(pretrained_core, strict=True)
encoder.to(config["device"])
encoder.eval()

In [None]:
ssim_loss = SSIMLoss(
    window=config["stim_crop_win"],
    log_loss=True,
    inp_normalized=True,
    inp_standardized=False,
)
val_loss_fn = lambda stim_pred, stim: ssim_loss(normalize(stim_pred), normalize(stim))

In [None]:
config["decoder"] = {
    "model": {
        "encoder": encoder.float(),
        "resp_shape": (10000,),
        "stim_shape": (1, 50, 50),
        "in_shape": (4, 50, 50),
        "resp_layers_cfg": {
            "layers": [
                ("fc", 300),
                ("unflatten", 1, (3, 10, 10)),
                ("deconv", 128, 7, 2, 1),
                ("deconv", 64, 5, 2, 1),
                ("deconv", 1, 4, 1, 0),
            ],
            "act_fn": nn.ReLU(),
            "out_act_fn": nn.Sigmoid(),
            "dropout": 0.2,
            "batch_norm": True,
        },
        "reconstruction_init_method": "resp_layers",
        "act_fn": nn.ReLU(),
        # "stim_loss_fn": MSELossWithCrop(config["stim_crop_win"]),
        "stim_loss_fn": SSIMLoss(
            window=config["stim_crop_win"],
            log_loss=True,
            inp_normalized=False,
            inp_standardized=True,
        ),
        # "stim_loss_fn": lambda x_hat, x: 0.9 * ssim_loss(x_hat, x) + 0.1 * F.mse_loss(x_hat, x),
        "resp_loss_fn": nn.MSELoss(),
        "opter_cls": torch.optim.Adam,
        "opter_kwargs": {
            "lr": 0.0005,
        },
        "unroll": 1,
        "preproc_grad": True,
        "device": config["device"],
    },
    "n_epochs": 120,
    "n_steps": 1,
    "save_run": True,
}

decoder = L2O_Decoder(**config["decoder"]["model"]).to(config["device"])

In [None]:
### prepare checkpointing
if config["decoder"]["save_run"]:
    ### save config
    run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    config["dir"] = os.path.join(DATA_PATH, "models", run_name)
    os.makedirs(config["dir"], exist_ok=True)
    with open(os.path.join(config["dir"], "config.json"), "w") as f:
        json.dump(config, f, indent=4, default=str)
    os.makedirs(os.path.join(config["dir"], "samples"), exist_ok=True)
    os.makedirs(os.path.join(config["dir"], "ckpt"), exist_ok=True)
    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )
    
    print(f"Run name: {run_name}\nRun dir: {config['dir']}")
else:
    make_sample_path = lambda epoch, prefix: None
    print("WARNING: Not saving the run and the config.")

In [None]:
### load ckpt
run_name = "2023-09-24_18-49-50"
ckpt = torch.load(os.path.join(DATA_PATH, "models", run_name, "ckpt", "decoder_40.pt"))

history = ckpt["history"]
config = ckpt["config"]
best = ckpt["best"]

decoder = L2O_Decoder(**config["decoder"]["model"]).to(config["device"])
decoder.load_state_dict(ckpt["decoder"])

make_sample_path = lambda epoch, prefix: os.path.join(
    config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
)

In [None]:
### print model
print(decoder.run_batch(
    stim=stim.to(config["device"]),
    resp=resp.to(config["device"]),
    n_steps=1,
    train=False,
    x_hat_history_iters=None,
)[0].shape)
print(f"Number of parameters: {count_parameters(decoder)}")

decoder

In [None]:
def plot_losses(history, save_to=None):
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111)
    ax.plot(history["train_loss"], label="train")
    ax.plot(history["val_loss"], label="val")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()

    if save_to:
        fig.savefig(save_to)
    ### save fig
    if config["decoder"]["save_run"]:
        fig.savefig(os.path.join(config["dir"], f"losses_{epoch}.png"))

    plt.show()

In [None]:
data_names = list(dataloaders.keys())
print(f"{data_names=}")

In [None]:
### train
history = {"train_loss": [], "val_loss": []}
best = {"val_loss": np.inf, "epoch": 0, "model": None}
s, e = len(history["train_loss"]), len(history["train_loss"]) + config["decoder"]["n_epochs"]
for epoch in range(s, e):
    print(f"[{epoch + 1}/{e}]")

    ### train and val
    train_dataloader, val_dataloader = get_dataloaders(
        config=config,
        dataloaders=dataloaders,
        use_data_names=data_names,
        only_v1_data_eval=config["only_v1_data_eval"],
    )
    train_loss = train(
        model=decoder,
        dataloader=train_dataloader,
        config=config,
    )
    val_loss = val(
        model=decoder,
        dataloader=val_dataloader,
        # loss_fn=config["decoder"]["model"]["stim_loss_fn"],
        loss_fn=val_loss_fn,
        config=config,
    )

    ### save best model
    if val_loss < best["val_loss"]:
        best["val_loss"] = val_loss
        best["epoch"] = epoch
        best["model"] = deepcopy(decoder.state_dict())

    ### log
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    print(f"{train_loss=:.4f}, {val_loss=:.4f}")

    ### plot sample reconstructions
    stim_pred = decoder.run_batch(
        stim=stim.to(config["device"]),
        resp=resp.to(config["device"]),
        n_steps=config["decoder"]["n_steps"],
        train=False,
        x_hat_history_iters=None,
    )[0].detach()
    if "v1_data" in config["data"] and config["data"]["v1_data"]["crop"] == False:
        plot_comparison(target=crop_stim(stim[:8]).cpu(), pred=crop_stim(stim_pred[:8]).cpu(), save_to=make_sample_path(epoch, ""))
    else:
        plot_comparison(target=stim[:8].cpu(), pred=stim_pred[:8].cpu(), save_to=make_sample_path(epoch, "no_crop_"))

    ### plot losses
    if epoch % 5 == 0 and epoch > 0:
        plot_losses(history=history)

        ### ckpt
        if config["decoder"]["save_run"]:
            torch.save({
                "decoder": decoder.state_dict(),
                "opter": decoder.opter.state_dict(),
                "history": history,
                "config": config,
                "best": best,
            }, os.path.join(config["dir"], "ckpt", f"decoder_{epoch}.pt"))

In [None]:
val(
    model=decoder,
    dataloader=val_dataloader,
    loss_fn=config["decoder"]["model"]["stim_loss_fn"],
    # loss_fn=ssim_loss,
    config=config,
)

In [None]:
print(f"Best val loss: {best['val_loss']:.4f} at epoch {best['epoch']}")

### save final ckpt
if config["decoder"]["save_run"]:
    torch.save({
        "decoder": decoder.state_dict(),
        "opter": decoder.opter.state_dict(),
        "history": history,
        "config": config,
        "best": best,
    }, os.path.join(config["dir"], f"decoder.pt"))

### plot reconstructions of the final model
decoder.load_state_dict(best["model"])
stim_pred_best = decoder.run_batch(
    stim=stim.to(config["device"]),
    resp=resp.to(config["device"]),
    n_steps=config["decoder"]["n_steps"],
    train=False,
    x_hat_history_iters=None,
)[0].detach()
plot_comparison(
    target=crop_stim(stim[:8]).cpu(),
    pred=crop_stim(stim_pred_best[:8]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_best.png")
)

### plot losses
plot_losses(
    history=history,
    save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses_final.png"),
)