In [None]:
import os
import random
import numpy as np
import json
import pickle
import shutil
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import lovely_tensors as lt
import csng
from csng.utils.mix import RunningStats
from csng.utils.data import crop, normalize
from csng.cat_v1.data import get_cat_v1_dataloaders

lt.monkey_patch()
DATA_PATH = os.path.join(os.environ["DATA_PATH"], "cat_V1_spiking_model", "50K_single_trial_dataset")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "data": {
        "mixing_strategy": "sequential", # needed only with multiple base dataloaders
    },
    "crop_win": (slice(15, 35), slice(15, 35)),
    "only_v1_data_eval": True,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Split data

In [None]:
### config
subdirs = ["train", "val"]
train_ratio = 0.8
all_samples = sorted(os.listdir(os.path.join(DATA_PATH, "single_trial")))
total_samples = len(all_samples)
train_samples = int(train_ratio * total_samples)
val_samples = total_samples - train_samples
print(f"{train_samples=}, {val_samples=}")

for subdir in subdirs:
    os.makedirs(os.path.join(DATA_PATH, "datasets", subdir), exist_ok=True)

In [None]:
### split into subfolders
for sample_idx, sample_name in enumerate(all_samples):
    if sample_idx < train_samples:
        subdir = subdirs[0]
    elif sample_idx < train_samples + val_samples:
        subdir = subdirs[1]
    else:
        subdir = subdirs[2]
    
    ### move file
    stim = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "stimulus.npy"))
    exc_resp = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "V1_Exc_L23.npy"))
    inh_resp = np.load(os.path.join(DATA_PATH, "single_trial", sample_name, "V1_Inh_L23.npy"))
    # save as pickle
    with open(os.path.join(DATA_PATH, "datasets", subdir, f"{sample_name}.pickle"), "wb") as f:
        pickle.dump({
            "stim": stim,
            "exc_resp": exc_resp,
            "inh_resp": inh_resp,
        }, f)

### remove previous directory
shutil.rmtree(os.path.join(DATA_PATH, "single_trial"))

## Move and preprocess multi-trial test data

In [None]:
target_dir = os.path.join(DATA_PATH, "datasets", "test")
os.makedirs(target_dir, exist_ok=True)
test_data_path = os.path.join(DATA_PATH, "Dataset_multitrial", "Dic23data", "multitrial")
samples = sorted(os.listdir(test_data_path))
print(f"{len(samples)=},  {target_dir=}")

In [None]:
### preprocess test multi-trial data and save as pickle
for sample_name in samples:
    sample_dir = os.path.join(test_data_path, sample_name)

    ### move files
    stim = np.load(os.path.join(sample_dir, "stimulus.npy"))
    all_exc_resp = []
    all_inh_resp = []
    for trial_dir_name in os.listdir(sample_dir):
        if trial_dir_name == "stimulus.npy":
            continue
        exc_resp = np.load(os.path.join(sample_dir, trial_dir_name, "V1_Exc_L23.npy"))
        inh_resp = np.load(os.path.join(sample_dir, trial_dir_name, "V1_Inh_L23.npy"))
        all_exc_resp.append(exc_resp)
        all_inh_resp.append(inh_resp)
    exc_resp = np.stack(all_exc_resp, axis=0)
    inh_resp = np.stack(all_inh_resp, axis=0)

    ### save as pickle
    with open(os.path.join(target_dir, f"{sample_name}.pickle"), "wb") as f:
        pickle.dump({
            "stim": stim,
            "exc_resp": exc_resp,
            "inh_resp": inh_resp,
        }, f)
    
    ### remove sample_name directory
    shutil.rmtree(sample_dir)

## Get data statistics

In [None]:
config["data"]["v1_data"] = {
    "train_path": os.path.join(DATA_PATH, "datasets", "train"),
    "val_path": os.path.join(DATA_PATH, "datasets", "val"),
    "test_path": os.path.join(DATA_PATH, "datasets", "test"),
    "image_size": [50, 50],
    "crop": False,
    "batch_size": 1000,
    "stim_keys": ("stim",),
    "resp_keys": ("exc_resp", "inh_resp"),
    # "stim_normalize_mean": 46.143,
    # "stim_normalize_std": 20.420,
    # "resp_normalize_mean": torch.load(
    #     os.path.join(DATA_PATH, "responses_mean.pt")
    # ),
    # "resp_normalize_std": torch.load(
    #     os.path.join(DATA_PATH, "responses_std.pt")
    # ),
}

In [None]:
### image stimuli
v1_dataloaders = get_cat_v1_dataloaders(**config["data"]["v1_data"])
dataloader = torch.utils.data.DataLoader(v1_dataloaders["train"].dataset, batch_size=1000, shuffle=True)

mean_inputs, std_inputs = torch.zeros(1), torch.zeros(1)
for inp_idx, (inputs, targets) in enumerate(dataloader):
    for c in range(inputs.size(1)):
        mean_inputs[c] += inputs[:,c,:,:].mean((-1,-2)).mean()
        std_inputs[c] += inputs[:,c,:,:].std((-1,-2)).mean()
mean_inputs.div_(len(dataloader))
std_inputs.div_(len(dataloader))
mean_inputs, std_inputs

In [None]:
### responses
v1_dataloaders = get_cat_v1_dataloaders(**config["data"]["v1_data"])
dataloader = torch.utils.data.DataLoader(v1_dataloaders["train"].dataset, batch_size=1000, shuffle=True)
stats_all = RunningStats(num_components=46875, lib="torch", device="cpu")
stats_exc = RunningStats(num_components=37500, lib="torch", device="cpu")
stats_inh = RunningStats(num_components=9375, lib="torch", device="cpu")
for i, (s, r) in enumerate(dataloader):
    stats_all.update(r)
    stats_exc.update(r[:,:37500])
    stats_inh.update(r[:,37500:])
    if i % 200 == 0:
        print(f"{i}: {r.mean()=} {r.std()=} {stats_all.get_mean()=} {stats_all.get_std()=}")

### save
torch.save(stats_all.get_mean(), os.path.join(DATA_PATH, "responses_mean.pt"))
torch.save(stats_all.get_std(), os.path.join(DATA_PATH, "responses_std.pt"))
torch.save(stats_exc.get_mean(), os.path.join(DATA_PATH, "responses_exc_mean.pt"))
torch.save(stats_exc.get_std(), os.path.join(DATA_PATH, "responses_exc_std.pt"))
torch.save(stats_inh.get_mean(), os.path.join(DATA_PATH, "responses_inh_mean.pt"))
torch.save(stats_inh.get_std(), os.path.join(DATA_PATH, "responses_inh_std.pt"))