# <a id='toc1_'></a>[set up](#toc0_)

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
# import my code
import GLM
from DataLoader import Allen_dataset, Allen_dataloader_multi_session
import utility_functions as utils

# Set random seed for reproducibility
utils.set_seed(0)

  from .autonotebook import tqdm as notebook_tqdm


# Load all sessions

In [3]:
### Build a dataloader for cross-session data
# Let's use only two sessions here as an example
# Feel free to change the sessions you want

# Only two sessions
session_ids = [757216464, 715093703]

# Sessions with all 6 probes; if you need this, just uncomment below
# session_ids = [
#     715093703, 719161530, 721123822, 737581020, 739448407, 742951821, 743475441,
#     744228101, 746083955, 750332458, 750749662, 751348571, 754312389, 755434585,
#     756029989, 757216464, 760693773, 761418226, 762602078, 763673393, 766640955,
#     767871931, 768515987, 771160300, 771990200, 773418906, 774875821, 778240327,
#     778998620, 779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
#     793224716, 794812542, 797828357, 798911424, 799864342, 831882777, 839068429,
#     840012044, 847657808
# ]

# All sessions provided in Allen Institute dataset
# session_ids = [
#     715093703, 719161530, 721123822, 732592105, 737581020, 739448407,
#     742951821, 743475441, 744228101, 746083955, 750332458, 750749662,
#     751348571, 754312389, 754829445, 755434585, 756029989, 757216464,
#     757970808, 758798717, 759883607, 760345702, 760693773, 761418226,
#     762120172, 762602078, 763673393, 766640955, 767871931, 768515987,
#     771160300, 771990200, 773418906, 774875821, 778240327, 778998620,
#     779839471, 781842082, 786091066, 787025148, 789848216, 791319847,
#     793224716, 794812542, 797828357, 798911424, 799864342, 816200189,
#     819186360, 819701982, 821695405, 829720705, 831882777, 835479236,
#     839068429, 839557629, 840012044, 847657808
# ]

kwargs = {
    'shuffle':True,
    'align_stimulus_onset':True,  # Whether you want every trial to have stimulus onset at the t=0
    'merge_trials':True, # If True, two consecutive trials with the same stimulus might be merged into one trial
        # For example, trial 1 and trial 2 are both "Gabors", each lasts 250ms, and there is no interval between them
    'batch_size':64,
    'fps':500,
    'start_time':0.0,
    'end_time':0.4,
    'padding':0.1, # time (in second) to include before t=0; this is useful for GLM coupling models 
        # because we want allow data at t=-20ms as predictor/covariates to affect our outcome variable at t=0ms 
        # since we have delay in the model. 
    'selected_probes':['probeA', 'probeB', 'probeC', 'probeD', 'probeE', 'probeF'],
}
cross_session_dataloader = Allen_dataloader_multi_session(session_ids, **kwargs)

Start loading data
  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 2/2 [00:30<00:00, 15.00s/it]

Total sessions: 2, Batch size: 64, Train set size: 264, Val set size: 40, Test set size: 78





Now this ```cross_session_dataloader``` is all you need. Let's take a look at it

In [4]:
print(f"Number of batches in the training set: {len(cross_session_dataloader.train_loader)}")
print(f"Number of batches in the validation set: {len(cross_session_dataloader.val_loader)}")
print(f"Number of batches in the test set: {len(cross_session_dataloader.test_loader)}")

Number of batches in the training set: 264
Number of batches in the validation set: 40
Number of batches in the test set: 78


**For fast loading speed, all trials in a batch comes from the same session (animal)**

Let's take a look at the first batch. 
Each batch is a ```dict``` that stores the following information:
- spike_trains: numpy arrary of shape (number of time bins, total number of neurons, number of trials aka batch size)
- presentation_ids: the stimulus presentation id for each trials
- neuron_id: the neuron id for each neuron
- session_id: this is only a string representing the session which all trials in the batch comes from. 
- nneuron_list: a list of number of neurons in each area/population for this session. The sum of this list is equal to total number of neurons

In [5]:
batch = next(iter(cross_session_dataloader.train_loader))
print(batch.keys())

dict_keys(['spike_trains', 'presentation_ids', 'neuron_id', 'session_id', 'nneuron_list'])


In [6]:
batch['spike_trains'].shape

torch.Size([250, 258, 64])

In [7]:
batch['presentation_ids'].shape

(64,)

In [8]:
batch['neuron_id'].shape

(258,)

In [9]:
batch['session_id']

'715093703'

In [10]:
batch['nneuron_list']

[30, 50, 60, 42, 30, 46]

# Save and reload the data loader

In [None]:
# You can save the object to a file so you don't have to run the above code again every time you want to load the data
import joblib
data_path = 'two_sessions.joblib'
joblib.dump(cross_session_dataloader, data_path)

In [None]:
# Use this line to load the saved object
cross_session_dataloader = joblib.load(data_path)