# <a id='toc1_'></a>[set up](#toc0_)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
if sys.platform == 'linux':
    sys.path.append("/home/qix/MultiNeuronGLM")
else:
    sys.path.append("D:/Github/MultiNeuronGLM")

In [4]:
import pandas as pd
import utility_functions as utils
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
import random
import torch
import logging
import joblib

import GLM
from DataLoader import Allen_dataset, Allen_dataloader_multi_session

# sns.set_theme()
sns.set_theme(style="white")
# sns.set_style('whitegrid')

In [3]:
# Set random seed for reproducibility

random.seed(0)
np.random.seed(0) 
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)

logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s - %(levelname)s - %(message)s - [%(filename)s:%(lineno)d]'
)

# Load all sessions

In [5]:
# Build a dataloader for cross-session data
session_ids = [
    715093703, 719161530, 721123822, 732592105, 737581020, 739448407,
    742951821, 743475441, 744228101, 746083955, 750332458, 750749662,
    751348571, 754312389, 754829445, 755434585, 756029989, 757216464,
    757970808, 758798717, 759883607, 760345702, 760693773, 761418226,
    762120172, 762602078, 763673393, 766640955, 767871931, 768515987,
    771160300, 771990200, 773418906, 774875821, 778240327, 778998620,
    779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
    793224716, 794812542, 797828357, 798911424, 799864342, 816200189,
    819186360, 819701982, 821695405, 829720705, 831882777, 835479236,
    839068429, 839557629, 840012044, 847657808
]
kwargs = {
    'shuffle':False,
    'align_stimulus_onset':False, 
    'merge_trials':False, 
    'batch_size':64,
    'fps':1000, 
    'start_time':0.0, 
    'end_time':0.4, 
    'padding':0.1, 
    'selected_probes':['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'], 
}
cross_session_dataloader = Allen_dataloader_multi_session(session_ids, **kwargs)

2025-01-15 16:47:08,049 - CRITICAL - Total number of sessions: 58 - [DataLoader.py:91]
2025-01-15 16:47:08,050 - CRITICAL - Train ratio: 0.7 - [DataLoader.py:92]
2025-01-15 16:47:08,051 - CRITICAL - Val ratio: 0.1 - [DataLoader.py:93]
2025-01-15 16:47:08,052 - CRITICAL - Test ratio: 0.20000000000000004 - [DataLoader.py:94]
2025-01-15 16:47:08,053 - CRITICAL - Batch size: 64 - [DataLoader.py:95]
2025-01-15 16:47:08,053 - CRITICAL - Start loading data - [DataLoader.py:116]
Downloading: 100%|██████████| 2.52G/2.52G [03:52<00:00, 10.8MB/s]
100%|██████████| 58/58 [1:05:32<00:00, 67.80s/it] 


In [None]:
# Save the object
joblib.dump(cross_session_dataloader, 'cross_session_dataloader.joblib')

['cross_session_dataloader.joblib']

In [5]:
# Load the object
if sys.platform == 'linux':
    data_path = '/home/qix/user_data/allen_spike_trains/cross_session_dataloader.joblib'
else:
    data_path = 'D:/ecephys_cache_dir/cross_session_dataloader.joblib'
cross_session_dataloader = joblib.load(data_path)

In [7]:
cross_session_dataloader.session_trial_indices

