In [None]:
import os
import random
import numpy as np
import pickle
import matplotlib.pyplot as plt
import json
import pandas as pd
from datetime import datetime
from copy import deepcopy
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt
from nnfabrik.builder import get_data
from lurz2020.datasets.mouse_loaders import static_loaders
from lurz2020.models.models import se2d_fullgaussian2d
from lurz2020.training.trainers import standard_trainer as trainer
from lurz2020.utility.measures import get_correlations, get_fraction_oracles

import csng
from csng.CNN_Decoder import CNN_Decoder
from csng.utils import RunningStats, plot_comparison, standardize, normalize, get_mean_and_std, count_parameters
from csng.losses import SSIMLoss, MSELossWithCrop
from csng.data import MixedBatchLoader

from models import MultiReadInCNN
from data_utils import get_mouse_v1_data, PerSampleStoredDataset

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "mouse_v1_sensorium22")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "sequential", # needed only with multiple base dataloaders
    },
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Data

In [None]:
dataloaders = dict()

### Mouse V1 dataset (Sensorium 2022)

In [None]:
### prep data config
filenames = [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
    # "static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # mouse 1
    # "static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # sensorium+ (mouse 2)
    # "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 3)
    # "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 4)
    # "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 5)
    # "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 6)
    "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 7)
]
for f_idx, f_name in enumerate(filenames):
    filenames[f_idx] = os.path.join(DATA_PATH, f_name)

config["data"].update({
    "paths": filenames,
    "dataset_fn": "sensorium.datasets.static_loaders",
    "dataset_config": {
        "paths": filenames,
        "normalize": True,
        "scale": 0.25, # 256x144 -> 64x36
        "include_behavior": False,
        "add_behavior_as_channels": False,
        "include_eye_position": True,
        "exclude": None,
        "file_tree": True,
        "cuda": False,
        "batch_size": 128,
        "seed": config["seed"],
        "use_cache": False,
    },
    "normalize_neuron_coords": True,
})

In [None]:
## get dataloaders and cell coordinates
dataloaders, neuron_coords = get_mouse_v1_data(config)

