In [1]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import json
import pickle
from datetime import datetime
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt

import csng
from csng.utils import plot_comparison, standardize, normalize, get_mean_and_std, count_parameters, RunningStats
from csng.losses import SSIMLoss, MSELossWithCrop

from data import (
    prepare_50k_v1_dataloaders,
    SyntheticDataset,
    BatchPatchesDataLoader,
    MixedBatchLoader,
    PerSampleStoredDataset,
)

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model", "50K_single_trial_dataset")
print(f"{DATA_PATH=}")

DATA_PATH='/media/jsobotka/ext_ssd/csng_data/cat_V1_spiking_model/50K_single_trial_dataset'


In [2]:
config = {
    "data": {
        "mixing_strategy": "parallel_min", # needed only with multiple base dataloaders
    },
    "stim_crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    # "device": "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

... Running on cuda ...


In [3]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

In [4]:
if config["stim_crop_win"] is not None:
    crop_stim = lambda x: x[..., config["stim_crop_win"][0], config["stim_crop_win"][1]]
else:
    crop_stim = lambda x: x

## Split data

In [None]:
### config
subdirs = ["train", "val", "test"]
train_ratio, val_ratio = 0.8, 0.125
all_samples = sorted(os.listdir(os.path.join(DATA_PATH, "single_trial")))
total_samples = len(all_samples)
train_samples, val_samples = int(train_ratio * total_samples), int(val_ratio * total_samples)
test_samples = total_samples - train_samples - val_samples
print(f"{train_samples=}, {val_samples=}, {test_samples=}")

for subdir in subdirs:
    os.makedirs(os.path.join(DATA_PATH, "datasets", subdir), exist_ok=True)

In [None]:
import shutil

### split into subfolders
for sample_idx, sample_name in enumerate(all_samples):
    if sample_idx < train_samples:
        subdir = subdirs[0]
    elif sample_idx < train_samples + val_samples:
        subdir = subdirs[1]
    else:
        subdir = subdirs[2]
    
    ### move file
    stim = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "stimulus.npy"))
    exc_resp = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "V1_Exc_L23.npy"))
    inh_resp = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "V1_Inh_L23.npy"))
    # save as pickle
    with open(os.path.join(DATA_PATH, "datasets", subdir, f"{sample_name}.pickle"), "wb") as f:
        pickle.dump({
            "stim": stim,
            "exc_resp": exc_resp,
            "inh_resp": inh_resp,
        }, f)
    # remove sample_name directory
    shutil.rmtree(os.path.join(DATA_PATH, "single_trial", sample_name))

## Move multi-trial test data

In [5]:
target_dir = os.path.join(DATA_PATH, "datasets", "test")
samples = sorted(os.listdir(os.path.join(DATA_PATH, "Dataset_multitrial", "Dic23data", "multitrial")))
print(f"{len(samples)=},  {target_dir=}")

len(samples)=250,  target_dir='/media/jsobotka/ext_ssd/csng_data/cat_V1_spiking_model/50K_single_trial_dataset/datasets/test'


In [16]:
import shutil

for sample_name in samples:
    sample_dir = os.path.join(DATA_PATH, "Dataset_multitrial", "Dic23data", "multitrial", sample_name)
    
    ### move files
    stim = np.load(os.path.join(sample_dir, "stimulus.npy"))
    all_exc_resp = []
    all_inh_resp = []
    for trial_dir_name in os.listdir(sample_dir):
        if trial_dir_name == "stimulus.npy":
            continue
        exc_resp = np.load(os.path.join(sample_dir, trial_dir_name, "V1_Exc_L23.npy"))
        inh_resp = np.load(os.path.join(sample_dir, trial_dir_name, "V1_Inh_L23.npy"))
        all_exc_resp.append(exc_resp)
        all_inh_resp.append(inh_resp)
    exc_resp = np.stack(all_exc_resp, axis=0)
    inh_resp = np.stack(all_inh_resp, axis=0)

    ### save as pickle
    with open(os.path.join(target_dir, f"{sample_name}.pickle"), "wb") as f:
        pickle.dump({
            "stim": stim,
            "exc_resp": exc_resp,
            "inh_resp": inh_resp,
        }, f)
    
    ### remove sample_name directory
    shutil.rmtree(sample_dir)

## Get data statistics

In [None]:
with open(os.path.join(DATA_PATH, "datasets", "val", f"0000045000.pickle"), "rb") as f:
    da = pickle.load(f)

In [5]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "datasets", "test"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 1000,
    "stim_keys": ("stim",),
    "resp_keys": ("exc_resp", "inh_resp"),
    # "stim_normalize_mean": 46.143,
    # "stim_normalize_std": 20.420,
    # "resp_normalize_mean": torch.from_numpy(np.load(
    #     os.path.join(DATA_PATH, "responses_mean.npy")
    # )).float(),
    # "resp_normalize_std": torch.from_numpy(np.load(
    #     os.path.join(DATA_PATH, "responses_std.npy")
    # )).float(),
}

