In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec #Useful tool to arrange multiple plots in one figure (https://matplotlib.org/stable/api/_as_gen/matplotlib.gridspec.GridSpec.html)

%matplotlib inline

import platform
platstring = platform.platform()

data_dirname = 'visual-behavior-neuropixels'
use_static = False
if 'Darwin' in platstring or 'macOS' in platstring:
    # macOS 
    data_root = "/Volumes/Brain2021/"
elif 'Windows'  in platstring:
    # Windows (replace with the drive letter of USB drive)
    data_root = "E:/"
elif ('amzn' in platstring):
    # then on AWS
    data_root = "/data/"
    data_dirname = 'visual-behavior-neuropixels-data'
    use_static = True
else:
    # then your own linux platform
    # EDIT location where you mounted hard drive
    data_root = "/home/andrew/Documents/tmp/5-1-24-change-detection"

from allensdk.brain_observatory.behavior.behavior_project_cache.\
    behavior_neuropixels_project_cache \
    import VisualBehaviorNeuropixelsProjectCache

# this path should point to the location of the dataset on your platform
cache_dir = os.path.join(data_root, data_dirname)
cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)

# cache = VisualBehaviorNeuropixelsProjectCache.from_local_cache(
#             cache_dir=cache_dir, use_static_cache=use_static)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
sessions = cache.get_ecephys_session_table()

In [4]:
sessions.columns

Index(['behavior_session_id', 'date_of_acquisition', 'equipment_name',
       'session_type', 'mouse_id', 'genotype', 'sex', 'project_code',
       'age_in_days', 'unit_count', 'probe_count', 'channel_count',
       'structure_acronyms', 'image_set', 'prior_exposures_to_image_set',
       'session_number', 'experience_level', 'prior_exposures_to_omissions',
       'file_id', 'abnormal_histology', 'abnormal_activity'],
      dtype='object')

In [5]:
session = cache.get_ecephys_session(ecephys_session_id=1047969464)
units_df = session.get_units()
channels_df = session.get_channels()
print(session.get_units().columns)
print(session.get_channels().columns)

Index(['PT_ratio', 'amplitude', 'amplitude_cutoff', 'cluster_id',
       'cumulative_drift', 'd_prime', 'firing_rate', 'isi_violations',
       'isolation_distance', 'l_ratio', 'local_index', 'max_drift',
       'nn_hit_rate', 'nn_miss_rate', 'peak_channel_id', 'presence_ratio',
       'quality', 'recovery_slope', 'repolarization_slope', 'silhouette_score',
       'snr', 'spread', 'velocity_above', 'velocity_below',
       'waveform_duration'],
      dtype='object')
Index(['anterior_posterior_ccf_coordinate', 'dorsal_ventral_ccf_coordinate',
       'filtering', 'left_right_ccf_coordinate', 'probe_channel_number',
       'probe_horizontal_position', 'probe_id', 'probe_vertical_position',
       'structure_acronym'],
      dtype='object')


In [6]:
merged_df = units_df.merge(channels_df[['structure_acronym']], left_on='peak_channel_id', right_index=True)
merged_df.columns

Index(['PT_ratio', 'amplitude', 'amplitude_cutoff', 'cluster_id',
       'cumulative_drift', 'd_prime', 'firing_rate', 'isi_violations',
       'isolation_distance', 'l_ratio', 'local_index', 'max_drift',
       'nn_hit_rate', 'nn_miss_rate', 'peak_channel_id', 'presence_ratio',
       'quality', 'recovery_slope', 'repolarization_slope', 'silhouette_score',
       'snr', 'spread', 'velocity_above', 'velocity_below',
       'waveform_duration', 'structure_acronym'],
      dtype='object')

In [9]:
stimulus_presentations = session.stimulus_presentations
# stimulus_presentations.groupby('stimulus_block')[['stimulus_block', 
#                                                 'stimulus_name', 
#                                                 'active', 
#                                                 'duration', 
#                                                 'start_time']].head(1)
stimulus_presentations[stimulus_presentations['stimulus_block']==4].head(5)

Unnamed: 0_level_0,stimulus_block,image_name,duration,start_time,end_time,start_frame,end_frame,is_change,is_image_novel,omitted,...,stimulus_name,is_sham_change,color,position_y,orientation,temporal_frequency,stimulus_index,position_x,spatial_frequency,active
stimulus_presentations_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
8447,4,,0.250204,4859.441647,4859.691851,288733,288748,,,,...,flash_250ms,False,1.0,,,,1,,,False
8448,4,,0.250217,4861.443267,4861.693483,288853,288868,,,,...,flash_250ms,False,-1.0,,,,1,,,False
8449,4,,0.250212,4863.444957,4863.695168,288973,288988,,,,...,flash_250ms,False,-1.0,,,,1,,,False
8450,4,,0.250207,4865.446657,4865.696863,289093,289108,,,,...,flash_250ms,False,1.0,,,,1,,,False
8451,4,,0.250217,4867.448307,4867.698523,289213,289228,,,,...,flash_250ms,False,-1.0,,,,1,,,False


