In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from experanto.data import SimpleChunkedDataset, Mouse2pStaticImageDataset
from utils import MultiEpochsDataLoader, LongCycler

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def get_n_dataloaders(N, bs, num_workers, persistent_workers, prefetch_factor, pin_memory):
    dataloaders = {}
    for i in range(N):
        if num_workers == 0:

            dl = MultiEpochsDataLoader(dataset, 
                                batch_size=bs, 
                                shuffle=True, 
                                num_workers=num_workers, )
        else:
            dl = MultiEpochsDataLoader(dataset, 
                                batch_size=bs, 
                                shuffle=True, 
                                num_workers=num_workers, 
                                persistent_workers=persistent_workers,
                                      prefetch_factor=prefetch_factor,
                                      pin_memory=pin_memory,)
        dataloaders[i]=dl
    return dataloaders

In [19]:
#root_folder = "./dynamic29228-2-10-Video-sensorium23_sanitycheck_times_fixed/"
root_folder = "/data/mouse_datasets/dynamic29228-2-10-Video-sensorium23_full_256_144/"
sampling_rate = 8  # Hz
chunk_size = 80 # samples per block, i.e. context length in samples per neuron
dataset = SimpleChunkedDataset(root_folder=root_folder, chunk_size=chunk_size, sampling_rate=sampling_rate)

In [20]:
datapoint = dataset[0]
for k, v in datapoint.items():
    print(k, v.shape)

responses (80, 7928)
screen (80, 1, 144, 256)
eye_tracker (80, 4)
treadmill (80, 1)
timestamps (80, 7928)


In [21]:
# The interpolation of the data isn't terribly fast, so I'd recommend to set num workers to >= 4
dataloader = MultiEpochsDataLoader(dataset, 
                        batch_size=8,
                        prefetch_factor=1,
                        shuffle=True, 
                        num_workers=12, 
                        persistent_workers=True,
                        #pin_memory=True,
)
for b in tqdm(dataloader):
    #gpu_tensor = b["screen"].cuda()
    pass

100% 92/92 [00:05<00:00, 15.44it/s]


In [7]:
for b in tqdm(dataloader):
    gpu_tensor = b["screen"].to("cuda", non_blocking=True)
    

100% 228/228 [00:13<00:00, 16.85it/s]


In [None]:
for b in tqdm(dataloader):
    gpu_tensor = b["screen"].to("cuda", non_blocking=True)
    

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1871, in _run_once
    event_list = self._selector.select(timeout)
  File "/usr/lib/python3.10/sele

In [9]:
dls=get_n_dataloaders(100,8,1,True, 1, True)

In [10]:
for i, (k, b) in tqdm(enumerate(LongCycler(dls))):
    _ = b["screen"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["responses"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["eye_tracker"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["treadmill"].to("cuda", torch.bfloat16, non_blocking=True)
    if i >100:
        break

101it [00:02, 34.92it/s] 


In [11]:
for i, (k, b) in tqdm(enumerate(LongCycler(dls))):
    _ = b["screen"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["responses"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["eye_tracker"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["treadmill"].to("cuda", torch.bfloat16, non_blocking=True)
    if i >10000:
        break

10001it [05:46, 28.90it/s]


In [1]:
10001 / 346

28.904624277456648

In [12]:
dls=get_n_dataloaders(100,8,1,True, 1, False)

In [None]:
for i, (k, b) in tqdm(enumerate(LongCycler(dls))):
    _ = b["screen"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["responses"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["eye_tracker"].to("cuda", torch.bfloat16, non_blocking=True)
    _ = b["treadmill"].to("cuda", torch.bfloat16, non_blocking=True)
    if i >10000:
        break

669it [00:37, 17.38it/s]