In [5]:
# Setup Environment
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
import psutil
import time
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from mpl_toolkits.mplot3d import Axes3D
from sklearn.svm import SVR
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
from sklearn.preprocessing import LabelEncoder
import torch
import seaborn as sns
import pickle
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv
from torch.nn import Linear
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor

if torch.cuda.is_available():
    print("CUDA GPU is available.")
    device = torch.device('cuda')
else:
    print("CUDA GPU is not available. Using CPU instead.")
    device = torch.device('cpu')
    
print(f"Current GPU device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
print(f"Total RAM: {(psutil.virtual_memory().total / (1024**3)):.2f} GB")
print(f"Available RAM: {(psutil.virtual_memory().available / (1024**3)):.2f} GB")

from allensdk.core.brain_observatory_cache import BrainObservatoryCache
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache 

# Set output directory to a new folder called 'output' in the current working directory
output_dir = os.path.join(os.getcwd(), 'output')

# Check if the output directory exists, and create it if it doesn't
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Set DOWNLOAD_COMPLETE_DATASET to True
DOWNLOAD_COMPLETE_DATASET = True

# Create a file path to the manifest.json file within the output directory
manifest_path = os.path.join(output_dir, "manifest.json")

# Check if the manifest.json file exists
if os.path.exists(manifest_path):
    print("Using existing manifest.json file.")
else:
    print("Creating a new manifest.json file.")

# Create an instance of the EcephysProjectCache class with the manifest file path as argument
cache = EcephysProjectCache(manifest=manifest_path)
# Get session table
session_table = cache.get_session_table()

# Display session keys
session_keys = []
print("Session keys:")
for session_key in session_table.index:
    session_keys.append(session_key)
print(session_keys)

CUDA GPU is available.
Current GPU device: NVIDIA A100-PCIE-40GB MIG 2g.10gb
Total RAM: 1006.92 GB
Available RAM: 944.93 GB
Using existing manifest.json file.
Session keys:
[715093703, 719161530, 721123822, 732592105, 737581020, 739448407, 742951821, 743475441, 744228101, 746083955, 750332458, 750749662, 751348571, 754312389, 754829445, 755434585, 756029989, 757216464, 757970808, 758798717, 759883607, 760345702, 760693773, 761418226, 762120172, 762602078, 763673393, 766640955, 767871931, 768515987, 771160300, 771990200, 773418906, 774875821, 778240327, 778998620, 779839471, 781842082, 786091066, 787025148, 789848216, 791319847, 793224716, 794812542, 797828357, 798911424, 799864342, 816200189, 819186360, 819701982, 821695405, 829720705, 831882777, 835479236, 839068429, 839557629, 840012044, 847657808]


In [6]:
sessions = [816200189, 819186360, 819701982, 821695405, 829720705, 831882777, 835479236, 839068429, 839557629, 840012044, 847657808]

In [7]:
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm
import concurrent.futures
import pickle

# Function to check if a session contains natural scenes
def contains_natural_scenes(session):
    stimulus_presentations = session.stimulus_presentations
    return 'natural_scenes' in stimulus_presentations['stimulus_name'].unique()

# Print function to display all stimuli in a session
def print_all_stimuli(session):
    stimulus_presentations = session.stimulus_presentations
    unique_stimuli = stimulus_presentations['stimulus_name'].unique()
    print(f"Unique stimuli in session {session.ecephys_session_id}: {unique_stimuli}")

# Your main processing loop

for i in sessions:
    session_number = i
    
    # Pull session.
    session = cache.get_session_data(session_number,
                                     isi_violations_maximum=np.inf,
                                     amplitude_cutoff_maximum=np.inf,
                                     presence_ratio_minimum=-np.inf
                                    )
    
    # Print all stimuli in the session
    print_all_stimuli(session)
    
    # Check if the session contains natural scenes
    if not contains_natural_scenes(session):
        print(f"Skipping session {session_number} as it does not contain natural scenes.")
        continue

    # Get spike times.
    spike_times = session.spike_times
    
    # Get specific stimulus table.
    stimulus_table = session.get_stimulus_table("natural_scenes")
    
    # Display objects within session.
    print("Session objects")
    print([attr_or_method for attr_or_method in dir(session) if attr_or_method[0] != '_'])
    
    # Access the invalid_times DataFrame
    invalid_times = session.invalid_times

    # Function to check if a spike time is valid
    def is_valid_time(spike_times, invalid_intervals):
        invalid = np.zeros_like(spike_times, dtype=bool)
        for _, row in invalid_intervals.iterrows():
            start, end = row['start_time'], row['stop_time']
            invalid |= (spike_times >= start) & (spike_times <= end)
        return ~invalid

    # Filter the valid spike times
    valid_spike_times = {}
    
    with tqdm(total=len(spike_times), desc='Filtering valid spike times') as pbar:
        for neuron, times in spike_times.items():
            valid_mask = is_valid_time(times, session.invalid_times)
            valid_spike_times[neuron] = times[valid_mask]
            pbar.update(1)
            
    # Parameters
    timesteps_per_frame = 5  # Set the number of timesteps per frame

    # The start times of each stimulus presentation
    image_start_times = torch.tensor(stimulus_table.start_time.values)

    # The end times of each stimulus presentation
    image_end_times = torch.tensor(stimulus_table.stop_time.values)

    # The duration of each image presentation
    image_durations = image_end_times - image_start_times

    # The bin size for each image presentation
    bin_sizes = image_durations / timesteps_per_frame

    # The number of bins per image presentation
    bins_per_image = timesteps_per_frame

    # The total number of bins
    total_bins = bins_per_image * len(image_start_times)

    # Create an empty binary spike matrix
    num_neurons = len(spike_times.keys())

    def process_neuron(times):
        # The start bin for the next image presentation
        start_bin = 0
        neuron_spike_bins = torch.zeros(total_bins, dtype=torch.int32)
        for image_idx, (start_time, end_time) in enumerate(zip(image_start_times, image_end_times)):
            # Bin edges for this image presentation
            bin_edges = torch.linspace(start_time, end_time, bins_per_image + 1)

            # Bin the spike times for this image presentation
            binned_spike_times = torch.histc(torch.tensor(times), bins=bin_edges.shape[0]-1, min=bin_edges.min(), max=bin_edges.max())

            # Add the binned spike times to the spike matrix
            end_bin = start_bin + bins_per_image
            if len(binned_spike_times) == len(neuron_spike_bins[start_bin:end_bin]):
                neuron_spike_bins[start_bin:end_bin] = binned_spike_times

            # Update the start bin for the next image presentation
            start_bin = end_bin
        return neuron_spike_bins

    with concurrent.futures.ProcessPoolExecutor() as executor:
        spike_matrix = list(tqdm(executor.map(process_neuron, spike_times.values()), total=num_neurons, desc='Processing neurons'))

    spike_matrix = torch.stack(spike_matrix)

    # Convert the spike matrix to a pandas DataFrame and set the index to neuron IDs
    spike_dataframe = pd.DataFrame(spike_matrix.numpy(), index=spike_times.keys())

    spike_dataframe.T
    
    spike_df = spike_dataframe.T
    spike_df['frame'] = 'nan'
    spike_df['frame'] = np.repeat(np.array(stimulus_table['frame']), timesteps_per_frame)

    # Save the dictionary of valid spike times to a pickle file
    with open(f'spike_trains_with_stimulus_session_{session_number}_{timesteps_per_frame}.pkl', 'wb') as f:
        pickle.dump(spike_df, f)



Downloading:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

  return func(args[0], **pargs)
  return func(args[0], **pargs)


Unique stimuli in session 787025148: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 787025148 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/1.88G [00:00<?, ?B/s]

Unique stimuli in session 789848216: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 789848216 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.32G [00:00<?, ?B/s]

Unique stimuli in session 791319847: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings'
 'natural_movie_three' 'natural_movie_one' 'static_gratings'
 'natural_scenes' 'drifting_gratings_contrast']