[(0, 15871),
 (15871, 31734),
 (31734, 47599),
 (47599, 63463),
 (63463, 79334),
 (79334, 95207),
 (95207, 111074),
 (111074, 126944),
 (126944, 142818),
 (142818, 158693),
 (158693, 174562),
 (174562, 190429),
 (190429, 206300),
 (206300, 222177),
 (222177, 238051),
 (238051, 253930),
 (253930, 269809),
 (269809, 285686),
 (285686, 301563),
 (301563, 317441),
 (317441, 333318),
 (333318, 349192),
 (349192, 365068),
 (365068, 381934),
 (381934, 398799),
 (398799, 415665),
 (415665, 432531),
 (432531, 448661),
 (448661, 464792),
 (464792, 480664),
 (480664, 496532),
 (496532, 512402),
 (512402, 528271),
 (528271, 544142),
 (544142, 560018),
 (560018, 575887),
 (575887, 591760),
 (591760, 607631),
 (607631, 623502),
 (623502, 639367),
 (639367, 655239),
 (655239, 672102),
 (672102, 687970),
 (687970, 703844),
 (703844, 720703),
 (720703, 737565),
 (737565, 754429),
 (754429, 770299),
 (770299, 786170),
 (786170, 802042),
 (802042, 817914),
 (817914, 833784),
 (833784, 849657),
 (849657, 

In [29]:
cross_session_dataloader.shuffle = True
cross_session_dataloader._split_data()

In [31]:
cross_session_dataloader.train_batches[0]

array([564754, 564755, 564756, 564757, 564758, 564759, 564760, 564761,
       564762, 564763, 564764, 564765, 564766, 564767, 564768, 564769,
       564770, 564771, 564772, 564773, 564774, 564775, 564776, 564777,
       564778, 564779, 564780, 564781, 564782, 564783, 564784, 564785,
       564786, 564787, 564788, 564789, 564790, 564791, 564792, 564793,
       564794, 564795, 564796, 564797, 564798, 564799, 564800, 564801,
       564802, 564803, 564804, 564805, 564806, 564807, 564808, 564809,
       564810, 564811, 564812, 564813, 564814, 564815, 564816, 564817])

In [20]:
batch = cross_session_dataloader.get_batch(split='train')

Get next batch


In [13]:
len(cross_session_dataloader.train_batches)*3/3600

8.4675

In [11]:
batch[1]

{'spike_train':                                                          193
 units                                                       
 950911880  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911873  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911932  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911986  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950912018  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 ...                                                      ...
 950956911  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956870  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956845  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956952  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950957053  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 
 [258 rows x 1 columns],
 'session_id': 715093703,
 'trial_idx': 193}

# (Appendix) Check run time and bottom functions

In [64]:
spikes_table = cross_session_dataloader.sessions[757970808].get_spike_table_optimized(list(np.arange(1)))
spikes_table

Unnamed: 0,stimulus_presentation_id,unit_id,spike_time,time_since_stimulus_presentation_onset
227,0,951837953,24.478824,-0.097202
82,0,951836777,24.478887,-0.097138
279,0,951838062,24.481290,-0.094735
250,0,951837987,24.481624,-0.094402
341,0,951839806,24.488022,-0.088004
...,...,...,...,...
6,0,951840998,24.970495,0.394469
154,0,951837066,24.970554,0.394528
16,0,951841172,24.970695,0.394669
224,0,951837931,24.971857,0.395831


In [65]:
spikes_table = cross_session_dataloader.sessions[757970808].get_spike_table(list(np.arange(1)))
spikes_table

Unnamed: 0_level_0,stimulus_presentation_id,unit_id,time_since_stimulus_presentation_onset
spike_time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
24.478824,0,951837953,-0.097202
24.478887,0,951836777,-0.097138
24.481290,0,951838062,-0.094735
24.481624,0,951837987,-0.094402
24.488022,0,951839806,-0.088004
...,...,...,...
24.970495,0,951840998,0.394469
24.970554,0,951837066,0.394528
24.970695,0,951841172,0.394669
24.971857,0,951837931,0.395831


In [45]:
spike_train =cross_session_dataloader.sessions[757970808].get_trial_metric_per_unit_per_trial_test(
    selected_trials=list(np.arange(64)), 
)

In [47]:
spike_train

Unnamed: 0_level_0,24,33,36,44,53,57,0,1,3,7,...,41,42,45,46,54,61,62,63,10,28
units,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
951841982,"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,,,...,,,,,,,,,,
951841977,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,"[0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,,,,,,,,,,
951841010,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, ...",,"[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,
951840998,"[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, ...","[0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",...,"[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, ...","[0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, ...","[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, ...","[0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
951841002,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
951840074,,,,,,,,,,,...,,,,,,,,,,
951839907,,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,...,,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,,,
951839916,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,,,,...,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,,,,,,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ..."
951839940,,,"[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",,,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[3, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 3, 0, 0, ...",...,"[0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, ...","[0, 1, 0, 0, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [22]:
for i in range(64):
    spike_train =cross_session_dataloader.sessions[757970808].get_trial_metric_per_unit_per_trial(
        selected_trials=[i], 
        metric_type='count'
    )

In [11]:
for i in range(64):
    spike_times = cross_session_dataloader.sessions[757970808].get_spike_table([i])

In [12]:
spike_times

Unnamed: 0_level_0,stimulus_presentation_id,unit_id,time_since_stimulus_presentation_onset
spike_time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
24.501424,63,951838019,-36.324640
24.501661,63,951841367,-36.324402
24.501955,63,951839810,-36.324109
24.502690,63,951837931,-36.323373
24.502890,63,951837953,-36.323173
...,...,...,...
24.994424,63,951837947,-35.831640
24.995390,63,951838062,-35.830673
24.999258,63,951839099,-35.826806
24.999824,63,951837975,-35.826240