In [None]:
count_structure_acronyms = session.units['structure_acronym'].value_counts()

In [None]:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

def get_time_window(session, buffer=1):
    """
    Get the time window from slightly before the first flash to slightly after the third flash.
    
    Parameters:
        session (EcephysSession): The session object.
        buffer (float): Buffer time in seconds to extend before the first and after the last flash.
    
    Returns:
        tuple: start_time, end_time
    """
    num_flashes = 30
    flashes = session.get_stimulus_table("flashes").head(num_flashes)
    print("number of flashes!!!-------------")
    print(len(flashes))
    num_flashes = min(num_flashes, len(flashes))
    start_time = flashes.iloc[0]['start_time'] - buffer
    end_time = flashes.iloc[num_flashes-1]['stop_time'] + buffer

    flash_start_times = flashes['start_time'].values
    flash_end_times = flashes['stop_time'].values
    print(flash_start_times)

    return start_time, end_time, flash_start_times, flash_end_times

def extract_and_bin_spikes(session, start_time, end_time, bin_size, regions, flash_start_times, flash_end_times):
    """
    Extract spikes for each region within a defined time window and bin them.
    
    Parameters:
        session (EcephysSession): The session object.
        start_time (float): Start time of the window.
        end_time (float): End time of the window.
        bin_size (float): Bin size in seconds.
        regions (list): List of regions to include in the analysis.
    
    Returns:
        dict: A dictionary containing binned firing rates for each region.
    """
    binned_flash_starts, _ = np.histogram(flash_start_times, bins=np.arange(start_time, end_time, bin_size))
    binned_flash_ends, _ = np.histogram(flash_end_times, bins=np.arange(start_time, end_time, bin_size))

    region_data = {region: [] for region in regions}

    # print("spike times")
    # print(session.spike_times)
    # print("units")
    # print(session.units.columns)
    for unit_id, spikes in session.spike_times.items():
        region = session.units.loc[unit_id, 'ecephys_structure_acronym']
        if region in regions:
            binned_spikes, times = np.histogram(spikes, bins=np.arange(start_time, end_time, bin_size))
            # region_data[region].append(pd.Series(binned_spikes / bin_size, name=unit_id, index=times[:-1]))
            region_data[region].append(pd.Series(binned_spikes / bin_size, name=unit_id))

    # Average across units in each region or provide zeros where no data exists
    for region in regions:
        if region_data[region]:
            region_data[region] = pd.concat(region_data[region], axis=1).mean(axis=1)
        else:
            # Initialize the series with zeros for each time binS
            print("------------------------alert missing region data: ", region)
            # print(region)
            region_data[region] = pd.Series(np.zeros(len(np.arange(start_time, end_time, bin_size)[:-1])))

    return region_data, binned_flash_starts, binned_flash_ends


# Main analysis for two selected sessions
print(len(sessions))

regions = ["LGd", "LGv", "LP", "VISp", "VISl", "VISal", "VISrl", "VISpm", "VISam"]
bin_size = 0.005  # Example bin size of 100ms

# Collect data from all selected sessions
per_unit_counts = {region: [] for region in regions}
all_region_data = {region: [] for region in regions}
blank_regions = {region: 0 for region in regions}
all_binned_flash_starts = []
all_binned_flash_ends = []
count = -1
for session_id in sessions.index:
    count += 1

    session = cache.get_session_data(session_id)

    count_structure_acronyms = session.units['structure_acronym'].value_counts()
    for region in regions:
        per_unit_counts[region].append(count_structure_acronyms.get(region, 0))
        # print(f"Region: {region}, Count: {count_structure_acronyms.get(region, 0)}")

    start_time, end_time, flash_start_times, flash_end_times = get_time_window(session)
    print("start_time, end_time")
    print(start_time, end_time)
    region_data, binned_flash_starts, binned_flash_ends = extract_and_bin_spikes(session, start_time, end_time, bin_size, regions, flash_start_times, flash_end_times)

    if not all(len(df) == 12059 for df in region_data.values()):
        print("BIG PROBLEM BINS DON't MATCH: ", len(region_data["VISp"]))
        break

    for region in regions:
        if not region_data[region].any():
            print(f"{region} bumping empty count for region", region_data[region])
            blank_regions[region] += 1

    for region in regions:
        all_region_data[region].append(region_data[region])

    all_binned_flash_starts.append(binned_flash_starts)
    all_binned_flash_ends.append(binned_flash_ends)


