In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.GAN import GAN
from L2O_Decoder import L2O_Decoder
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, MSELossWithCrop

# from orig_data import prepare_spiking_data_loaders
from data import prepare_v1_dataloaders, SyntheticDataset, BatchPatchesDataLoader, MixedBatchLoader, PerSampleStoredDataset

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [None]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Data

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    # "crop": True,
    # "batch_size": 64,
    "batch_size": 20,
    "stim_normalize_mean": 46.236,
    "stim_normalize_std": 21.196,
    "resp_normalize_mean": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_mean_from_training_dataset.npy")
    )).float(),
    "resp_normalize_std": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_std_from_training_dataset.npy")
    )).float(),
}

In [None]:
### get data loaders
v1_dataloaders = prepare_v1_dataloaders(**config["data"]["v1_data"])

In [None]:
### show data
stim, resp = next(iter(v1_dataloaders["val"]))
print(
    f"{stim.shape=}, {resp.shape=}"
    f"\n{stim.min()=}, {stim.max()=}"
    f"\n{resp.min()=}, {resp.max()=}"
    f"\n{stim.mean()=}, {stim.std()=}"
    f"\n{resp.mean()=}, {resp.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

### Synthetic data (generated using the Encoder)

In [None]:
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_mean_from_syn_dataset.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_std_from_syn_dataset.npy"))).float()

config["data"]["syn_data"] = {
    "dataset": {
        "stim_transform": transforms.Normalize(
            mean=114.457,
            std=51.356,
        ),
        "resp_transform": csng.utils.Normalize(
            mean=resp_mean,
            std=resp_std,
        ),
    },
    "dataloader": {
        "batch_size": 20,
        "shuffle": True,
    }
}

In [None]:
syn_datasets = {
    "train": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "train"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "val": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "val"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "test": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "test"),
        **config["data"]["syn_data"]["dataset"]
    ),
}