Session objects
['DETAILED_STIMULUS_PARAMETERS', 'LazyProperty', 'age_in_days', 'api', 'channel_structure_intervals', 'channels', 'conditionwise_spike_statistics', 'ecephys_session_id', 'from_nwb_path', 'full_genotype', 'get_current_source_density', 'get_inter_presentation_intervals_for_stimulus', 'get_invalid_times', 'get_lfp', 'get_parameter_values_for_stimulus', 'get_pupil_data', 'get_screen_gaze_data', 'get_stimulus_epochs', 'get_stimulus_parameter_values', 'get_stimulus_table', 'inter_presentation_intervals', 'invalid_times', 'mean_waveforms', 'metadata', 'num_channels', 'num_probes', 'num_stimulus_presentations', 'num_units', 'optogenetic_stimulation_epochs', 'presentationwise_spike_counts', 'presentationwise_spike_times', 'probes', 'rig_equipment_name', 'rig_geometry_data', 'running

Filtering valid spike times: 100%|██████████| 1445/1445 [00:00<00:00, 7787.27it/s]
Processing neurons: 100%|██████████| 1445/1445 [06:37<00:00,  3.63it/s]


Downloading:   0%|          | 0.00/2.69G [00:00<?, ?B/s]

Unique stimuli in session 793224716: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'invalid_presentation'
 'natural_movie_one_shuffled' 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 793224716 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.59G [00:00<?, ?B/s]

Unique stimuli in session 794812542: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 794812542 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.55G [00:00<?, ?B/s]

Unique stimuli in session 797828357: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings'
 'natural_movie_three' 'natural_movie_one' 'static_gratings'
 'natural_scenes' 'drifting_gratings_contrast']
