In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, MSELossWithCrop

from data import (
    prepare_v1_dataloaders,
    SyntheticDataset,
    BatchPatchesDataLoader,
    MixedBatchLoader,
    PerSampleStoredDataset,
)

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [None]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Data

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    # "crop": True,
    # "batch_size": 64,
    "batch_size": 20,
    "stim_normalize_mean": 46.236,
    "stim_normalize_std": 21.196,
    "resp_normalize_mean": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_mean_from_training_dataset.npy")
    )).float(),
    "resp_normalize_std": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_std_from_training_dataset.npy")
    )).float(),
}

In [None]:
### get data statistics
# v1_dataloaders = prepare_data_loaders(**config["data"])
# data_stats = get_mean_and_std(dataset=v1_dataloaders["train"].dataset, verbose=True)
# for k in data_stats:
#     for ks in data_stats[k]:
#         print(f"{k}.{ks}: {data_stats[k][ks]}")

In [None]:
### get data loaders
v1_dataloaders = prepare_v1_dataloaders(**config["data"]["v1_data"])

In [None]:
### show data
stim, resp = next(iter(v1_dataloaders["val"]))
print(
    f"{stim.shape=}, {resp.shape=}"
    f"\n{stim.min()=}, {stim.max()=}"
    f"\n{resp.min()=}, {resp.max()=}"
    f"\n{stim.mean()=}, {stim.std()=}"
    f"\n{resp.mean()=}, {resp.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

### Synthetic data (generated using the Encoder)

In [None]:
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_mean_from_syn_dataset.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "responses_std_from_syn_dataset.npy"))).float()

config["data"]["syn_data"] = {
    "dataset": {
        "stim_transform": transforms.Normalize(
            mean=114.457,
            std=51.356,
        ),
        "resp_transform": csng.utils.Normalize(
            mean=resp_mean,
            std=resp_std,
        ),
    },
    "dataloader": {
        "batch_size": 20,
        "shuffle": True,
    }
}

In [None]:
syn_datasets = {
    "train": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "train"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "val": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "val"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "test": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "test"),
        **config["data"]["syn_data"]["dataset"]
    ),
}