In [None]:
v1_dataloaders = prepare_50k_v1_dataloaders(**config["data"]["v1_data"])
dataloader = torch.utils.data.DataLoader(v1_dataloaders["train"].dataset.dataset, batch_size=1000, shuffle=True)

mean_inputs, std_inputs = torch.zeros(1), torch.zeros(1)
for inp_idx, (inputs, targets) in enumerate(dataloader):
    for c in range(inputs.size(1)):
        mean_inputs[c] += inputs[:,c,:,:].mean((-1,-2)).mean()
        std_inputs[c] += inputs[:,c,:,:].std((-1,-2)).mean()
mean_inputs.div_(len(dataloader))
std_inputs.div_(len(dataloader))
mean_inputs, std_inputs

In [6]:
v1_dataloaders = prepare_50k_v1_dataloaders(**config["data"]["v1_data"])
dataloader = torch.utils.data.DataLoader(v1_dataloaders["train"].dataset.dataset, batch_size=1000, shuffle=True)
stats_all = RunningStats(num_components=46875, lib="torch", device="cpu")
stats_exc = RunningStats(num_components=37500, lib="torch", device="cpu")
stats_inh = RunningStats(num_components=9375, lib="torch", device="cpu")
for i, (s, r) in enumerate(dataloader):
    stats_all.update(r)
    stats_exc.update(r[:,:37500])
    stats_inh.update(r[:,37500:])
    if i % 200 == 0:
        print(f"{i}: {r.mean()=} {r.std()=} {stats_all.get_mean()=} {stats_all.get_std()=}")

## save
torch.save(stats_all.get_mean(), os.path.join(DATA_PATH, "responses_mean.pt"))
torch.save(stats_all.get_std(), os.path.join(DATA_PATH, "responses_std.pt"))
torch.save(stats_exc.get_mean(), os.path.join(DATA_PATH, "responses_exc_mean.pt"))
torch.save(stats_exc.get_std(), os.path.join(DATA_PATH, "responses_exc_std.pt"))
torch.save(stats_inh.get_mean(), os.path.join(DATA_PATH, "responses_inh_mean.pt"))
torch.save(stats_inh.get_std(), os.path.join(DATA_PATH, "responses_inh_std.pt"))

Train dataset size: 45000. Validation dataset size: 5000. Test dataset size: 0.
0: r.mean()=tensor[] 3.475 r.std()=tensor[] 4.659 stats_all.get_mean()=tensor[46875] x∈[0.166, 30.312] μ=3.475 σ=2.705 stats_all.get_std()=tensor[46875] x∈[0.001, 0.281] μ=0.009 σ=0.017


## Data

In [35]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "datasets", "test"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 1000,
    "stim_keys": ("stim",),
    "resp_keys": ("exc_resp", "inh_resp"),
    "stim_normalize_mean": 46.143,
    "stim_normalize_std": 20.420,
    "resp_normalize_mean": torch.load(
        os.path.join(DATA_PATH, "responses_mean.pt")
    ),
    "resp_normalize_std": torch.load(
        os.path.join(DATA_PATH, "responses_std.pt")
    ),
}

In [36]:
### get data loaders
v1_dataloaders = prepare_50k_v1_dataloaders(**config["data"]["v1_data"])

Train dataset size: 45000. Validation dataset size: 5000. Test dataset size: 250.


In [37]:
### show data
stim, resp = next(iter(v1_dataloaders["val"]))
print(
    f"{stim.shape=}, {resp.shape=}"
    f"\n{stim.min()=}, {stim.max()=}"
    f"\n{resp.min()=}, {resp.max()=}"
    f"\n{stim.mean()=}, {stim.std()=}"
    f"\n{resp.mean()=}, {resp.std()=}"
)

fig = plt.figure(figsize=(14, 6))
ax = fig.add_subplot(131)
ax.imshow(stim[0].squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(132)
ax.imshow(crop_stim(stim[0]).squeeze().unsqueeze(-1), cmap="gray")

ax = fig.add_subplot(133)
ax.imshow(resp[0].view(125, 375).squeeze(0).unsqueeze(-1), cmap="gray")

plt.show()

RuntimeError: Caught RuntimeError in pin memory thread for device 0.
Original Traceback (most recent call last):
  File "/home/jsobotka/miniconda3/envs/dev/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 34, in do_one_step
    data = pin_memory(data, device)
  File "/home/jsobotka/miniconda3/envs/dev/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 70, in pin_memory
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/home/jsobotka/miniconda3/envs/dev/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 70, in <listcomp>
    return type(data)([pin_memory(sample, device) for sample in data])  # type: ignore[call-arg]
  File "/home/jsobotka/miniconda3/envs/dev/lib/python3.10/site-packages/torch/utils/data/_utils/pin_memory.py", line 55, in pin_memory
    return data.pin_memory(device)
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



---

## Encoder

In [None]:
### load encoder
from data_orig import prepare_spiking_data_loaders
from lurz2020.models.models import se2d_fullgaussian2d

print("Loading encoder...")

### config only for the encoder
spiking_data_loaders_config = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "orig", "raw", "test.pickle"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 32,
}
encoder_config = {
    "init_mu_range": 0.55,
    "init_sigma": 0.4,
    "input_kern": 19,
    "hidden_kern": 17,
    "hidden_channels": 32,
    "gamma_input": 1.0,
    "gamma_readout": 2.439,
    "grid_mean_predictor": None,
    "layers": 5
}

