In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pickle
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters, RunningStats
from csng.losses import SSIMLoss, MSELossWithCrop

from data import (
    prepare_v1_dataloaders,
    SyntheticDataset,
    BatchPatchesDataLoader,
    MixedBatchLoader,
    PerSampleStoredDataset,
)

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    # "device": "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [None]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Data

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    # "crop": True,
    "batch_size": 64,
    "stim_normalize_mean": 46.236,
    "stim_normalize_std": 21.196,
    "resp_normalize_mean": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_mean_from_training_dataset.npy")
    )).float(),
    "resp_normalize_std": torch.from_numpy(np.load(
        os.path.join(DATA_PATH, "responses_std_from_training_dataset.npy")
    )).float(),
}

In [None]:
### get data loaders
v1_dataloaders = prepare_v1_dataloaders(**config["data"]["v1_data"])

In [None]:
### show data
stim, resp = next(iter(v1_dataloaders["val"]))
print(
    f"{stim.shape=}, {resp.shape=}"
    f"\n{stim.min()=}, {stim.max()=}"
    f"\n{resp.min()=}, {resp.max()=}"
    f"\n{stim.mean()=}, {stim.std()=}"
    f"\n{resp.mean()=}, {resp.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp[0].view(100, 100).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

## Encoder

In [None]:
### load encoder
from data_orig import prepare_spiking_data_loaders
from lurz2020.models.models import se2d_fullgaussian2d

print("Loading encoder...")

### config only for the encoder
spiking_data_loaders_config = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 32,
}
encoder_config = {
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    "input_kern": 19,
    "hidden_kern": 17,
    "hidden_channels": 32,
    "gamma_input": 1.0,
    "gamma_readout": 2.439,
    "grid_mean_predictor": None,
    "layers": 5
}

### encoder
_dataloaders = prepare_spiking_data_loaders(**spiking_data_loaders_config)
encoder = se2d_fullgaussian2d(
    **encoder_config,
    dataloaders=_dataloaders,
    seed=2,
).float()
del _dataloaders

### load pretrained core
pretrained_core = torch.load(
    os.path.join(DATA_PATH, "models", "spiking_scratch_tunecore_68Y_model.pth"),
    map_location=config["device"],
)
encoder.load_state_dict(pretrained_core, strict=True)
encoder.to(config["device"])
encoder.eval()

## Create the synthetic data

In [None]:
data_part = "test" ## TODO: when "train", turn of shuffle
save_stats = False

trans_to_apply = [
    {
        "name": "original",
        "stim": lambda x: x,
        "resp": lambda x: x,
        "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part),
        "sample_idx": 0,
        "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    },
    # { ### noise to resp
    #     "name": "01noise_resp
    #     "stim": lambda x: x,
    #     "resp": lambda x: x + torch.randn(x.shape) * 0.1,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_01noise_resp"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### flip stim
    #     "name": "flip_stim",
    #     "stim": lambda x: x.flip(2),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_flip_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "01rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 1, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_01rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "02rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 2, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_02rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "03rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 3, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_03rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
]

### create dirs
for tran_to_apply in trans_to_apply:
    if os.path.exists(tran_to_apply["save_dir"]) and len(os.listdir(tran_to_apply["save_dir"])) > 0:
        print(f"[WARNING]: {tran_to_apply['save_dir']} already exists and is not empty.")
    os.makedirs(tran_to_apply["save_dir"], exist_ok=True)

In [None]:
n_batches = len(v1_dataloaders[data_part])

with torch.no_grad():
    ### run
    for batch_idx, (stim, _) in enumerate(v1_dataloaders[data_part]):
        for tran_to_apply in trans_to_apply:
            stim = tran_to_apply["stim"](stim.to(config["device"]))

            ### forward
            resp = encoder(stim)
            resp = tran_to_apply["resp"](resp)
            if save_stats:
                tran_to_apply["stats"].update(resp)

            ### save
            for i in range(stim.shape[0]):
                sample_path = os.path.join(tran_to_apply["save_dir"], f"{tran_to_apply['sample_idx']}.pickle")
                with open(sample_path, "wb") as f:
                    pickle.dump({
                        "stim": stim[i].cpu(),
                        "resp": resp[i].cpu(),
                    }, f)
                tran_to_apply["sample_idx"] += 1
        
        ### log
        if batch_idx % 50 == 0:
            print(f"Batch {batch_idx}/{n_batches}")

## save stats
if save_stats:
    for tran_to_apply in trans_to_apply:
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_v1_encoder", f"responses_mean_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_mean().cpu().numpy(),
        )
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_v1_encoder", f"responses_std_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_std().cpu().numpy(),
        )

In [None]:
### load
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_mean_original.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_std_original.npy"))).float()

dataset = PerSampleStoredDataset(
    dataset_dir=os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "train"),
    stim_transform=lambda x: x,
    resp_transform=csng.utils.Normalize(
        mean=resp_mean,
        std=resp_std,
    ),
)

In [None]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
stats = RunningStats(num_components=50 * 50, lib="torch", device="cpu")
for b, (s, r) in enumerate(dataloader):
    stats.update(s.view(s.shape[0], -1))
    if b % 50 == 0:
        print(f"Batch {b} processed")

In [None]:
stats.get_mean(), stats.get_std()