syn_dataloaders = {
    "train": DataLoader(
        dataset=syn_datasets["train"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "val": DataLoader(
        dataset=syn_datasets["val"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "test": DataLoader(
        dataset=syn_datasets["test"],
        **config["data"]["syn_data"]["dataloader"],
    ),
}

In [None]:
### show data
syn_stim, syn_resp = next(iter(syn_dataloaders["val"]))
print(
    f"{syn_stim.shape=}, {syn_resp.shape=}"
    f"\n{syn_stim.min()=}, {syn_stim.max()=}"
    f"\n{syn_resp.min()=}, {syn_resp.max()=}"
    f"\n{syn_stim.mean()=}, {syn_stim.std()=}"
    f"\n{syn_resp.mean()=}, {syn_resp.std()=}"
)

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(syn_stim.cpu()[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(syn_resp.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Compare decoders

In [None]:
def eval(model, dataloader, loss_fn, normalize_decoded, config, device="cpu"):
    model.eval()
    val_loss = 0
    # with torch.no_grad():
    for batch_idx, (stim, resp) in enumerate(dataloader):
        stim = stim.to(device)
        resp = resp.to(device)

        if resp.ndim == 3:
            # resp = resp.mean(dim=1) # average over trials (test V1 dataset)
            resp = resp[:, 0, :] # take only the first trial
        
        if model.__class__.__name__ == "L2O_Decoder":
            stim_pred, _ = model.run_batch(
                train=False,
                stim=None,
                resp=resp,
                n_steps=config["decoder"]["n_steps"],
                x_hat_history_iters=None,
            )
        else:
            stim_pred = model(resp)
        
        if normalize_decoded:
            stim_pred = normalize(stim_pred)
        
        loss = loss_fn(stim_pred, stim)
        
        ### log
        val_loss += loss.item()
    
    val_loss /= len(dataloader)
    return val_loss

In [None]:
runs_to_compare = {
    # "0% - CNN-S": {
    #     "run_name": "...",
    # },
    "0%": {
    # "0% - CNN": {
        "run_name": "2023-08-06_20-13-46",
    },
    # "0% - CNN-L": {
    #     "run_name": "2023-08-10_00-02-52",
    # },
    # "25% - CNN": {
    #     "run_name": "2023-08-07_18-28-52",
    # },
    # "25% - CNN-L": {
    #     "run_name": "2023-08-14_23-24-26",
    # },
    # "50% - CNN-S": {
    #     "run_name": "2023-08-17_23-03-04",
    # },
    # "50% - CNN": {
    #     "run_name": "2023-08-07_08-50-53",
    # },
    # "50% - CNN-L": {
    #     "run_name": "2023-08-09_00-03-18",
    # },
    # "50% - CNN-L": { # G from GAN
    #     "run_name": "2023-10-02_10-09-20",
    # },
    # "75% - CNN": {
    #     "run_name": "2023-08-07_08-54-37",
    # },
    # "75% - CNN-L": {
    #     "run_name": "2023-08-09_23-42-36",
    # },
    # "100% - CNN": {
    #     "run_name": "2023-08-09_00-08-49",
    # },
    
    # "0% - GAN": {
    #     "run_name": "2023-08-26_16-34-36",
    # },
    # "50% - GAN": {
    #     "run_name": "2023-08-30_09-07-13",
    # },
    
    # "0% - L2O": {
    #     "run_name": "2023-08-21_23-07-49",
    # },
    # "25% - L2O": {
    #     "run_name": "2023-09-05_19-45-22",
    # },
    # "50% - L2O": {
    #     "run_name": "2023-09-02_15-09-52",
    # },

    "25%": {
        "run_name": "2023-08-25_09-09-51",
    },
    "50%": {
        "run_name": "2023-08-25_09-07-46",
    },
    "75%": {
        "run_name": "2023-09-14_20-14-24",
    },

    # "test": {
    #     "run_name": "2023-10-02_10-09-20",
    # },
}

loss_fns = {
    "Log SSIM Loss": SSIMLoss(
        window=config["stim_crop_win"],
        log_loss=True,
        inp_normalized=True,
    ),
    "SSIM Loss": SSIMLoss(
        window=config["stim_crop_win"],
        log_loss=False,
        inp_normalized=True,
    ),
    "MSE Loss": lambda x_hat, x: F.mse_loss(
        standardize(crop_stim(x_hat)),
        standardize(crop_stim(x))
    ),
}

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
### config for collecting results
plot_losses = False
plot_reconstructions = False
rerun_l2o_val_loss_during_training = True
rerun_all_val_losses = False

### load models
for k in runs_to_compare.keys():
    run_name = runs_to_compare[k]["run_name"]
    print(f"Loading {k} model (run name: {run_name})...")

    ### load ckpt
    if "l2o" in k.lower():
        ckpt = torch.load(os.path.join(DATA_PATH, "models", run_name, "decoder.pt"))
        config = ckpt["config"]
        decoder = L2O_Decoder(**config["decoder"]["model"]).to(device)
    elif "gan" in k.lower():
        ckpt = torch.load(os.path.join(DATA_PATH, "models", "gan", run_name, "decoder.pt"))
        config = ckpt["config"]
        decoder = GAN(**config["decoder"]["model"]).to(device)
    else:
        ckpt = torch.load(os.path.join(DATA_PATH, "models", run_name, "decoder.pt"))
        config = ckpt["config"]
        decoder = CNN_Decoder(**config["decoder"]["model"]).to(device)

    history = ckpt["history"]
    config["stim_crop_win"] = (slice(15, 35), slice(15, 35))
    best = ckpt["best"]

    ### rerun L2O val loss with the same loss function as other decoders use
    if rerun_all_val_losses or (rerun_l2o_val_loss_during_training and "l2o" in k.lower()):
        print("  Rerunning val loss...")
        history["val_loss"] = []
        ckpt_dir = os.path.join(DATA_PATH, "models", run_name, "ckpt")
        
        for epoch in range(config["decoder"]["n_epochs"]):
            ckpt_filepath = os.path.join(ckpt_dir, f"decoder_{epoch}.pt")
            if os.path.exists(ckpt_filepath):
                ckpt = torch.load(ckpt_filepath)
                decoder.load_state_dict(ckpt["decoder"])
                decoder.eval()
                val_loss = eval(
                    model=decoder,
                    dataloader=v1_dataloaders["val"],
                    loss_fn=loss_fns["Log SSIM Loss"],
                    normalize_decoded=True,
                    config=config,
                    device=device,
                )
                history["val_loss"].append(val_loss)
                print(f"{epoch} ", end="")
            else:
                history["val_loss"].append(np.nan)
        print()

    ### load best model
    if "model" in best.keys():
        decoder.load_state_dict(best["model"])
    elif "decoder" in best.keys():
        decoder.load_state_dict(best["decoder"])
    else:
        decoder.load_state_dict(best)
    decoder.eval()

    ### plot losses
    if plot_losses:
        fig = plt.figure(figsize=(10, 6))
        ax = fig.add_subplot(111)
        ax.plot(history["train_loss"], label="train")
        ax.plot(history["val_loss"], label="val")
        ax.set_xlabel("Epoch")
        ax.set_ylabel("Loss")
        ax.set_ylim(0, None)
        ax.legend()
        plt.show()

    ### plot reconstructions of the final model
    if plot_reconstructions:
        if "l2o" in k.lower():
            stim_pred, x_hat_history = decoder.run_batch(
                train=False,
                stim=None,
                resp=resp.to(device),
                n_steps=config["decoder"]["n_steps"],
                x_hat_history_iters=None,
            )
            stim_pred_best = normalize(stim_pred.detach().cpu())
        else:
            stim_pred_best = decoder(resp.to(device)).detach().cpu()
        plot_comparison(
            target=crop_stim(stim[:8]).cpu(),
            pred=crop_stim(stim_pred_best[:8]).cpu(),
            # save_to=os.path.join(config["dir"], "stim_comparison_best.png")
        )

    ### eval
    test_losses = dict()
    for loss_fn_name, loss_fn in loss_fns.items():
        test_losses[loss_fn_name] = eval(
            model=decoder,
            dataloader=v1_dataloaders["test"],
            loss_fn=loss_fn,
            normalize_decoded=True if "l2o" in k.lower() else False,
            config=config,
            device=device,
        )

    ### save
    runs_to_compare[k]["test_losses"] = test_losses
    for metric in history.keys():
        runs_to_compare[k][metric] = history[metric]
    runs_to_compare[k]["config"] = config
    runs_to_compare[k]["best_val_loss"] = best["val_loss"]

In [None]:
def autolabel(rects):
    """Attach a text label above each bar in *rects*, displaying its height.
    https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/barchart.html
    """
    for rect in rects:
        height = rect.get_height()
        ax.annotate(
            f"{height:.3f}",
            xy=(rect.get_x() + rect.get_width() / 2, height),
            xytext=(0, 3),  # 3 points vertical offset
            textcoords="offset points",
            ha='center', va='bottom',
            fontsize=13,
            rotation=90,
        )

In [None]:
### plot losses together

### config
to_plot = "val_loss"
conv_win = 10

### plot
fig = plt.figure(figsize=(12, 8))
ax = fig.add_subplot(111)

for k, run_dict in runs_to_compare.items():
    if k == "100%":
        continue
    if conv_win is not None and np.nan not in run_dict[to_plot]:
        vals_to_plot = np.convolve(run_dict[to_plot], np.ones(conv_win) / conv_win, mode="valid")
    else:
        vals_to_plot = run_dict[to_plot]
    ax.plot(
        [t for t in range(len(vals_to_plot)) if vals_to_plot[t] is not np.nan],
        [v for v in vals_to_plot if v is not np.nan],
        label=k,
        linewidth=3,
    )

if to_plot == "train_loss":
    ax.set_title("Training log SSIM loss (V1 data + x % of synthetic data)", fontsize=16)
elif to_plot == "val_loss":
    ax.set_title("Validation log SSIM loss on the V1 data", fontsize=16)
else:
    raise ValueError(f"Unknown loss type: {to_plot}")

ax.set_xlabel("Epoch", fontsize=15, labelpad=20)
ax.set_ylabel("Log SSIM loss", fontsize=15, labelpad=20)
ax.set_ylim(0.26, None)
# ax.set_xlim(0, 150)
ax.legend(
    # loc="upper right",
    # loc="upper center",
    loc="lower left",
    fontsize=13,
    frameon=False,
    # bbox_to_anchor=(1.0, 1.0),
    bbox_transform=ax.transAxes,
    title="% of syn. data in training",
    title_fontsize=15,
    ncol=4,
)
# increase width of legend lines
leg = ax.get_legend()
for legobj in leg.legendHandles:
    legobj.set_linewidth(4.0)


# set larger font for x and y ticks
ax.tick_params(axis="both", which="major", labelsize=14)

# remove top and right spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

plt.show()

In [None]:
### bar plot of test losses
fig = plt.figure(figsize=(18, 8))
ax = fig.add_subplot(111)

### grouped bar plot (run_dict["test_losses"] is a dict containing multiple losses)
bar_width = 0.9
losses_to_plot = [
    "SSIM Loss",
    "Log SSIM Loss",
    "MSE Loss",
]
colors = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
]
for i, (k, run_dict) in enumerate(runs_to_compare.items()):
    for j, loss in enumerate(losses_to_plot):
        rects = ax.bar(
            i - bar_width / len(losses_to_plot) + j * bar_width / len(losses_to_plot),
            run_dict["test_losses"][loss],
            width=bar_width / len(losses_to_plot),
            color=colors[j],
        )
        autolabel(rects)

### add legend with color explanation
from matplotlib import patches as mpatches
ax.legend(
    handles=[
        mpatches.Patch(color=colors[i], label=loss)
        for i, loss in enumerate(losses_to_plot)
    ],
    loc="upper center",
    bbox_to_anchor=(0.5, 1.14),
    ncol=len(losses_to_plot),
    fontsize=14,
    frameon=False,
)

ax.set_title(
    "Test Losses (test dataset only with V1 data)",
    fontsize=18,
    pad=70,
)
ax.set_xticks(range(len(runs_to_compare)))
ax.set_xticklabels(runs_to_compare.keys())
ax.tick_params(axis="both", which="major", labelsize=14)
ax.set_xlabel("Decoder", fontsize=14, labelpad=20)
ax.set_ylabel("Loss", fontsize=14, labelpad=20)
ax.set_ylim(0, None)

# remove top and right spines
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)

plt.show()