In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
from copy import deepcopy
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.GAN import GAN
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, MSELossWithCrop

from mypkg.visualization import LivePlot

from data import (
    prepare_v1_dataloaders,
    SyntheticDataset,
    BatchPatchesDataLoader,
    MixedBatchLoader,
    PerSampleStoredDataset,
)

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [None]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Data

In [None]:
dataloaders = dict()

### V1 dataset (spiking model of cat V1)

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    # "crop": True,
    # "batch_size": 64,
    "batch_size": 20,
    "stim_normalize_mean": 46.236,
    "stim_normalize_std": 21.196,
    "resp_normalize_mean": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_mean_from_training_dataset.npy")
    )).float(),
    "resp_normalize_std": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_std_from_training_dataset.npy")
    )).float(),
}

In [None]:
### get data loaders
dataloaders["v1_data"] = prepare_v1_dataloaders(**config["data"]["v1_data"])

In [None]:
### show data
stim_sample, resp_sample = next(iter(dataloaders["v1_data"]["val"]))
print(
    f"{stim_sample.shape=}, {resp_sample.shape=}"
    f"\n{stim_sample.min()=}, {stim_sample.max()=}"
    f"\n{resp_sample.min()=}, {resp_sample.max()=}"
    f"\n{stim_sample.mean()=}, {stim_sample.std()=}"
    f"\n{resp_sample.mean()=}, {resp_sample.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim_sample[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim_sample[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp_sample[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

### Synthetic data (different stimuli dataset -> encoder -> neuronal responses)

In [None]:
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data", "responses_mean.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data", "responses_std.npy"))).float()

config["data"]["syn_data"] = {
    "dataset": {
        "stim_transform": transforms.Normalize(
            mean=114.457,
            std=51.356,
        ),
        "resp_transform": csng.utils.Normalize(
            mean=resp_mean,
            std=resp_std,
        ),
    },
    "dataloader": {
        "batch_size": 20,
        "shuffle": True,
    }
}

In [None]:
syn_datasets = {
    "train": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "train"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "val": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "val"),
        **config["data"]["syn_data"]["dataset"]
    ),
    "test": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data", "processed", "test"),
        **config["data"]["syn_data"]["dataset"]
    ),
}

dataloaders["syn_data"] = {
    "train": DataLoader(
        dataset=syn_datasets["train"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "val": DataLoader(
        dataset=syn_datasets["val"],
        **config["data"]["syn_data"]["dataloader"],
    ),
    "test": DataLoader(
        dataset=syn_datasets["test"],
        **config["data"]["syn_data"]["dataloader"],
    ),
}

In [None]:
### show data
syn_stim_sample, syn_resp_sample = next(iter(dataloaders["syn_data"]["val"]))
print(
    f"{syn_stim_sample.shape=}, {syn_resp_sample.shape=}"
    f"\n{syn_stim_sample.min()=}, {syn_stim_sample.max()=}"
    f"\n{syn_resp_sample.min()=}, {syn_resp_sample.max()=}"
    f"\n{syn_stim_sample.mean()=}, {syn_stim_sample.std()=}"
    f"\n{syn_resp_sample.mean()=}, {syn_resp_sample.std()=}"
)

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim_sample.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(syn_stim_sample.cpu()[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(syn_resp_sample.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

### Synthetic data (V1 data stimuli -> encoder -> neuronal responses)

In [None]:
### load
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_mean_original.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_std_original.npy"))).float()

config["data"]["syn_data_v1_enc"] = {
    "dataset": {
        "stim_transform": transforms.Normalize(
            mean=0,
            std=1,
        ),
        "resp_transform": csng.utils.Normalize(
            mean=resp_mean,
            std=resp_std,
        ),
    },
    "dataloader": {
        "batch_size": 20,
        "shuffle": True,
    }
}

In [None]:
syn_datasets_v1_encoder = {
    "train": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "train"),
        **config["data"]["syn_data_v1_enc"]["dataset"]
    ),
    "val": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "val"),
        **config["data"]["syn_data_v1_enc"]["dataset"]
    ),
    "test": PerSampleStoredDataset(
        dataset_dir=os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "test"),
        **config["data"]["syn_data_v1_enc"]["dataset"]
    ),
}

