# <a id='toc1_'></a>[set up](#toc0_)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
if sys.platform == 'linux':
    sys.path.append("/home/qix/MultiNeuronGLM")
else:
    sys.path.append("D:/Github/MultiNeuronGLM")

In [4]:
import pandas as pd
import utility_functions as utils
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import numpy as np
import random
import torch
import logging
import joblib

import GLM
from DataLoader import Allen_dataset, Allen_dataloader_multi_session

# sns.set_theme()
sns.set_theme(style="white")
# sns.set_style('whitegrid')

In [3]:
# Set random seed for reproducibility

random.seed(0)
np.random.seed(0) 
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)

logging.basicConfig(
    level=logging.WARNING,
    format='%(asctime)s - %(levelname)s - %(message)s - [%(filename)s:%(lineno)d]'
)

# Load all sessions

In [5]:
# Build a dataloader for cross-session data
session_ids = [
    715093703, 719161530, 721123822, 732592105, 737581020, 739448407,
    742951821, 743475441, 744228101, 746083955, 750332458, 750749662,
    751348571, 754312389, 754829445, 755434585, 756029989, 757216464,
    757970808, 758798717, 759883607, 760345702, 760693773, 761418226,
    762120172, 762602078, 763673393, 766640955, 767871931, 768515987,
    771160300, 771990200, 773418906, 774875821, 778240327, 778998620,
    779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
    793224716, 794812542, 797828357, 798911424, 799864342, 816200189,
    819186360, 819701982, 821695405, 829720705, 831882777, 835479236,
    839068429, 839557629, 840012044, 847657808
]
kwargs = {
    'shuffle':False,
    'align_stimulus_onset':False, 
    'merge_trials':False, 
    'batch_size':64,
    'fps':1000, 
    'start_time':0.0, 
    'end_time':0.4, 
    'padding':0.1, 
    'selected_probes':['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'], 
}
cross_session_dataloader = Allen_dataloader_multi_session(session_ids, **kwargs)

2025-01-15 16:47:08,049 - CRITICAL - Total number of sessions: 58 - [DataLoader.py:91]
2025-01-15 16:47:08,050 - CRITICAL - Train ratio: 0.7 - [DataLoader.py:92]
2025-01-15 16:47:08,051 - CRITICAL - Val ratio: 0.1 - [DataLoader.py:93]
2025-01-15 16:47:08,052 - CRITICAL - Test ratio: 0.20000000000000004 - [DataLoader.py:94]
2025-01-15 16:47:08,053 - CRITICAL - Batch size: 64 - [DataLoader.py:95]
2025-01-15 16:47:08,053 - CRITICAL - Start loading data - [DataLoader.py:116]
Downloading: 100%|██████████| 2.52G/2.52G [03:52<00:00, 10.8MB/s]
100%|██████████| 58/58 [1:05:32<00:00, 67.80s/it] 


In [27]:
# Save the object
joblib.dump(cross_session_dataloader, 'cross_session_dataloader.joblib')

['cross_session_dataloader.joblib']

In [5]:
# Load the object
data_path = '/home/qix/user_data/allen_spike_trains/cross_session_dataloader.joblib'
cross_session_dataloader = joblib.load(data_path)

KeyboardInterrupt: 

In [6]:
cross_session_dataloader

<DataLoader.Allen_dataloader_multi_session at 0x794d0fd23b80>

In [14]:
cross_session_dataloader.session_trial_indices

[(0, 15871),
 (15871, 31734),
 (31734, 47599),
 (47599, 63463),
 (63463, 79334),
 (79334, 95207),
 (95207, 111074),
 (111074, 126944),
 (126944, 142818),
 (142818, 158693),
 (158693, 174562),
 (174562, 190429),
 (190429, 206300),
 (206300, 222177),
 (222177, 238051),
 (238051, 253930),
 (253930, 269809),
 (269809, 285686),
 (285686, 301563),
 (301563, 317441),
 (317441, 333318),
 (333318, 349192),
 (349192, 365068),
 (365068, 381934),
 (381934, 398799),
 (398799, 415665),
 (415665, 432531),
 (432531, 448661),
 (448661, 464792),
 (464792, 480664),
 (480664, 496532),
 (496532, 512402),
 (512402, 528271),
 (528271, 544142),
 (544142, 560018),
 (560018, 575887),
 (575887, 591760),
 (591760, 607631),
 (607631, 623502),
 (623502, 639367),
 (639367, 655239),
 (655239, 672102),
 (672102, 687970),
 (687970, 703844),
 (703844, 720703),
 (720703, 737565),
 (737565, 754429),
 (754429, 770299),
 (770299, 786170),
 (786170, 802042),
 (802042, 817914),
 (817914, 833784),
 (833784, 849657),
 (849657, 

In [16]:
cross_session_dataloader.train_batches[0]

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63])

In [24]:
batch = cross_session_dataloader.get_batch(split='train')

In [19]:
batch[1]

{'spike_train':                                                            1
 units                                                       
 950911880  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911873  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911932  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950911986  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950912018  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 ...                                                      ...
 950956911  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956870  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956845  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950956952  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 950957053  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
 
 [258 rows x 1 columns],
 'session_id': 715093703,
 'trial_idx': 1}

# (Appendix) Check run time and bottom functions

In [22]:
for i in range(64):
    spike_train =cross_session_dataloader.sessions[757970808].get_trial_metric_per_unit_per_trial(
        selected_trials=[i], 
        metric_type='count'
    )

In [11]:
for i in range(64):
    spike_times = cross_session_dataloader.sessions[757970808].get_spike_table([i])

In [12]:
spike_times

Unnamed: 0_level_0,stimulus_presentation_id,unit_id,time_since_stimulus_presentation_onset
spike_time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
24.501424,63,951838019,-36.324640
24.501661,63,951841367,-36.324402
24.501955,63,951839810,-36.324109
24.502690,63,951837931,-36.323373
24.502890,63,951837953,-36.323173
...,...,...,...
24.994424,63,951837947,-35.831640
24.995390,63,951838062,-35.830673
24.999258,63,951839099,-35.826806
24.999824,63,951837975,-35.826240
