# <a id='toc1_'></a>[set up](#toc0_)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
if sys.platform == 'linux':
    sys.path.append("/home/qix/MultiNeuronGLM")
else:
    sys.path.append("D:/Github/MultiNeuronGLM")

In [2]:
import pandas as pd
import utility_functions as utils
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
import random
import torch
import logging
import joblib

import GLM
from DataLoader import Allen_dataset, Allen_dataloader_multi_session, BatchIterator

# sns.set_theme()
sns.set_theme(style="white")
# sns.set_style('whitegrid')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Set random seed for reproducibility

random.seed(0)
np.random.seed(0) 
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)

logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s - %(levelname)s - %(message)s - [%(filename)s:%(lineno)d]'
)

# Download all sessions

In [146]:
# Download all sessions
session_ids = [
    715093703, 719161530, 721123822, 732592105, 737581020, 739448407,
    742951821, 743475441, 744228101, 746083955, 750332458, 750749662,
    751348571, 754312389, 754829445, 755434585, 756029989, 757216464,
    757970808, 758798717, 759883607, 760345702, 760693773, 761418226,
    762120172, 762602078, 763673393, 766640955, 767871931, 768515987,
    771160300, 771990200, 773418906, 774875821, 778240327, 778998620,
    779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
    793224716, 794812542, 797828357, 798911424, 799864342, 816200189,
    819186360, 819701982, 821695405, 829720705, 831882777, 835479236,
    839068429, 839557629, 840012044, 847657808
]
allen_dataset = cross_session_dataloader.sessions[715093703]
for session_id in tqdm(session_ids):
    allen_dataset._cache.get_session_data(session_id)


Downloading: 100%|██████████| 2.86G/2.86G [01:33<00:00, 30.6MB/s]
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
Downloading: 100%|██████████| 3.07G/3.07G [16:12<00:00, 3.16MB/s]
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
Downloading: 100%|██████████| 1.74G/1.74G [00:57<00:00, 30.4MB/s]
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
Downloading: 100%|██████████| 2.91G/2.91G [01:29<00:00, 32.7MB/s]
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
Downloading: 100%|██████

# Load all sessions and save to hard drive

In [4]:
# Build a dataloader for cross-session data
session_ids = [
    715093703, 719161530, 721123822, 732592105, 737581020, 739448407,
    742951821, 743475441, 744228101, 746083955, 750332458, 750749662,
    751348571, 754312389, 754829445, 755434585, 756029989, 757216464,
    757970808, 758798717, 759883607, 760345702, 760693773, 761418226,
    762120172, 762602078, 763673393, 766640955, 767871931, 768515987,
    771160300, 771990200, 773418906, 774875821, 778240327, 778998620,
    779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
    793224716, 794812542, 797828357, 798911424, 799864342, 816200189,
    819186360, 819701982, 821695405, 829720705, 831882777, 835479236,
    839068429, 839557629, 840012044, 847657808
]
kwargs = {
    'shuffle':True,
    'align_stimulus_onset':False, 
    'merge_trials':False, 
    'batch_size':64,
    'fps':500, 
    'start_time':0.0, 
    'end_time':0.4, 
    'padding':0.1, 
    'selected_probes':['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'], 
}
cross_session_dataloader = Allen_dataloader_multi_session(session_ids, **kwargs)

2025-01-22 00:52:46,120 - CRITICAL - Start loading data - [DataLoader.py:155]
  0%|          | 0/58 [00:00<?, ?it/s]

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
100%|██████████| 58/58 [19:32<00:00, 20.21s/it]

Total sessions: 58, Batch size: 64, Train set size: 10178, Val set size: 1454, Test set size: 2908





In [5]:
# Save the object
if sys.platform == 'linux':
    data_path = '/home/qix/user_data/allen_spike_trains/cross_session_dataloader.joblib'
else:
    data_path = 'D:/ecephys_cache_dir/cross_session_dataloader.joblib'
joblib.dump(cross_session_dataloader, data_path)

['D:/ecephys_cache_dir/cross_session_dataloader.joblib']

# Make a toy dataset with two sessions for testing


In [4]:
# Build a dataloader for cross-session data
session_ids = [
    715093703, 719161530,
]
kwargs = {
    'shuffle':True,
    'align_stimulus_onset':True, 
    'merge_trials':True, 
    'batch_size':64,
    'fps':500, 
    'start_time':0.0, 
    'end_time':0.4, 
    'padding':0.1, 
    'selected_probes':['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'], 
    'stimulus_name': 'all',
}
cross_session_dataloader = Allen_dataloader_multi_session(session_ids, **kwargs)

2025-01-24 03:49:30,746 - CRITICAL - Start loading data - [DataLoader.py:155]
  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:28<00:00, 14.09s/it]

Total sessions: 2, Batch size: 64, Train set size: 266, Val set size: 38, Test set size: 76





In [5]:
# Save the object
if sys.platform == 'linux':
    data_path = '/home/qix/user_data/allen_spike_trains/two_session_toy_dataloader.joblib'
else:
    data_path = 'D:/ecephys_cache_dir/two_session_toy_dataloader.joblib'
joblib.dump(cross_session_dataloader, data_path)

['/home/qix/user_data/allen_spike_trains/two_session_toy_dataloader.joblib']

In [11]:
for ibatch, batch in enumerate(cross_session_dataloader.train_loader):
    if ibatch == 5:
        break
    print(batch["session_id"])

715093703
719161530
719161530
719161530
719161530


[30, 50, 60, 42, 30, 46]

# Reload from hard drive

In [6]:
# Load the object
if sys.platform == 'linux':
    data_path = '/home/qix/user_data/allen_spike_trains/cross_session_dataloader.joblib'
else:
    data_path = 'D:/ecephys_cache_dir/cross_session_dataloader.joblib'
cross_session_dataloader = joblib.load(data_path)

In [8]:
for ibatch, batch in enumerate(cross_session_dataloader.train_loader):
    if ibatch == 5:
        break
    print(batch["session_id"])

739448407
831882777
774875821
831882777
831882777


In [9]:
batch

{'spike_trains': array([[[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 1, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        ...,
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         ...,
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0]],
 
        [[0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0, ..., 0, 0, 0],
         [0, 0, 0

In [15]:
batch.keys()

dict_keys(['spike_trains', 'presentation_ids', 'neuron_id', 'session_id', 'nneuron_list'])

In [14]:
(
    batch["spike_trains"].shape, 
    batch["neuron_id"].shape, 
    batch["presentation_ids"].shape, 
    batch["session_id"],
    batch['nneuron_list'],
)

((250, 258, 64), (258,), (64,), 715093703, [30, 50, 60, 42, 30, 46])

In [12]:
cross_session_dataloader.common_kwargs

{'shuffle': True,
 'align_stimulus_onset': False,
 'merge_trials': False,
 'batch_size': 64,
 'fps': 500,
 'start_time': 0.0,
 'end_time': 0.4,
 'padding': 0.1,
 'selected_probes': ['probeA',
  'probeB',
  'probeC',
  'probeD',
  'probeE',
  'probeF']}

In [15]:
# Control experiment: don't have low resolution but only original resolution
for ibatch, batch in enumerate(cross_session_dataloader.train_loader):
    if ibatch == 5:
        break
    batch["low_res_spike_trains"] = utils.change_temporal_resolution_single(batch["spike_trains"], 10)