In [None]:
import os
import random
import torch
from nnfabrik.builder import get_data

import numpy as np
import matplotlib.pyplot as plt
import lovely_tensors as lt

lt.monkey_patch()

DATA_PATH = os.path.join(os.environ["DATA_PATH"], "mouse_v1_sensorium22")
print(f"{DATA_PATH=}")

In [None]:
config = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "seed": 0,
}

print(f"... Running on {config['device']} ...")

In [None]:
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
random.seed(config["seed"])

## Data

In [None]:
filenames = [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
    # "static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # mouse 1
    # "static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # sensorium+ (mouse 2)
    # "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 3)
    # "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 4)
    "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 5)
    # "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 6)
    # "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip", # pretraining (mouse 7)
]
for f_idx, f_name in enumerate(filenames):
    filenames[f_idx] = os.path.join(DATA_PATH, f_name)

config["data"] = {
    "paths": filenames,
    "dataset_fn": "sensorium.datasets.static_loaders",
    "dataset_config": {
        "paths": filenames,
        "normalize": True,
        "scale": 0.25, # 256x144 -> 64x36
        "include_behavior": False,
        "add_behavior_as_channels": False,
        "include_eye_position": False,
        "exclude": None,
        "file_tree": True,
        "cuda": False,
        "batch_size": 128,
        "seed": config["seed"],
    },
}

data_key = "23343-5-17"
_dataloaders = get_data(config["data"]["dataset_fn"], config["data"]["dataset_config"])
dataloaders = {
    "mouse_v1": {
        "train": _dataloaders["train"][data_key],
        "val": _dataloaders["validation"][data_key],
        "test": _dataloaders["test"][data_key],
        "test_no_resp": _dataloaders["final_test"][data_key],
    }
}

In [None]:
### show data
stim, resp = next(iter(dataloaders["mouse_v1"]["val"]))
cell_motor_coordinates = dataloaders["mouse_v1"]["train"].dataset.neurons.cell_motor_coordinates
print(
    f"{len(dataloaders['mouse_v1']['val'].dataset)} samples"

    "\n\nstimuli:"
    f"\n  {stim.shape}"
    f"\n  min={stim.min().item():.3f}  max={stim.max().item():.3f}"
    f"\n  mean={stim.mean().item():.3f}  std={stim.std().item():.3f}"
    "\nresponses:"
    f"\n  {resp.shape}"
    f"\n  min={resp.min().item():.3f}  max={resp.max().item():.3f}"
    f"\n  mean={resp.mean().item():.3f}  std={resp.std().item():.3f}"
    "\ncell coordinates:"
    f"\n  {cell_motor_coordinates.shape}"
    f"\n  min={cell_motor_coordinates.min():.3f}  max={cell_motor_coordinates.max():.3f}"
    f"\n  mean={cell_motor_coordinates.mean():.3f}  std={cell_motor_coordinates.std():.3f}"
)

fig = plt.figure(figsize=(16, 6))
ax = fig.add_subplot(121)
ax.imshow(stim[0].squeeze().unsqueeze(-1).cpu(), cmap="gray")

ax = fig.add_subplot(122)
ax.imshow(resp[0].view(38, 193).squeeze(0).unsqueeze(-1).cpu(), cmap="gray")

plt.show()