In [1]:
import os
import numpy as np
from datetime import datetime
import dill
import torch
from torch import nn
import torch.nn.functional as F
import lovely_tensors as lt
lt.monkey_patch()

import csng
from csng.models.inverted_encoder import InvertedEncoder, InvertedEncoderBrainreader
from csng.models.ensemble import EnsembleInvEnc
from csng.utils.mix import seed_all
from csng.utils.data import standardize, normalize, crop
from csng.utils.comparison import find_best_ckpt, load_decoder_from_ckpt, plot_reconstructions, plot_metrics, eval_decoder
from csng.losses import get_metrics
from csng.data import get_dataloaders, get_sample_data
from csng.brainreader_mouse.encoder import get_encoder as get_encoder_brainreader

In [2]:
### set paths
DATA_PATH = os.environ["DATA_PATH"]
DATA_PATH_CAT_V1 = os.path.join(DATA_PATH, "cat_V1_spiking_model", "50K_single_trial_dataset")
DATA_PATH_MOUSE_V1 = os.path.join(DATA_PATH, "mouse_v1_sensorium22")
DATA_PATH_BRAINREADER = os.path.join(DATA_PATH, "brainreader")

In [None]:
### global config
config = {
    "device": os.environ["DEVICE"],
    "seed": 0,
    "data": {
        "mixing_strategy": "sequential", # needed only with multiple base dataloaders
        "max_training_batches": None,
    },
    "crop_wins": dict(),
}

### brainreader mouse data
config["data"]["brainreader_mouse"] = {
    "device": config["device"],
    "mixing_strategy": config["data"]["mixing_strategy"],
    "max_batches": None,
    "data_dir": os.path.join(DATA_PATH_BRAINREADER, "data"),
    # "batch_size": 1,
    "batch_size": 16,
    # "sessions": list(range(1, 3)),
    "sessions": [6],
    "resize_stim_to": (36, 64),
    "normalize_stim": True,
    "normalize_resp": False,
    "div_resp_by_std": True,
    "clamp_neg_resp": False,
    "additional_keys": None,
    "avg_test_resp": True,

    # "neuron_coords": {
    #     "1": encoder.decoders[0].encoder.readout["6"].mu[0,:,0].detach().clone(),
    # }
}

config["data"]["mouse_v1"] = {
    "dataset_fn": "sensorium.datasets.static_loaders",
    "dataset_config": {
        "paths": [ # from https://gin.g-node.org/cajal/Sensorium2022/src/master
            os.path.join(DATA_PATH_MOUSE_V1, "static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-1
            # os.path.join(DATA_PATH_MOUSE_V1, "static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-2
            # os.path.join(DATA_PATH_MOUSE_V1, "static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-3
            # os.path.join(DATA_PATH_MOUSE_V1, "static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-4
            # os.path.join(DATA_PATH_MOUSE_V1, "static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip"), # M-5
        ],
        "normalize": True,
        "scale": 0.25, # 256x144 -> 64x36
        "include_behavior": False,
        "add_behavior_as_channels": False,
        "include_eye_position": True,
        "exclude": None,
        "file_tree": True,
        "cuda": "cuda" in config["device"],
        "batch_size": 5,
        "seed": config["seed"],
        "use_cache": False,
    },
    "crop_win": (22, 36),
    "skip_train": False,
    "skip_val": False,
    "skip_test": False,
    "normalize_neuron_coords": True,
    "average_test_multitrial": True,
    "save_test_multitrial": True,
    "test_batch_size": 7,
    "device": config["device"],
}

# add crop_wins for brainreader mouse data
dls, neuro_coords = get_dataloaders(config=config)

In [5]:
b = next(iter(dls["val"]["brainreader_mouse"]))

In [70]:
_b

[{'data_key': '21067-10-18',
  'stim': tensor[5, 1, 36, 64] n=11520 x∈[-1.699, 2.261] μ=0.183 σ=1.052 cuda:0,
  'resp': tensor[5, 8372] n=41860 x∈[-5.978e-10, 21.945] μ=0.343 σ=1.016 cuda:0,
  'neuron_coords': tensor[8372, 3] n=25116 x∈[-1.000, 1.000] μ=0.002 σ=0.593 cuda:0,
  'pupil_center': tensor[5, 2] n=10 x∈[0.027, 0.847] μ=0.416 σ=0.263 cuda:0 [[0.472, 0.027], [0.847, 0.214], [0.356, 0.296], [0.364, 0.612], [0.777, 0.193]]}]

In [64]:
b

[{'data_key': '6',
  'stim': tensor[16, 1, 36, 64] n=36864 x∈[-2.292, 2.874] μ=-0.125 σ=1.186 cuda:0,
  'resp': tensor[16, 8587] n=137392 x∈[-1.083e-08, 33.292] μ=0.425 σ=1.045 cuda:0,
  'neuron_coords': tensor[8587, 2] n=17174 x∈[-0.983, 0.869] μ=0.123 σ=0.321 cuda:0,
  'pupil_center': None}]

In [80]:
enc_ckpt = torch.load(os.path.join(DATA_PATH, "models", "encoder_ball.pt"), pickle_module=dill)
enc_ckpt["model"].readout["1.mu"]

AttributeError: 'collections.OrderedDict' object has no attribute 'readout'

In [5]:
encoder = EnsembleInvEnc(
    encoder_paths=[
        os.path.join(DATA_PATH, "models", "encoder_ball.pt"),
    ],
    encoder_config={
        "img_dims": (1, 36, 64),
        "stim_pred_init": "randn",
        "lr": 1000,
        "n_steps": 1000,
        "img_grad_gauss_blur_sigma": 1.5,
        "jitter": None,
        "mse_reduction": "per_sample_mean_sum",
        "device": config["device"],
    },
    use_brainreader_encoder=True,
    get_encoder_fn=get_encoder_brainreader,
    device=config["device"],
)

[INFO] Loading encoder checkpoint from /media/jan/ext_ssd/csng_data/models/encoder_ball.pt




In [24]:
encoder.decoders[0].encoder.readout["1"].shared_grid

Parameter containing:
Parameter[1, 9395, 1, 2] n=18790 x∈[-1.000, 0.960] μ=-0.013 σ=0.240 grad cuda:0

In [46]:
encoder.decoders[0].encoder.readout["1"].mu.detach().cpu()

tensor[1, 9395, 1, 2] n=18790 x∈[-1.000, 0.960] μ=-0.013 σ=0.240

In [44]:
encoder.decoders[0].encoder.readout["1"].sample_grid(batch_size=1, sample=False) == encoder.decoders[0].encoder.readout["1"].mu

tensor[1, 9395, 1, 2] n=18790 x∈[1.000, 1.000] μ=1.000 σ=0. bool cuda:0