### encoder
_dataloaders = prepare_spiking_data_loaders(**spiking_data_loaders_config)
encoder = se2d_fullgaussian2d(
    **encoder_config,
    dataloaders=_dataloaders,
    seed=2,
).float()
del _dataloaders

### load pretrained core
pretrained_core = torch.load(
    os.path.join(DATA_PATH, "models", "spiking_scratch_tunecore_68Y_model.pth"),
    map_location=config["device"],
)
encoder.load_state_dict(pretrained_core, strict=True)
encoder.to(config["device"])
encoder.eval()

## Create the synthetic data

In [None]:
data_part = "test" ## TODO: when "train", turn off shuffle
save_stats = False

trans_to_apply = [
    {
        "name": "original",
        "stim": lambda x: x,
        "resp": lambda x: x,
        "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part),
        "sample_idx": 0,
        "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    },
    # { ### noise to resp
    #     "name": "01noise_resp
    #     "stim": lambda x: x,
    #     "resp": lambda x: x + torch.randn(x.shape) * 0.1,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_01noise_resp"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### flip stim
    #     "name": "flip_stim",
    #     "stim": lambda x: x.flip(2),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_flip_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "01rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 1, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_01rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "02rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 2, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_02rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
    # { ### rotate stim
    #     "name": "03rotate_stim",
    #     "stim": lambda x: torch.rot90(x, 3, (1, 2)),
    #     "resp": lambda x: x,
    #     "save_dir": os.path.join(DATA_PATH, "synthetic_data_v1_encoder", data_part + "_03rotate_stim"),
    #     "sample_idx": 0,
    #     "stats": RunningStats(num_components=10000, lib="torch", device=config["device"]),
    # },
]

### create dirs
for tran_to_apply in trans_to_apply:
    if os.path.exists(tran_to_apply["save_dir"]) and len(os.listdir(tran_to_apply["save_dir"])) > 0:
        print(f"[WARNING]: {tran_to_apply['save_dir']} already exists and is not empty.")
    os.makedirs(tran_to_apply["save_dir"], exist_ok=True)

In [None]:
n_batches = len(v1_dataloaders[data_part])

with torch.no_grad():
    ### run
    for batch_idx, (stim, _) in enumerate(v1_dataloaders[data_part]):
        for tran_to_apply in trans_to_apply:
            stim = tran_to_apply["stim"](stim.to(config["device"]))

            ### forward
            resp = encoder(stim)
            resp = tran_to_apply["resp"](resp)
            if save_stats:
                tran_to_apply["stats"].update(resp)

            ### save
            for i in range(stim.shape[0]):
                sample_path = os.path.join(tran_to_apply["save_dir"], f"{tran_to_apply['sample_idx']}.pickle")
                with open(sample_path, "wb") as f:
                    pickle.dump({
                        "stim": stim[i].cpu(),
                        "resp": resp[i].cpu(),
                    }, f)
                tran_to_apply["sample_idx"] += 1
        
        ### log
        if batch_idx % 50 == 0:
            print(f"Batch {batch_idx}/{n_batches}")

## save stats
if save_stats:
    for tran_to_apply in trans_to_apply:
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_v1_encoder", f"responses_mean_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_mean().cpu().numpy(),
        )
        np.save(
            os.path.join(DATA_PATH, "synthetic_data_v1_encoder", f"responses_std_{tran_to_apply['name']}.npy"),
            tran_to_apply["stats"].get_std().cpu().numpy(),
        )

In [None]:
### load
resp_mean = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_mean_original.npy"))).float()
resp_std = torch.from_numpy(np.load(os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "responses_std_original.npy"))).float()

dataset = PerSampleStoredDataset(
    dataset_dir=os.path.join(DATA_PATH, "synthetic_data_v1_encoder", "train"),
    stim_transform=lambda x: x,
    resp_transform=csng.utils.Normalize(
        mean=resp_mean,
        std=resp_std,
    ),
)

In [None]:
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

In [None]:
stats = RunningStats(num_components=50 * 50, lib="torch", device="cpu")
for b, (s, r) in enumerate(dataloader):
    stats.update(s.view(s.shape[0], -1))
    if b % 50 == 0:
        print(f"Batch {b} processed")

In [None]:
stats.get_mean(), stats.get_std()