syn_dataloaders = {
    "train": DataLoader(
        dataset=syn_datasets["train"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "val": DataLoader(
        dataset=syn_datasets["val"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "test": DataLoader(
        dataset=syn_datasets["test"],
        **config["data"]["syn_data"]["dataloader"],
    ),
}

In [None]:
### calculate statistics

### for stimuli
# syn_stats = get_mean_and_std(dataset=syn_dataset, verbose=True)
# syn_stats

### for responses
# from csng.utils import RunningStats

# stats = RunningStats(num_components=10000, lib="torch", device="cuda")
# for i, (s, r) in enumerate(syn_dataloader):
#     stats.update(r)
#     if i % 200 == 0:
#         print(f"{i}: {r.mean()=} {r.std()=} {stats.get_mean()=} {stats.get_std()=}")

### save
# torch.save(stats.get_mean(), os.path.join(DATA_PATH, "responses_mean_from_syn_dataset.pt"))
# torch.save(stats.get_std(), os.path.join(DATA_PATH, "responses_std_from_syn_dataset.pt"))


### generate preprocessed synthetic data
# import pickle
# target_dir = os.path.join(DATA_PATH, "synthetic_data", "processed")

# for data_split in ("train", "val", "test"):
#     print(data_split)
#     ### get the whole batch from dataloaders and save to disk
#     sample_idx = 0
#     for stim, resp in syn_dataloaders[data_split]:
#         if sample_idx % 2000 == 0:
#             print("  ", sample_idx)
        
#         for i in range(stim.shape[0]):
#             sample_idx += 1
#             save_to = os.path.join(target_dir, data_split, f"{sample_idx}.pickle")

#             data = {"stim": stim[i].cpu(), "resp": resp[i].cpu()}
#             with open(save_to, "wb") as f:
#                 pickle.dump(data, f)

In [None]:
### show data
syn_stim, syn_resp = next(iter(syn_dataloaders["val"]))
print(
    f"{syn_stim.shape=}, {syn_resp.shape=}"
    f"\n{syn_stim.min()=}, {syn_stim.max()=}"
    f"\n{syn_resp.min()=}, {syn_resp.max()=}"
    f"\n{syn_stim.mean()=}, {syn_stim.std()=}"
    f"\n{syn_resp.mean()=}, {syn_resp.std()=}"
)

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(syn_stim.cpu()[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(syn_resp.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Decoder

In [None]:
def train(model, dataloader, opter, loss_fn, config, l1_reg_mul=0, l2_reg_mul=0, verbose=True):
    model.train()
    train_loss = 0
    n_batches = len(dataloader)
    
    ### run
    for batch_idx, (stim, resp) in enumerate(dataloader):
        ### data
        stim = stim.to(config["device"])
        resp = resp.to(config["device"])
        
        ### train
        opter.zero_grad()
        stim_pred = model(resp)
        loss = loss_fn(stim_pred, stim)

        ### regularization
        if l1_reg_mul != 0:
            l1_reg = sum(p.abs().sum() for n, p in model.named_parameters() if p.requires_grad and "weight" in n)
            loss += l1_reg_mul * l1_reg
        if l2_reg_mul != 0:
            l2_reg = sum(p.pow(2.0).sum() for n, p in model.named_parameters() if p.requires_grad and "weight" in n)
            loss += l2_reg_mul * l2_reg

        loss.backward()
        opter.step()
        
        ### log
        train_loss += loss.item()
        if verbose and batch_idx % 100 == 0:
            print(f"Training progress: [{batch_idx}/{n_batches} ({100. * batch_idx / n_batches:.0f}%)]"
                  f"  Loss: {loss.item():.6f}")
        batch_idx += 1

    train_loss /= n_batches
    return train_loss

In [None]:
def val(model, dataloader, loss_fn, config):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (stim, resp) in enumerate(dataloader):
            stim = stim.to(config["device"])
            resp = resp.to(config["device"])
            
            stim_pred = model(resp)
            loss = loss_fn(stim_pred, stim)
            
            ### log
            val_loss += loss.item()
    
    val_loss /= len(dataloader)
    return val_loss

In [None]:
def get_dataloaders(config, v1_dataloaders, syn_dataloaders, only_v1_data_eval=True):
    if "v1_data" in config["data"] and "syn_data" in config["data"]:
        train_dataloader = MixedBatchLoader(
            dataloaders=[v1_dataloaders["train"], syn_dataloaders["train"]],
            mixing_strategy=config["data"]["mixing_strategy"],
            device=config["device"],
        )
        if only_v1_data_eval:
            val_dataloader = v1_dataloaders["val"]
        else:
            val_dataloader = MixedBatchLoader(
                dataloaders=[v1_dataloaders["val"], syn_dataloaders["val"]],
                mixing_strategy=config["data"]["mixing_strategy"],
                device=config["device"],
            )
    elif "v1_data" in config["data"]:
        train_dataloader = v1_dataloaders["train"]
        val_dataloader = v1_dataloaders["val"]
    elif "syn_data" in config["data"]:
        train_dataloader = syn_dataloaders["train"]
        val_dataloader = syn_dataloaders["val"]
    else:
        raise ValueError("No data to train on.")

    return train_dataloader, val_dataloader

In [None]:
config["decoder"] = {
    "model": {
        "resp_shape": (10000,),
        "stim_shape": (1, 50, 50),
        "layers": [
            ("fc", 384),  # CNN Baseline
            ("unflatten", 1, (6, 8, 8)),  # CNN Baseline
            # ("fc", 576),  # CNN Baseline large
            # ("unflatten", 1, (9, 8, 8)),  # CNN Baseline large
            ("deconv", 256, 7, 2, 0),
            ("deconv", 64, 5, 2, 0),
            ("deconv", 32, 4, 1, 0),
            # ("deconv", 64, 5, 1, 1),  # CNN Baseline large
            # ("deconv", 32, 4, 1, 1),  # CNN Baseline large
            ("deconv", 1, 3, 1, 0),
        ],
        "act_fn": nn.ReLU,
        "out_act_fn": nn.Identity,
        "dropout": 0.2,
        "batch_norm": True,
    },
    "opter_cls": torch.optim.Adam,
    "opter_kwargs": {
        "lr": 0.003,
    },
    # "loss_fn": nn.MSELoss(),
    # "loss_fn": MSELossWithCrop(window=config["stim_crop_win"]),
    "loss_fn": SSIMLoss(
        window=config["stim_crop_win"],
        log_loss=True,
        inp_normalized=True,
    ),
    "l1_reg_mul": 0,
    "l2_reg_mul": 1e-5,
    "n_epochs": 100,
    "save_run": True,
}

decoder = CNN_Decoder(**config["decoder"]["model"]).to(config["device"])
opter = config["decoder"]["opter_cls"](decoder.parameters(), **config["decoder"]["opter_kwargs"])
loss_fn = config["decoder"]["loss_fn"]() if type(config["decoder"]["loss_fn"]) == type else config["decoder"]["loss_fn"]

In [None]:
### prepare checkpointing
if config["decoder"]["save_run"]:
    ### save config
    run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    config["dir"] = os.path.join(DATA_PATH, "models", run_name)
    os.makedirs(config["dir"], exist_ok=True)
    with open(os.path.join(config["dir"], "config.json"), "w") as f:
        json.dump(config, f, indent=4, default=str)
    os.makedirs(os.path.join(config["dir"], "samples"), exist_ok=True)
    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )

    print(f"Run name: {run_name}\nRun dir: {config['dir']}")
else:
    make_sample_path = lambda epoch, prefix: None
    print("WARNING: Not saving the run and the config.")

In [None]:
### load ckpt
run_name = "2023-08-01_15-22-55"
ckpt = torch.load(os.path.join(DATA_PATH, "models", run_name, "decoder.pt"))

decoder.load_state_dict(ckpt["decoder"])
opter.load_state_dict(ckpt["opter"])
history = ckpt["history"]
config = ckpt["config"]
best = ckpt["best"]

make_sample_path = lambda epoch, prefix: os.path.join(
    config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
)

In [None]:
with torch.no_grad():
    print(decoder(resp.to(config["device"])).shape)
print(f"Number of parameters: {count_parameters(decoder)}")

decoder

In [None]:
def plot_losses(history, save_to=None):
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111)
    ax.plot(history["train_loss"], label="train")
    ax.plot(history["val_loss"], label="val")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()

    if save_to:
        fig.savefig(save_to)
    ### save fig
    if config["decoder"]["save_run"]:
        fig.savefig(os.path.join(config["dir"], f"losses_{epoch}.png"))

    plt.show()

In [None]:
### train
history = {"train_loss": [], "val_loss": []}
best = {"val_loss": np.inf, "epoch": 0, "model": None}
s, e = len(history["train_loss"]), len(history["train_loss"]) + config["decoder"]["n_epochs"]
for epoch in range(s, e):
    print(f"[{epoch + 1}/{e}]")

    ### train and val
    train_dataloader, val_dataloader = get_dataloaders(
        config=config,
        v1_dataloaders=v1_dataloaders,
        syn_dataloaders=syn_dataloaders,
        only_v1_data_eval=config["only_v1_data_eval"],
    )
    train_loss = train(
        model=decoder,
        dataloader=train_dataloader,
        opter=opter,
        loss_fn=loss_fn,
        config=config,
        l1_reg_mul=config["decoder"]["l1_reg_mul"],
        l2_reg_mul=config["decoder"]["l2_reg_mul"],
    )
    val_loss = val(
        model=decoder,
        dataloader=val_dataloader,
        loss_fn=loss_fn,
        config=config,
    )

    ### save best model
    if val_loss < best["val_loss"]:
        best["val_loss"] = val_loss
        best["epoch"] = epoch
        best["model"] = decoder.state_dict()

    ### log
    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    print(f"{train_loss=:.4f}, {val_loss=:.4f}")

    ### plot reconstructions
    stim_pred = decoder(resp[:8].to(config["device"])).detach()
    plot_comparison(target=stim[:8].cpu(), pred=stim_pred[:8].cpu(), save_to=make_sample_path(epoch, "no_crop_"))
    if "v1_data" in config["data"] and config["data"]["v1_data"]["crop"] == False:
        plot_comparison(target=crop_stim(stim[:8]).cpu(), pred=crop_stim(stim_pred[:8]).cpu(), save_to=make_sample_path(epoch, ""))

    ### plot losses
    if epoch % 10 == 0:
        plot_losses(
            history=history,
            save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses_{epoch}.png"),
        )

    ### ckpt
    if config["decoder"]["save_run"]:
        torch.save({
            "decoder": decoder.state_dict(),
            "opter": opter.state_dict(),
            "history": history,
            "config": config,
            "best": best,
        }, os.path.join(config["dir"], "decoder.pt"))

In [None]:
### plot reconstructions of the final model
decoder.load_state_dict(best["model"])
stim_pred_best = decoder(resp.to(config["device"])).detach().cpu()
plot_comparison(
    target=crop_stim(stim[:8]).cpu(),
    pred=crop_stim(stim_pred_best[:8]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_best.png")
)

### plot losses
plot_losses(
    history=history,
    save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses_final.png"),
)