In [None]:
### show data
sample_data_key = dataloaders["mouse_v1"]["val"].data_keys[0]
datapoint = next(iter(dataloaders["mouse_v1"]["val"].dataloaders[0]))
stim, resp = datapoint.images, datapoint.responses
pupil_center = datapoint.pupil_center
print(
    f"Training dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['train'].dataloaders)} samples"
    f"\nValidation dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['val'].dataloaders)} samples"
    f"\nTest dataset:\t\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['test'].dataloaders)} samples"
    f"\nTest (no resp) dataset:\t {sum(len(dl) * config['data']['dataset_config']['batch_size'] for dl in dataloaders['mouse_v1']['test_no_resp'].dataloaders)} samples"

    "\n\nstimuli:"
    f"\n  {stim.shape}"
    f"\n  min={stim.min().item():.3f}  max={stim.max().item():.3f}"
    f"\n  mean={stim.mean().item():.3f}  std={stim.std().item():.3f}"
    "\nresponses:"
    f"\n  {resp.shape}"
    f"\n  min={resp.min().item():.3f}  max={resp.max().item():.3f}"
    f"\n  mean={resp.mean().item():.3f}  std={resp.std().item():.3f}"
    "\nneuronal coordinates:"
    f"\n  {neuron_coords[sample_data_key].shape}"
    f"\n  min={neuron_coords[sample_data_key].min():.3f}  max={neuron_coords[sample_data_key].max():.3f}"
    f"\n  mean={neuron_coords[sample_data_key].mean():.3f}  std={neuron_coords[sample_data_key].std():.3f}"
)

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(121)
ax.imshow(stim[0].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(122)
reshape_to = None
for i in range(30, 150):
    if resp.shape[-1] % i == 0:
        reshape_to = (i, resp.shape[-1] // i)
        break
if reshape_to != None:
    ax.imshow(resp[0].view(reshape_to).squeeze(0).unsqueeze(-1).cpu(), cmap="gray")

plt.show()

## Encoder

In [None]:
### load encoder
print("Loading encoder...")

### load pretrained encoder ckpt
encoder_ckpt = torch.load(
    os.path.join(DATA_PATH, "models", "encoder.pt"),
    map_location=config["device"],
)

### get temporary dataloaders for the encoder
_dataloaders = get_data(
    encoder_ckpt["config"]["data"]["dataset_fn"],
    encoder_ckpt["config"]["data"]["dataset_config"]
)

### init encoder
encoder = se2d_fullgaussian2d(
    **encoder_ckpt["config"]["encoder"]["model_config"],
    dataloaders=_dataloaders,
    seed=encoder_ckpt["config"]["seed"],
).float()
encoder.load_state_dict(encoder_ckpt["encoder_state"], strict=True)
encoder.to(config["device"])
encoder.eval()

In [None]:
### validate encoder is working (corr on val set should be ~ 0.32)
train_correlation = get_correlations(encoder, _dataloaders["train"], device=config["device"], as_dict=False, per_neuron=False)
validation_correlation = get_correlations(encoder, _dataloaders["validation"], device=config["device"], as_dict=False, per_neuron=False)
test_correlation = get_correlations(encoder, _dataloaders["test"], device=config["device"], as_dict=False, per_neuron=False)

print(
    f"Correlation (train set):      {train_correlation:.3f}"
    f"\nCorrelation (validation set): {validation_correlation:.3f}"
    f"\nCorrelation (test set):       {test_correlation:.3f}"
)

## Create synthetic data

In [None]:
data_part = "test"
assert len(dataloaders["mouse_v1"][data_part].data_keys) == 1,\
    "Create synthetic datasets one by one."
data_key = dataloaders["mouse_v1"][data_part].data_keys[0]
save_stats = False

print(f"{data_part=}  {data_key=}  {save_stats=}")

trans_to_apply = [
    {
        "name": "original",
        "stim": lambda x: x,
        "resp": lambda x: x,
        "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, data_part),
        "sample_idx": 0,
        "stats": RunningStats(num_components=encoder.readout[data_key].outdims, lib="torch", device=config["device"]),
    },
    # { ### noise to resp
    #     "name": "01noise_resp
    #     "stim": lambda x: x,
    #     "resp": lambda x: x + torch.randn(x.shape) * 0.1,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_part + "_01noise_resp"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### flip stim
    #     "name": "flip_stim",
    #     "stim": lambda x: x.flip(2),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_part + "_flip_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "01rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 1, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_part + "_01rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "02rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 2, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_part + "_02rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "03rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 3, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_part + "_03rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
]

### create dirs
for tran_to_apply in trans_to_apply:
    if os.path.exists(tran_to_apply["save_dir"]) and len(os.listdir(tran_to_apply["save_dir"])) > 0:
        print(f"[WARNING]: {tran_to_apply['save_dir']} already exists and is not empty.")
    os.makedirs(tran_to_apply["save_dir"], exist_ok=True)

In [None]:
n_batches = len(dataloaders["mouse_v1"][data_part])

with torch.no_grad():
    ### run
    for batch_idx, b in enumerate(dataloaders["mouse_v1"][data_part]):
        for _data_key, (stim, _, neuron_coords, pupil_center) in b.items():
            assert _data_key == data_key, f"Data key mismatch: {data_key} vs. {_data_key}"
            for tran_to_apply in trans_to_apply:
                stim = tran_to_apply["stim"](stim.to(config["device"]))

                ### forward
                resp = encoder(stim, data_key=data_key)
                resp = tran_to_apply["resp"](resp)
                if save_stats:
                    tran_to_apply["stats"].update(resp)

                ### save
                for i in range(stim.shape[0]):
                    sample_path = os.path.join(tran_to_apply["save_dir"], f"{tran_to_apply['sample_idx']}.pickle")
                    with open(sample_path, "wb") as f:
                        pickle.dump({
                            "stim": stim[i].cpu(),
                            "resp": resp[i].cpu(),
                            "pupil_center": pupil_center[i].cpu(),
                        }, f)
                    tran_to_apply["sample_idx"] += 1

        ### log
        if batch_idx % 50 == 0:
            print(f"Batch {batch_idx}/{n_batches}")

## save stats of responses
if save_stats:
    for tran_to_apply in trans_to_apply:
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, f"responses_mean_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_mean().cpu().numpy(),
        )
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, f"responses_std_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_std().cpu().numpy(),
        )

## save neuron_coords
p = os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, f"neuron_coords.npy")
np.save(p, neuron_coords.cpu().numpy())

In [None]:
### load
# resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", target_data_key, f"responses_mean_original.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, f"responses_std_original.npy"))).float()

div_by = resp_std.clone()
thres = 0.01 * resp_std.mean()
idx = resp_std <= thres
div_by[idx] = thres

dataset = PerSampleStoredDataset(
    dataset_dir=os.path.join(DATA_PATH, "synthetic_data_mouse_v1_encoder", data_key, data_part),
    stim_transform=lambda x: x, # stim is already normalized
    resp_transform=csng.utils.Normalize(
        # mean=resp_mean,
        mean=0,
        std=div_by,
    ),
)

In [None]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
stats = RunningStats(num_components=resp.shape[-1], lib="torch", device="cpu")
for b, (s, r, pc) in enumerate(dataloader):
    # stats.update(s.view(s.shape[0], -1))
    stats.update(r)
    if b % 50 == 0:
        print(f"Batch {b} processed")

In [None]:
stats.get_mean(), stats.get_std()