Session objects
['DETAILED_STIMULUS_PARAMETERS', 'LazyProperty', 'age_in_days', 'api', 'channel_structure_intervals', 'channels', 'conditionwise_spike_statistics', 'ecephys_session_id', 'from_nwb_path', 'full_genotype', 'get_current_source_density', 'get_inter_presentation_intervals_for_stimulus', 'get_invalid_times', 'get_lfp', 'get_parameter_values_for_stimulus', 'get_pupil_data', 'get_screen_gaze_data', 'get_stimulus_epochs', 'get_stimulus_parameter_values', 'get_stimulus_table', 'inter_presentation_intervals', 'invalid_times', 'mean_waveforms', 'metadata', 'num_channels', 'num_probes', 'num_stimulus_presentations', 'num_units', 'optogenetic_stimulation_epochs', 'presentationwise_spike_counts', 'presentationwise_spike_times', 'probes', 'rig_equipment_name', 'rig_geometry_data', 'running

Filtering valid spike times: 100%|██████████| 1610/1610 [00:01<00:00, 1586.44it/s]
Processing neurons: 100%|██████████| 1610/1610 [06:41<00:00,  4.01it/s]  


Downloading:   0%|          | 0.00/2.86G [00:00<?, ?B/s]

Unique stimuli in session 798911424: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings'
 'natural_movie_three' 'natural_movie_one' 'static_gratings'
 'natural_scenes' 'drifting_gratings_contrast']
Session objects
['DETAILED_STIMULUS_PARAMETERS', 'LazyProperty', 'age_in_days', 'api', 'channel_structure_intervals', 'channels', 'conditionwise_spike_statistics', 'ecephys_session_id', 'from_nwb_path', 'full_genotype', 'get_current_source_density', 'get_inter_presentation_intervals_for_stimulus', 'get_invalid_times', 'get_lfp', 'get_parameter_values_for_stimulus', 'get_pupil_data', 'get_screen_gaze_data', 'get_stimulus_epochs', 'get_stimulus_parameter_values', 'get_stimulus_table', 'inter_presentation_intervals', 'invalid_times', 'mean_waveforms', 'metadata', 'num_channels', 'num_probes', 'num_stimulus_presentations', 'num_units', 'optogenetic_stimulation_epochs', 'presentationwise_spike_counts', 'presentationwise_spike_times', 'probes', 'rig_equipment_name', 'rig_geometry_data', 'running

Filtering valid spike times: 100%|██████████| 1878/1878 [00:00<00:00, 4025.92it/s]
Processing neurons: 100%|██████████| 1878/1878 [07:23<00:00,  4.24it/s] 


Downloading:   0%|          | 0.00/2.74G [00:00<?, ?B/s]

Unique stimuli in session 799864342: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings'
 'natural_movie_three' 'natural_movie_one' 'static_gratings'
 'natural_scenes' 'drifting_gratings_contrast']
Session objects
['DETAILED_STIMULUS_PARAMETERS', 'LazyProperty', 'age_in_days', 'api', 'channel_structure_intervals', 'channels', 'conditionwise_spike_statistics', 'ecephys_session_id', 'from_nwb_path', 'full_genotype', 'get_current_source_density', 'get_inter_presentation_intervals_for_stimulus', 'get_invalid_times', 'get_lfp', 'get_parameter_values_for_stimulus', 'get_pupil_data', 'get_screen_gaze_data', 'get_stimulus_epochs', 'get_stimulus_parameter_values', 'get_stimulus_table', 'inter_presentation_intervals', 'invalid_times', 'mean_waveforms', 'metadata', 'num_channels', 'num_probes', 'num_stimulus_presentations', 'num_units', 'optogenetic_stimulation_epochs', 'presentationwise_spike_counts', 'presentationwise_spike_times', 'probes', 'rig_equipment_name', 'rig_geometry_data', 'running

Filtering valid spike times: 100%|██████████| 1551/1551 [00:00<00:00, 3983.60it/s]
Processing neurons: 100%|██████████| 1551/1551 [06:42<00:00,  3.86it/s]


Downloading:   0%|          | 0.00/2.45G [00:00<?, ?B/s]

Unique stimuli in session 816200189: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 816200189 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

Unique stimuli in session 819186360: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 819186360 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.52G [00:00<?, ?B/s]

Unique stimuli in session 819701982: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 819701982 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.01G [00:00<?, ?B/s]

Unique stimuli in session 821695405: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 821695405 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Unique stimuli in session 829720705: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 829720705 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

Unique stimuli in session 831882777: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 831882777 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.01G [00:00<?, ?B/s]

Unique stimuli in session 835479236: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 835479236 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.82G [00:00<?, ?B/s]

Unique stimuli in session 839068429: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 839068429 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/1.89G [00:00<?, ?B/s]

Unique stimuli in session 839557629: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 839557629 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.85G [00:00<?, ?B/s]

Unique stimuli in session 840012044: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 840012044 as it does not contain natural scenes.




Downloading:   0%|          | 0.00/2.81G [00:00<?, ?B/s]

Unique stimuli in session 847657808: ['spontaneous' 'gabors' 'flashes' 'drifting_gratings_contrast'
 'natural_movie_one_more_repeats' 'natural_movie_one_shuffled'
 'drifting_gratings_75_repeats' 'dot_motion']
Skipping session 847657808 as it does not contain natural scenes.