dataloaders["syn_data_v1_enc"] = {
    "train": DataLoader(
        dataset=syn_datasets_v1_encoder["train"],
        **config["data"]["syn_data_v1_enc"]["dataloader"],
    ),
    "val": DataLoader(
        dataset=syn_datasets_v1_encoder["val"],
        **config["data"]["syn_data_v1_enc"]["dataloader"],
    ),
    "test": DataLoader(
        dataset=syn_datasets_v1_encoder["test"],
        **config["data"]["syn_data_v1_enc"]["dataloader"],
    ),
}

In [None]:
### show data
syn_stim_v1_enc, syn_resp_v1_enc = next(iter(dataloaders["syn_data_v1_enc"]["val"]))
print(
    f"{syn_stim_v1_enc.shape=}, {syn_resp_v1_enc.shape=}"
    f"\n{syn_stim_v1_enc.min()=}, {syn_stim_v1_enc.max()=}"
    f"\n{syn_resp_v1_enc.min()=}, {syn_resp_v1_enc.max()=}"
    f"\n{syn_stim_v1_enc.mean()=}, {syn_stim_v1_enc.std()=}"
    f"\n{syn_resp_v1_enc.mean()=}, {syn_resp_v1_enc.std()=}"
)

fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(131)
ax.imshow(syn_stim_v1_enc.cpu()[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(syn_stim_v1_enc.cpu()[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(syn_resp_v1_enc.cpu()[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Decoder

In [None]:
def train(model, dataloader, opter, loss_fn, config, l1_reg_mul=0, l2_reg_mul=0, verbose=True):
    model.train()
    train_loss = 0
    n_batches = len(dataloader)
    
    ### run
    for batch_idx, (stim, resp) in enumerate(dataloader):
        ### data
        stim = stim.to(config["device"])
        resp = resp.to(config["device"])
        
        ### train
        opter.zero_grad()
        stim_pred = model(resp)
        loss = loss_fn(stim_pred, stim)

        ### regularization
        if l1_reg_mul != 0:
            l1_reg = sum(p.abs().sum() for n, p in model.named_parameters() if p.requires_grad and "weight" in n)
            loss += l1_reg_mul * l1_reg
        if l2_reg_mul != 0:
            l2_reg = sum(p.pow(2.0).sum() for n, p in model.named_parameters() if p.requires_grad and "weight" in n)
            loss += l2_reg_mul * l2_reg

        loss.backward()
        opter.step()
        
        ### log
        train_loss += loss.item()
        if verbose and batch_idx % 100 == 0:
            print(f"Training progress: [{batch_idx}/{n_batches} ({100. * batch_idx / n_batches:.0f}%)]"
                  f"  Loss: {loss.item():.6f}")
        batch_idx += 1

    train_loss /= n_batches
    return train_loss

In [None]:
def val(model, dataloader, loss_fn, config):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (stim, resp) in enumerate(dataloader):
            stim = stim.to(config["device"])
            resp = resp.to(config["device"])
            
            stim_pred = model(resp)
            loss = loss_fn(stim_pred, stim)
            
            ### log
            val_loss += loss.item()
    
    val_loss /= len(dataloader)
    return val_loss

In [None]:
def get_dataloaders(config, dataloaders, use_data_names, only_v1_data_eval=True):
    if only_v1_data_eval:
        val_dataloader = dataloaders["v1_data"]["val"]

    ### get dataloaders to mix
    dataloaders_to_mix = []
    for data_name in use_data_names:
        dataloaders_to_mix.append(dataloaders[data_name])

    if len(dataloaders_to_mix) > 1:
        train_dataloader = MixedBatchLoader(
            dataloaders=[dl["train"] for dl in dataloaders_to_mix],
            mixing_strategy=config["data"]["mixing_strategy"],
            device=config["device"],
        )
        if not only_v1_data_eval:
            val_dataloader = MixedBatchLoader(
                dataloaders=[dl["val"] for dl in dataloaders_to_mix],
                mixing_strategy=config["data"]["mixing_strategy"],
                device=config["device"],
            )
    elif len(dataloaders_to_mix) == 1:
        train_dataloader = dataloaders_to_mix[0]["train"]
        if not only_v1_data_eval:
            val_dataloader = dataloaders_to_mix[0]["val"]
    else:
        raise ValueError("No data to train on.")

    return train_dataloader, val_dataloader

In [None]:
config["decoder"] = {
    "model": {
        "G_kwargs": {
            "in_shape": resp_sample.shape[1:],
            # "layers": [
            #     ("fc", 637),
            #     ("unflatten", 1, (13, 7, 7)),
            #     ("deconv", 256, 7, 2, 0),
            #     ("deconv", 128, 5, 2, 0),
            #     ("deconv", 64, 5, 1, 0),
            #     ("deconv", 64, 4, 1, 0),
            #     ("deconv", 1, 3, 1, 0),
            # ],
            "layers": [
                ("fc", 384),  # CNN Baseline
                ("unflatten", 1, (6, 8, 8)),  # CNN Baseline
                ("deconv", 256, 7, 2, 0),
                ("deconv", 128, 5, 2, 0),
                ("deconv", 64, 4, 1, 0),  # CNN Baseline
                ("deconv", 1, 3, 1, 0),
            ],
            "act_fn": nn.ReLU,
            # "out_act_fn": nn.Tanh,
            "out_act_fn": nn.Identity,
            "dropout": 0.2,
            "batch_norm": True,
        },
        "D_kwargs": {
            "in_shape": crop_stim(stim_sample).shape[1:],
            "layers": [
                ("conv", 128, 4, 1, 2),
                # ("conv", 128, 4, 2, 1),
                ("conv", 128, 4, 1, 0),
                # ("conv", 64, 4, 1, 0),
                ("conv", 64, 4, 1, 0),
                ("conv", 32, 3, 1, 0),
                ("fc", 1),
            ],
            "act_fn": nn.ReLU,
            "out_act_fn": nn.Sigmoid,
            # "out_act_fn": nn.Identity,
            "dropout": 0.3,
            "batch_norm": True,
        },
        "G_optim_kwargs": {"lr": 1e-4, "betas": (0.5, 0.999)},
        "D_optim_kwargs": {"lr": 5e-5, "betas": (0.5, 0.999)},
    },
    "stim_loss_fn": SSIMLoss(
        window=config["stim_crop_win"],
        log_loss=True,
        inp_normalized=True,
    ),
    # "l1_reg_mul": 0,
    # "l2_reg_mul": 5e-5,
    "n_epochs": 150,
    "save_run": True,
}

gan = GAN(**config["decoder"]["model"]).to(config["device"])

In [None]:
### prepare checkpointing
if config["decoder"]["save_run"]:
    ### save config
    run_name = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    config["dir"] = os.path.join(DATA_PATH, "models", "gan", run_name)
    os.makedirs(config["dir"], exist_ok=True)
    with open(os.path.join(config["dir"], "config.json"), "w") as f:
        json.dump(config, f, indent=4, default=str)
    os.makedirs(os.path.join(config["dir"], "samples"), exist_ok=True)
    os.makedirs(os.path.join(config["dir"], "ckpt"), exist_ok=True)
    make_sample_path = lambda epoch, prefix: os.path.join(
        config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
    )

    print(f"Run name: {run_name}\nRun dir: {config['dir']}")
else:
    make_sample_path = lambda epoch, prefix: None
    print("WARNING: Not saving the run and the config.")

In [None]:
### load ckpt
run_name = "2023-08-27_14-13-45"
ckpt = torch.load(os.path.join(DATA_PATH, "models", "gan", run_name, "ckpt", "decoder_15.pt"))

history = ckpt["history"]
config = ckpt["config"]
best = ckpt["best"]

gan = GAN(**config["decoder"]["model"]).to(config["device"])
gan.load_state_dict(ckpt["decoder"])

make_sample_path = lambda epoch, prefix: os.path.join(
    config["dir"], "samples", f"{prefix}stim_comparison_{epoch}e.png"
)

In [None]:
### show model
with torch.no_grad():
    print(
        f"Output shapes:"
        f"\n\tG: {gan(resp_sample.to(config['device'])).cpu().shape}"
        f"\n\tD: {gan.D(crop_stim(stim_sample).to(config['device'])).cpu().shape}"
    )
print(
    f"Number of parameters:"
    f"\n\tG: {count_parameters(gan.G)}"
    f"\n\tD: {count_parameters(gan.D)}"
)
gan

In [None]:
def plot_losses(history, save_to=None):
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_subplot(111)
    ax.plot(history["val_loss"], label="val")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.legend()

    if save_to:
        fig.savefig(save_to)
    ### save fig
    if config["decoder"]["save_run"]:
        fig.savefig(os.path.join(config["dir"], f"losses_{epoch}.png"))

    plt.show()

In [None]:
log_freq = 100
data_names = list(dataloaders.keys())
print(f"{data_names=}")

history = {k: [] for k in (
    "val_loss", "D_loss", "G_loss", "G_loss_stim", "G_loss_adv",
    "D_mean_abs_grad_first_layer", "D_mean_abs_grad_last_layer",
    "G_mean_abs_grad_first_layer", "G_mean_abs_grad_last_layer"
)}
live_plot = LivePlot(
    figsize=(22, 24),
    groups=[k for k in history.keys() if k not in ["val_loss"]],
    use_seaborn=True,
)
best = {"val_loss": np.inf, "epoch": 0, "decoder": None}
s, e = len(history["val_loss"]), config["decoder"]["n_epochs"]

In [None]:
### train
gan.train()
for epoch in range(s, e):
    print(f"[{epoch + 1}/{config['decoder']['n_epochs']}]")

    ### get data
    train_dataloader, val_dataloader = get_dataloaders(
        config=config,
        dataloaders=dataloaders,
        use_data_names=data_names,
        only_v1_data_eval=config["only_v1_data_eval"],
    )

    ### training epoch
    for batch_idx, (stim, resp) in enumerate(train_dataloader):
        resp = resp.to(config["device"])
        stim = stim.to(config["device"])

        ### update discriminator
        gan.D_optim.zero_grad()
        real_stim_pred = gan.D(crop_stim(stim))
        ### add noise to labels (uniform distribution between 0.9 and 1)
        noisy_real_stim_labels = torch.rand_like(real_stim_pred) * 0.05 + 0.95
        # noisy_real_stim_labels = torch.ones_like(real_stim_pred)
        real_stim_loss = torch.mean((real_stim_pred - noisy_real_stim_labels)**2) / 2.
        # real_stim_loss = torch.mean((real_stim_pred - 1.)**2) / 2.

        stim_pred = gan.G(resp)
        fake_stim_pred = gan.D(crop_stim(stim_pred.detach()))
        ### add noise to labels (uniform distribution between 0 and 0.1)
        noisy_fake_stim_labels = torch.rand_like(fake_stim_pred) * 0.05
        # noisy_fake_stim_labels = torch.zeros_like(fake_stim_pred)
        fake_stim_loss = torch.mean((fake_stim_pred - noisy_fake_stim_labels)**2) / 2.
        # fake_stim_loss = torch.mean(fake_stim_pred**2) / 2.

        D_loss = real_stim_loss + fake_stim_loss
        D_loss.backward()

        ### clip gradients
        for p in gan.D.parameters():
            p.grad.data.clamp_(-1., 1.)

        history["D_mean_abs_grad_first_layer"].append(torch.mean(torch.abs(gan.D.layers[0].weight.grad)).item())
        history["D_mean_abs_grad_last_layer"].append(torch.mean(torch.abs(gan.D.layers[-2].weight.grad)).item())
        gan.D_optim.step()

        ### update generator
        gan.G_optim.zero_grad()
        stim_pred = gan.G(resp)
        fake_stim_pred = gan.D(crop_stim(stim_pred))

        G_loss_adv = torch.mean((fake_stim_pred - 1.)**2)
        G_loss_stim = config["decoder"]["stim_loss_fn"](stim_pred, stim) / 2
        G_loss = G_loss_adv + G_loss_stim
        G_loss.backward()

        # clip gradients
        for p in gan.G.parameters():
            p.grad.data.clamp_(-1., 1.)

        history["G_mean_abs_grad_first_layer"].append(torch.mean(torch.abs(gan.G.layers[0].weight.grad)).item())
        history["G_mean_abs_grad_last_layer"].append(torch.mean(torch.abs(gan.G.layers[-2].weight.grad)).item())
        gan.G_optim.step()

        ### log
        history["D_loss"].append(D_loss.item())
        history["G_loss"].append(G_loss.item())
        history["G_loss_stim"].append(G_loss_stim.item())
        history["G_loss_adv"].append(G_loss_adv.item())

        if batch_idx % log_freq == 0 and batch_idx > 0:
            print(
                f"[{epoch + 1}/{config['decoder']['n_epochs']}  {batch_idx * len(resp)}/{len(train_dataloader) * len(resp)}  "
                f"({100. * batch_idx / len(train_dataloader):.0f}%)]  "
                f"D-loss: {D_loss.item():.4f}  "
                f"G-loss: {G_loss.item():.4f}  "
                f"G-loss-stim: {G_loss_stim.item():.4f}  "
                f"G-loss-adv: {G_loss_adv.item():.4f}"
            )

            live_plot.update({
                k: history[k][-log_freq:] for k in history.keys()
                if k not in ["val_loss"]               
            }, display=True)
    
    ### eval
    val_loss = val(
        model=gan,
        dataloader=val_dataloader,
        loss_fn=config["decoder"]["stim_loss_fn"],
        config=config,
    )

    ### save best model
    if val_loss < best["val_loss"]:
        best["val_loss"] = val_loss
        best["epoch"] = epoch
        best["decoder"] = deepcopy(gan.state_dict())

    ### log
    history["val_loss"].append(val_loss)
    print(f"{val_loss=:.4f}")

    ### plot reconstructions
    stim_pred = gan(resp_sample[:8].to(config["device"])).detach()
    plot_comparison(
        target=crop_stim(stim_sample[:8]).cpu(),
        pred=crop_stim(stim_pred[:8]).cpu(),
        save_to=make_sample_path(epoch, "")
    )

    ### plot losses
    if epoch % 5 == 0 and epoch > 0:
        plot_losses(history=history)

        ### ckpt
        if config["decoder"]["save_run"]:
            torch.save({
                "decoder": gan.state_dict(),
                "history": history,
                "config": config,
                "best": best,
            }, os.path.join(config["dir"], "ckpt", f"decoder_{epoch}.pt"))
    

In [None]:
print(f"Best val loss: {best['val_loss']:.4f} at epoch {best['epoch']}")

### save final
torch.save({
    "decoder": gan.state_dict(),
    "history": history,
    "config": config,
    "best": best,
}, os.path.join(config["dir"], f"decoder.pt"))

### plot reconstructions of the final model
gan.load_state_dict(best["decoder"])
stim_pred_best = gan(resp_sample.to(config["device"])).detach().cpu()
plot_comparison(
    target=crop_stim(stim_sample[:8]).cpu(),
    pred=crop_stim(stim_pred_best[:8]).cpu(),
    save_to=os.path.join(config["dir"], "stim_comparison_best.png")
)

### plot losses
plot_losses(
    history=history,
    save_to=None if not config["decoder"]["save_run"] else os.path.join(config["dir"], f"losses.png"),
)

In [None]:
live_plot.draw()

In [None]:
n_axes = len(history.keys())
fig = plt.figure(figsize=(24, 22))

for k_i, k in enumerate(history.keys()):
    ax = fig.add_subplot((n_axes // 3) + 1, 3, k_i + 1)
    ax.plot(history[k])
    ax.set_xlabel("Epoch")
    ax.set_title(k)