In [64]:
ecephys_sessions_table = cache.get_ecephys_session_table()
sst_novel_sessions = ecephys_sessions_table[(ecephys_sessions_table['genotype'].str.contains('Sst')) & 
                                            (ecephys_sessions_table['experience_level']=='Novel')]
# session_id = 1053941483
# session_id = 1047969464 # good mouse familiar
session_id = 1130349290 # mouse novel
session = cache.get_ecephys_session(
            ecephys_session_id=session_id)
session_id_2 = 1152811536 # mouse novel
session_2 = cache.get_ecephys_session(
            ecephys_session_id=session_id_2)


# session.metadata
# ecephys_sessions_table

condition = (ecephys_sessions_table['experience_level'] == "Novel")
filtered_table = ecephys_sessions_table[condition]
filtered_table

Unnamed: 0_level_0,behavior_session_id,date_of_acquisition,equipment_name,session_type,mouse_id,genotype,sex,project_code,age_in_days,unit_count,probe_count,channel_count,structure_acronyms,image_set,prior_exposures_to_image_set,session_number,experience_level,prior_exposures_to_omissions,file_id,abnormal_histology,abnormal_activity
ecephys_session_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1044594870,1044624428,2020-08-20 15:03:56.422000+00:00,NP.1,EPHYS_1_images_H_5uL_reward,524761,wt/wt,F,NeuropixelVisualBehavior,152,2103,5,1920,"['CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg', 'HPF...",H,0,2,Novel,1,872,,
1048189115,1048221709,2020-09-03 14:16:57.913000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,509808,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,264,1925,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,879,,
1048196054,1048222325,2020-09-03 14:25:07.290000+00:00,NP.0,EPHYS_1_images_H_3uL_reward,524925,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,F,NeuropixelVisualBehavior,166,2288,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,880,,
1049514117,1049542142,2020-09-10 15:11:15.371000+00:00,NP.0,EPHYS_1_images_H_3uL_reward,521466,Vip-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,F,NeuropixelVisualBehavior,194,1925,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,882,,
1051155866,1052162536,2020-09-17 15:05:39.665000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,524760,wt/wt,F,NeuropixelVisualBehavior,180,1922,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,883,,
1052533639,1052572359,2020-09-24 15:12:13.229000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,530862,Vip-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,149,1677,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,887,,
1053925378,1053960984,2020-10-01 16:07:18.990000+00:00,NP.0,EPHYS_1_images_H_3uL_reward,532246,Vip-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,145,1823,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,890,,
1053941483,1053960987,2020-10-01 17:03:58.362000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,527749,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,180,1543,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,891,,
1055403683,1055431030,2020-10-08 15:12:55.061000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,533537,wt/wt,M,NeuropixelVisualBehavior,144,1569,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,894,,
1055415082,1055434752,2020-10-08 15:44:11.498000+00:00,NP.0,EPHYS_1_images_H_3uL_reward,533539,wt/wt,F,NeuropixelVisualBehavior,144,1510,6,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,895,,


In [65]:

pd.set_option('display.max_rows', 500)

import pandas as pd
import matplotlib.pyplot as plt

trials = session_2.trials

# Assuming you have already loaded the DataFrame 'trials' from 'session.trials'

# Select relevant columns
trials = trials[['is_change', 'hit']]
trials.head(500)

# Cumulative sums for 'is_change' and 'hit'
trials['cumulative_hit'] = trials['hit'].cumsum()
trials['cumulative_change'] = trials['is_change'].cumsum()

# Plotting the curves
plt.figure(figsize=(10, 5))

# Plot cumulative hit count
plt.plot(trials.index, trials['cumulative_hit'], label="Cumulative Hit Count", color="blue")

# Plot cumulative change count
plt.plot(trials.index, trials['cumulative_change'], label="Cumulative Change Count", color="green")

# Adding labels and title
plt.xlabel("Trial ID")
plt.ylabel("Cumulative Count")
plt.title("Cumulative Hit and Change Count Over Time")
plt.legend()

# Show plot
plt.show()

# trials.columns
# hit_count = trials[trials['hit']].shape[0]
# hit_count

NameError: name 'session2' is not defined

In [18]:

licks = session.licks
licks.head()

Unnamed: 0,timestamps,frame
0,27.2133,86
1,27.3754,95
2,27.50548,103
3,27.73266,117
4,31.92935,368
