In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '6'
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from os import path

from torch import nn
from torch.nn import functional as F
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from tqdm import tqdm
from einops import rearrange
from torch.optim import AdamW, Adam

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from omegaconf import OmegaConf, open_dict
from experanto.datasets import ChunkDataset, SimpleChunkedDataset
from experanto.utils import LongCycler, MultiEpochsDataLoader

In [2]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.set_float32_matmul_precision('high')

# Hyperparameters

In [3]:
video_size = [36, 64]
batchsize=16

screen_chunk_size = 30
screen_fs = 30
chunk_size = 8

behavior_as_channels = True
replace_nans_with_means = True

dim_head = 64
num_heads = 2
drop_path_rate = 0
mlp_ratio=4

### get dataloaders

In [4]:
from experanto.dataloaders import get_multisession_dataloader

from experanto.configs import DEFAULT_CONFIG as cfg
paths = ['dynamic29513-3-5-Video-full',
         'dynamic29514-2-9-Video-full',
         'dynamic29755-2-8-Video-full',
         'dynamic29647-19-8-Video-full',
         'dynamic29156-11-10-Video-full',
         'dynamic29623-4-9-Video-full',
         'dynamic29515-10-12-Video-full',
         'dynamic29234-6-9-Video-full',
         'dynamic29712-5-9-Video-full',
         'dynamic29228-2-10-Video-full'
        ]
full_paths = [path.join("/data/mouse_polly/", f) for f in paths]

In [5]:
with open_dict(cfg):
    cfg.dataset.add_behavior_as_channels = behavior_as_channels
    cfg.dataset.replace_nans_with_means = replace_nans_with_means
cfg.dataset.global_chunk_size = None
cfg.dataset.global_sampling_rate = None

cfg.dataset.modality_config.screen.chunk_size = screen_chunk_size
cfg.dataset.modality_config.screen.sampling_rate = screen_fs
cfg.dataset.modality_config.responses.chunk_size = chunk_size
cfg.dataset.modality_config.responses.sampling_rate = 8
cfg.dataset.modality_config.eye_tracker.chunk_size = screen_chunk_size
cfg.dataset.modality_config.eye_tracker.sampling_rate = screen_fs
cfg.dataset.modality_config.treadmill.chunk_size = screen_chunk_size
cfg.dataset.modality_config.treadmill.sampling_rate = screen_fs

cfg.dataset.modality_config.screen.sample_stride = 1
cfg.dataset.modality_config.screen.include_blanks=True
cfg.dataset.modality_config.screen.valid_condition = {"tier": "train"}
cfg.dataset.modality_config.screen.transforms.Resize.size = video_size

cfg.dataloader.num_workers=2
cfg.dataloader.prefetch_factor=2
cfg.dataloader.batch_size=batchsize
cfg.dataloader.pin_memory=False
cfg.dataloader.shuffle=True

train_dl = get_multisession_dataloader(full_paths[:1], cfg)



In [6]:
session_key , v = next(iter(train_dl))

In [8]:
session_key

'29513-3-5'

In [15]:
stats1 = train_dl.loaders[session_key].dataset._statistics

In [16]:
dataset = train_dl.loaders[session_key].dataset

In [22]:
dataset._experiment.devices["responses"].root_folder

PosixPath('/data/mouse_polly/dynamic29513-3-5-Video-full/responses')

In [65]:
device_name = "responses"
stds = np.load(dataset._experiment.devices[device_name].root_folder / "meta/stds.npy")
means = np.load(dataset._experiment.devices[device_name].root_folder / "meta/means.npy")

In [66]:
stds

array([[ 8.14456711,  6.38401159,  8.73280993, ..., 63.49786429,
        12.78671866, 22.40623612]])

In [67]:
means

array([[ 2.11708711,  1.82554417,  2.2468886 , ..., 10.67869377,
         2.29983346,  3.75235768]])

In [68]:
stats1[device_name]["std"]

array([[ 7.80178436,  6.34071646,  7.64709398, ..., 53.70008149,
        11.90085297, 19.61641034]])

In [69]:
stats1[device_name]["mean"]

array([[0., 0., 0., ..., 0., 0., 0.]])