In [21]:
# Import the specific classes and functions we need
from swr_spike_behavior_visualizer import (
    SWRSpikeAnalyzer,
    abi_visual_behavior_units_session_search
)


In [None]:

# Configuration parameters
CACHE_DIR = "/space/scratch/allen_visbehave_data"
OUTPUT_DIR = "/home/acampbell/NeuropixelsLFPOnRamp/Figures_and_Technical_Validation/Relating_SWR_to_other_data/Results"
SWR_INPUT_DIR = "/space/scratch/SWR_final_pipeline/osf_campbellmurphy2025_v2_final"
DATASET_NAME = "allen_visbehave_swr_murphylab2024"
TARGET_REGIONS = ['RSC', 'SUB']
MIN_UNITS_PER_REGION = 100

# Create output directory if it doesn't exist
#os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize analyzer
analyzer = SWRSpikeAnalyzer(
    cache_dir=CACHE_DIR,
    swr_input_dir=SWR_INPUT_DIR
)

# Get sessions with good unit counts
search_df = abi_visual_behavior_units_session_search(
    cache_dir=CACHE_DIR,
    target_regions=TARGET_REGIONS,
    min_units_per_region=MIN_UNITS_PER_REGION
)

# Add column to check if session has SWR data
search_df['has_swr_data'] = search_df.index.map(
    lambda x: os.path.exists(os.path.join(SWR_INPUT_DIR, DATASET_NAME, f"swrs_session_{x}"))
)

# Filter for sessions with SWR data
search_df = search_df[search_df['has_swr_data']]

# Display the search results
print("\nSessions with SWR data:")
print(search_df)

# Get the session with the most target region units for testing
target_region = 'SUB'
if target_region in search_df.columns:
    test_session_id = search_df.index[0]  # First row has highest target_region count
    print(f"\nTesting with session {test_session_id} ({target_region} units: {search_df[target_region].iloc[0]})")
    
    # Analyze the test session
    results = analyzer.analyze_session_swr_spikes(
        session_id=test_session_id,
        regions=TARGET_REGIONS,
        window_size=0.05  # 50ms windows
    )
    
    # Display results for each region
    for region, region_results in results.items():
        print(f"\nResults for {region}:")
        print(region_results)
        
        # Create filename with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(OUTPUT_DIR, f"swr_spike_analysis_{region}_{timestamp}.csv")
        
        # Save to CSV
        #region_results.to_csv(output_file, index=False)
        region_results.head(5)
        

In [2]:
from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache

# Initialize the cache
cache_dir = "/space/scratch/allen_visbehave_data"
cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)

# Try to get a session that doesn't exist
session_id = "1044385384"
try:
    session = cache.get_ecephys_session(session_id)
except Exception as e:
    print(f"Error: {e}")
    print(f"Session ID: {session_id}")
    
    # Let's check what sessions are available
    sessions = cache.get_ecephys_session_table()
    print("\nAvailable session IDs:")
    print(sessions.index.tolist())

Error: The ecephys_session_table should have 1 and only 1 entry for a given ecephys_session_id. No indexed rows found for id=1044385384
Session ID: 1044385384

Available session IDs:
[1044385384, 1044594870, 1047969464, 1047977240, 1048189115, 1048196054, 1049273528, 1049514117, 1051155866, 1052342277, 1052533639, 1053709239, 1053718935, 1053925378, 1053941483, 1055221968, 1055240613, 1055403683, 1055415082, 1062755779, 1063010385, 1064400234, 1064415305, 1064639378, 1064644573, 1065437523, 1065449881, 1065905010, 1065908084, 1067588044, 1067781390, 1069461581, 1071300149, 1081079981, 1081090969, 1081429294, 1081431006, 1086200042, 1086410738, 1087720624, 1087992708, 1089296550, 1090803859, 1091039376, 1091039902, 1092283837, 1092466205, 1093638203, 1093642839, 1093864136, 1093867806, 1095138995, 1095340643, 1096620314, 1096935816, 1098119201, 1099598937, 1099869737, 1104052767, 1104058216, 1104289498, 1104297538, 1105543760, 1105798776, 1108334384, 1108335514, 1108528422, 1108531612, 

In [4]:
cache.get_ecephys_session(session_id)

RuntimeError: The ecephys_session_table should have 1 and only 1 entry for a given ecephys_session_id. No indexed rows found for id=1044385384

In [8]:
ecephys_sessions = cache.get_ecephys_session_table()

print(f"Total number of ecephys sessions: {len(ecephys_sessions)}")

ecephys_sessions.head()

Total number of ecephys sessions: 103


Unnamed: 0_level_0,behavior_session_id,date_of_acquisition,equipment_name,session_type,mouse_id,genotype,sex,project_code,age_in_days,unit_count,...,channel_count,structure_acronyms,image_set,prior_exposures_to_image_set,session_number,experience_level,prior_exposures_to_omissions,file_id,abnormal_histology,abnormal_activity
ecephys_session_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1044385384,1044408432,2020-08-19 14:47:08.574000+00:00,NP.1,EPHYS_1_images_G_5uL_reward,524761,wt/wt,F,NeuropixelVisualBehavior,151,2179,...,1920,"['CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg', 'LGv...",G,30,1,Familiar,0,870,,
1044594870,1044624428,2020-08-20 15:03:56.422000+00:00,NP.1,EPHYS_1_images_H_5uL_reward,524761,wt/wt,F,NeuropixelVisualBehavior,152,2103,...,1920,"['CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg', 'HPF...",H,0,2,Novel,1,872,,
1047969464,1048005547,2020-09-02 14:53:14.347000+00:00,NP.1,EPHYS_1_images_G_3uL_reward,509808,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,263,2438,...,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",G,62,1,Familiar,0,877,,
1047977240,1048009327,2020-09-02 15:15:03.733000+00:00,NP.0,EPHYS_1_images_G_3uL_reward,524925,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,F,NeuropixelVisualBehavior,165,1856,...,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",G,51,1,Familiar,0,878,,
1048189115,1048221709,2020-09-03 14:16:57.913000+00:00,NP.1,EPHYS_1_images_H_3uL_reward,509808,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,M,NeuropixelVisualBehavior,264,1925,...,2304,"['APN', 'CA1', 'CA3', 'DG-mo', 'DG-po', 'DG-sg...",H,0,2,Novel,1,879,,


In [13]:
int(session_id)

1044385384

In [15]:
np.isin(int(session_id), ecephys_sessions.index)

array(True)

In [6]:
cache.list_all_downloaded_manifests()

['visual-behavior-neuropixels_project_manifest_v0.5.0.json']

In [5]:
cache.current_manifest()

'visual-behavior-neuropixels_project_manifest_v0.5.0.json'

In [3]:
session = cache.get_ecephys_session(session_id)

RuntimeError: The ecephys_session_table should have 1 and only 1 entry for a given ecephys_session_id. No indexed rows found for id=1044385384

In [1]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache
import os
import sys

from scipy import signal
from scipy.stats import zscore
import seaborn as sns
from scipy.signal import hilbert
import argparse
import json
from datetime import datetime
import logging


  from .autonotebook import tqdm as notebook_tqdm


In [20]:

# =============================================================================
# Configuration Parameters
# =============================================================================
# Cache and data paths
CACHE_DIR = "/space/scratch/allen_visbehave_data"
OUTPUT_DIR = "/home/acampbell/NeuropixelsLFPOnRamp/Figures_and_Technical_Validation/Relating_SWR_to_other_data/Results"
SWR_INPUT_DIR = "/space/scratch/SWR_final_pipeline/osf_campbellmurphy2025_v2_final"  # Directory containing SWR event files

# Dataset configuration
DATASET_NAME = "allen_visbehave_swr_murphylab2024"  # Name of the dataset in SWRExplorer

# Session finding parameters
MIN_UNITS_PER_REGION = 50  # Minimum number of units required in each region
MAX_SPEED_THRESHOLD = 5.0  # Maximum speed during SWR (cm/s)
MIN_PUPIL_DIAMETER = 0.5  # Minimum pupil diameter (arbitrary units)
MAX_PUPIL_DIAMETER = 2.0  # Maximum pupil diameter (arbitrary units)
EVENTS_PER_SESSION = 3  # Number of best events to find per session

# SWR detection parameters
MIN_SW_POWER = 1
MIN_DURATION = 0.05
MAX_DURATION = 0.15
WINDOW_SIZE = 0.2  # Window size for spike correlation (seconds)

# Ripple band power parameters
MIN_RIPPLE_POWER = 5.0  # Minimum ripple band peak power (z-score)
MAX_RIPPLE_POWER = 10.0  # Maximum ripple band peak power (z-score)

# Target regions to analyze
TARGET_REGIONS = ['RSC', 'SUB']


In [17]:
cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=CACHE_DIR)
#explorer = SWRExplorer(base_path=SWR_INPUT_DIR)

# Get full unit and channel tables
units = cache.get_unit_table()
channels = cache.get_channel_table()

print("\nInitial data shapes:")
print(f"Units table shape: {units.shape}")
print(f"Channels table shape: {channels.shape}")

# Goo unit counts...
good_units = units[(units['quality'] == 'good') & (units['valid_data'] == True)]
good_unit_counts = good_units.groupby('ecephys_channel_id').size().rename('good_unit_count')
channels = channels.join(good_unit_counts, how='left')
channels['good_unit_count'] = channels['good_unit_count'].fillna(0).astype(int)


channels_in_rois_with_units_and_data_mask = np.isin(channels.structure_acronym, TARGET_REGIONS)
channels_in_rois_with_units_and_data_mask = channels_in_rois_with_units_and_data_mask & channels.valid_data
channels_in_rois_with_units_and_data_mask = channels_in_rois_with_units_and_data_mask
np.sum(channels_in_rois_with_units_and_data_mask)


Initial data shapes:
Units table shape: (319013, 35)
Channels table shape: (347520, 11)


7892

In [18]:
# Step 1: Apply region+valid mask to channels
roi_channels = channels[channels_in_rois_with_units_and_data_mask].copy()

# Step 2: Group by session and region, summing good unit counts
grouped_counts = (
    roi_channels
    .groupby(['ecephys_session_id', 'structure_acronym'])['good_unit_count']
    .sum()
    .unstack(fill_value=0)
)

# Step 3: Filter sessions with any region surpassing the unit count threshold
session_pass_mask = (grouped_counts >= MIN_UNITS_PER_REGION).any(axis=1)
passed_sessions_df = grouped_counts[session_pass_mask].copy()

# Optional: Add total units across all target regions
passed_sessions_df['total_good_units_in_rois'] = passed_sessions_df.sum(axis=1)

# Show result
print(f"\nNumber of sessions passing threshold: {len(passed_sessions_df)}")
display(passed_sessions_df.head())

# Step 1: Apply region+valid mask to channels
roi_channels = channels[channels_in_rois_with_units_and_data_mask].copy()

# Step 2: Group by session and region, summing good unit counts
grouped_counts = (
    roi_channels
    .groupby(['ecephys_session_id', 'structure_acronym'])['good_unit_count']
    .sum()
    .unstack(fill_value=0)
)

# Step 3: Filter sessions with any region surpassing the unit count threshold
session_pass_mask = (grouped_counts >= MIN_UNITS_PER_REGION).any(axis=1)
passed_sessions_df = grouped_counts[session_pass_mask].copy()

# Optional: Add total units across all target regions
passed_sessions_df['total_good_units_in_rois'] = passed_sessions_df.sum(axis=1)

# Show result
print(f"\nNumber of sessions passing threshold: {len(passed_sessions_df)}")
display(passed_sessions_df.head())


Number of sessions passing threshold: 11


structure_acronym,SUB,total_good_units_in_rois
ecephys_session_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1046166369,128,128
1053709239,196,196
1053925378,113,113
1064415305,102,102
1076487758,100,100



Number of sessions passing threshold: 11


structure_acronym,SUB,total_good_units_in_rois
ecephys_session_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1046166369,128,128
1053709239,196,196
1053925378,113,113
1064415305,102,102
1076487758,100,100


In [20]:
grouped_counts

structure_acronym,SUB
ecephys_session_id,Unnamed: 1_level_1
1043752325,86
1044389060,66
1044597824,51
1046166369,128
1046581736,56
...,...
1128520325,95
1130113579,0
1139846596,67
1140102579,1


In [19]:
passed_sessions_df

structure_acronym,SUB,total_good_units_in_rois
ecephys_session_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1046166369,128,128
1053709239,196,196
1053925378,113,113
1064415305,102,102
1076487758,100,100
1084427055,204,204
1084428217,112,112
1086200042,137,137
1093638203,166,166
1093867806,140,140


In [18]:
channels[['unit_count','good_unit_count']]

Unnamed: 0_level_0,unit_count,good_unit_count
ecephys_channel_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1049365509,0,0
1049365511,5,4
1049365512,0,0
1049365513,5,4
1049365514,7,4
...,...,...
1163267723,0,0
1163267724,0,0
1163267725,0,0
1163267726,0,0


In [9]:
print(f'units.columns : {units.columns}')

units.columns : Index(['ecephys_channel_id', 'ecephys_probe_id', 'ecephys_session_id',
       'amplitude_cutoff', 'anterior_posterior_ccf_coordinate',
       'dorsal_ventral_ccf_coordinate', 'left_right_ccf_coordinate',
       'cumulative_drift', 'd_prime', 'structure_acronym', 'structure_id',
       'firing_rate', 'isi_violations', 'isolation_distance', 'l_ratio',
       'local_index', 'max_drift', 'nn_hit_rate', 'nn_miss_rate',
       'presence_ratio', 'probe_horizontal_position',
       'probe_vertical_position', 'silhouette_score', 'snr', 'quality',
       'valid_data', 'amplitude', 'waveform_duration', 'waveform_halfwidth',
       'PT_ratio', 'recovery_slope', 'repolarization_slope', 'spread',
       'velocity_above', 'velocity_below'],
      dtype='object')


In [30]:
#!/usr/bin/env python3
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache
import os
import sys

# Add parent directory to path to import SWRExplorer
try:
    # This works in script context
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
except NameError:
    # This works in notebook context
    parent_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(parent_dir)

from scipy import signal
from scipy.stats import zscore
import seaborn as sns
from scipy.signal import hilbert
import argparse
import json
from datetime import datetime
import logging

# =============================================================================
# Configuration Parameters
# =============================================================================
# Cache and data paths
CACHE_DIR = "/space/scratch/allen_visbehave_data"
OUTPUT_DIR = "/home/acampbell/NeuropixelsLFPOnRamp/Figures_and_Technical_Validation/Relating_SWR_to_other_data/Results"
SWR_INPUT_DIR = "/space/scratch/SWR_final_pipeline/osf_campbellmurphy2025_v2_final"  # Directory containing SWR event files

# Dataset configuration
DATASET_NAME = "allen_visbehave_swr_murphylab2024"  # Name of the dataset in SWRExplorer

# Session finding parameters
MIN_UNITS_PER_REGION = 100  # Minimum number of units required in each region
MAX_SPEED_THRESHOLD = 5.0  # Maximum speed during SWR (cm/s)
MIN_PUPIL_DIAMETER = 0.5  # Minimum pupil diameter (arbitrary units)
MAX_PUPIL_DIAMETER = 2.0  # Maximum pupil diameter (arbitrary units)
EVENTS_PER_SESSION = 3  # Number of best events to find per session

# SWR detection parameters
MIN_SW_POWER = 1
MIN_DURATION = 0.05
MAX_DURATION = 0.15
WINDOW_SIZE = 0.2  # Window size for spike correlation (seconds)

# Ripple band power parameters
MIN_RIPPLE_POWER = 5.0  # Minimum ripple band peak power (z-score)
MAX_RIPPLE_POWER = 10.0  # Maximum ripple band peak power (z-score)

# Target regions to analyze
TARGET_REGIONS = ['RSC', 'SUB']

from allensdk.brain_observatory.behavior.behavior_project_cache import VisualBehaviorNeuropixelsProjectCache
import numpy as np
import pandas as pd

def abi_visual_behavior_units_session_search(
    cache_dir,
    target_regions,
    min_units_per_region=5
):
    """
    From the Allen Visual Behavior Neuropixels dataset, find sessions with at least
    `min_units_per_region` good units in any of the `target_regions`.

    Parameters
    ----------
    cache_dir : str
        Path to the AllenSDK VisualBehaviorNeuropixelsProjectCache directory.
    target_regions : list of str
        List of structure acronyms to consider as regions of interest.
    min_units_per_region : int, default=5
        Minimum number of good units required in any one region.

    Returns
    -------
    passed_sessions_df : pd.DataFrame
        Session-by-region table with good unit counts, indexed by session ID,
        plus a column with total units across all target regions.
    """

    # Load cache and tables
    cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)
    units = cache.get_unit_table()
    channels = cache.get_channel_table()

    print("\nInitial data shapes:")
    print(f"Units table shape: {units.shape}")
    print(f"Channels table shape: {channels.shape}")

    # Filter for good units
    good_units = units[(units['quality'] == 'good') & (units['valid_data'] == True)]

    # Count good units per channel and join to channels
    good_unit_counts = good_units.groupby('ecephys_channel_id').size().rename('good_unit_count')
    channels = channels.join(good_unit_counts, how='left')
    channels['good_unit_count'] = channels['good_unit_count'].fillna(0).astype(int)

    # Filter for target regions with valid data
    region_mask = channels['structure_acronym'].isin(target_regions) & channels['valid_data']
    roi_channels = channels[region_mask].copy()

    # Group by session and region
    grouped_counts = (
        roi_channels
        .groupby(['ecephys_session_id', 'structure_acronym'])['good_unit_count']
        .sum()
        .unstack(fill_value=0)
    )

    # Identify sessions passing threshold
    session_pass_mask = (grouped_counts >= min_units_per_region).any(axis=1)
    passed_sessions_df = grouped_counts[session_pass_mask].copy()

    # Add total across all target regions
    passed_sessions_df['total_good_units_in_rois'] = passed_sessions_df.sum(axis=1)

    print(f"\nNumber of sessions passing threshold: {len(passed_sessions_df)}")
    return passed_sessions_df

class SWRSpikeAnalyzer:
    def __init__(self, cache_dir, swr_input_dir, dataset_name="allen_visbehave_swr_murphylab2024"):
        """
        Initialize the SWRSpikeAnalyzer.
        
        Parameters:
        -----------
        cache_dir : str
            Path to the AllenSDK cache directory
        swr_input_dir : str
            Path to the directory containing SWR event files
        dataset_name : str
            Name of the dataset in SWRExplorer
        """
        self.cache_dir = cache_dir
        self.swr_input_dir = swr_input_dir
        self.dataset_name = dataset_name
        self.cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)
        self.explorer = SWRExplorer(base_path=swr_input_dir)
        
    def get_region_units_and_spikes(self, session_id, region):
        """
        Get units and spike times for a specific region in a session.
        
        Parameters:
        -----------
        session_id : str
            Session ID
        region : str
            Region acronym (e.g., 'RSC', 'SUB')
            
        Returns:
        --------
        tuple
            (region_units_df, spike_times_dict)
        """
        # Get good units
        units = self.cache.get_unit_table()
        channels = self.cache.get_channel_table()
        good_units = units[(units['quality'] == 'good') & (units['valid_data'] == True)]
        
        # Merge with channel info to get structure
        units_with_structure = good_units.merge(
            channels[['structure_acronym', 'probe_vertical_position']], 
            left_on='ecephys_channel_id', 
            right_index=True
        )
        
        # Filter for target region and session
        region_units = units_with_structure[
            (units_with_structure.structure_acronym == region) & 
            (units_with_structure.ecephys_session_id == session_id)
        ]
        
        # Get spike times for these units
        session = self.cache.get_ecephys_session(session_id)
        spike_times = {unit_id: session.spike_times[unit_id] for unit_id in region_units.index}
        
        return region_units, spike_times
        
    def calculate_firing_rate_changes(self, spike_times, swr_events, window_size=0.05):
        """
        Calculate firing rate changes around SWR events.
        
        Parameters:
        -----------
        spike_times : dict
            Dictionary mapping unit IDs to spike time arrays
        swr_events : pd.DataFrame
            DataFrame containing SWR events
        window_size : float
            Size of time windows in seconds
            
        Returns:
        --------
        pd.DataFrame
            DataFrame with firing rate change statistics for each unit
        """
        results = []
        
        for unit_id, times in spike_times.items():
            unit_results = {'unit_id': unit_id}
            
            # Calculate firing rates for different windows
            for _, event in swr_events.iterrows():
                # Before peak window
                before_mask = (times >= event['start_time']) & (times < event['peak_time'])
                before_rate = np.sum(before_mask) / (event['peak_time'] - event['start_time'])
                
                # After peak window
                after_mask = (times > event['peak_time']) & (times <= event['end_time'])
                after_rate = np.sum(after_mask) / (event['end_time'] - event['peak_time'])
                
                # During event vs baseline
                during_mask = (times >= event['start_time']) & (times <= event['end_time'])
                during_rate = np.sum(during_mask) / (event['end_time'] - event['start_time'])
                
                # Baseline window (50ms before event)
                baseline_start = event['start_time'] - window_size
                baseline_mask = (times >= baseline_start) & (times < event['start_time'])
                baseline_rate = np.sum(baseline_mask) / window_size
                
                # Store rates
                unit_results.setdefault('before_rates', []).append(before_rate)
                unit_results.setdefault('after_rates', []).append(after_rate)
                unit_results.setdefault('during_rates', []).append(during_rate)
                unit_results.setdefault('baseline_rates', []).append(baseline_rate)
            
            # Calculate t-tests
            from scipy import stats
            from statsmodels.stats.multitest import multipletests
            
            # Before vs After peak
            t_stat, p_val = stats.ttest_rel(unit_results['before_rates'], unit_results['after_rates'])
            unit_results['before_vs_after'] = (np.mean(unit_results['after_rates']) - np.mean(unit_results['before_rates']), p_val)
            
            # During vs Baseline
            t_stat, p_val = stats.ttest_rel(unit_results['during_rates'], unit_results['baseline_rates'])
            unit_results['during_vs_baseline'] = (np.mean(unit_results['during_rates']) - np.mean(unit_results['baseline_rates']), p_val)
            
            results.append(unit_results)
        
        # Convert to DataFrame
        results_df = pd.DataFrame(results)
        
        # Apply Benjamini-Hochberg correction
        for test in ['before_vs_after', 'during_vs_baseline']:
            p_values = [x[1] for x in results_df[test]]
            _, p_adjusted, _, _ = multipletests(p_values, method='fdr_bh')
            results_df[f'{test}_corrected'] = [(x[0], p) for x, p in zip(results_df[test], p_adjusted)]
        
        return results_df
        
    def analyze_session_swr_spikes(self, session_id, regions, window_size=0.05):
        """
        Analyze firing rate changes around SWR events for multiple regions in a session.
        
        Parameters:
        -----------
        session_id : str
            Session ID
        regions : list
            List of region acronyms to analyze
        window_size : float
            Size of time windows in seconds
            
        Returns:
        --------
        dict
            Dictionary mapping regions to analysis results DataFrames
        """
        results = {}
        
        # Get SWR events for the session
        swr_events = self.explorer.find_best_events(
            dataset=self.dataset_name,
            session_id=session_id,
            probe_id=None,  # Will check all probes
            min_sw_power=MIN_SW_POWER,
            min_duration=MIN_DURATION,
            max_duration=MAX_DURATION,
            min_clcorr=0.8,
            exclude_gamma=True,
            exclude_movement=True
        )
        
        if len(swr_events) == 0:
            print(f"No SWR events found for session {session_id}")
            return results
            
        # Analyze each region
        for region in regions:
            region_units, spike_times = self.get_region_units_and_spikes(session_id, region)
            if len(spike_times) == 0:
                print(f"No units found in {region} for session {session_id}")
                continue
                
            results[region] = self.calculate_firing_rate_changes(spike_times, swr_events, window_size)
            
        return results


In [31]:
analyzer = SWRSpikeAnalyzer(
    cache_dir=CACHE_DIR,
    swr_input_dir=SWR_INPUT_DIR
)

# Analyze a single session
results = analyzer.analyze_session_swr_spikes(
    session_id="1044385384",
    regions=['RSC', 'SUB'],
    window_size=0.05  # 50ms windows
)

# Access results for a specific region
rsc_results = results['RSC']

NameError: name 'SWRExplorer' is not defined

In [9]:
units.head(2)

Unnamed: 0_level_0,ecephys_channel_id,ecephys_probe_id,ecephys_session_id,amplitude_cutoff,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate,cumulative_drift,d_prime,structure_acronym,...,valid_data,amplitude,waveform_duration,waveform_halfwidth,PT_ratio,recovery_slope,repolarization_slope,spread,velocity_above,velocity_below
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1157005856,1157001834,1046469925,1046166369,0.5,8453.0,3353.0,6719.0,140.32,6.088133,MB,...,True,143.066332,0.151089,0.096147,0.310791,-0.113863,0.480656,20.0,-0.457845,
1157005853,1157001834,1046469925,1046166369,0.323927,8453.0,3353.0,6719.0,239.76,4.635583,MB,...,True,90.709418,0.357119,0.192295,0.53149,-0.075261,0.366371,30.0,2.060302,-2.060302


In [25]:


session_id = 1044385384
session = cache.get_ecephys_session(
            ecephys_session_id=session_id)

core - cached version: 2.6.0-alpha, loaded version: 2.7.0
  self.warn_for_ignored_namespaces(ignored_namespaces)


In [27]:
session.spike_times

{1049374910: array([8.13544647e+00, 9.07790753e+00, 1.46260412e+01, ...,
        9.74317349e+03, 9.74346252e+03, 9.74406818e+03]),
 1049374988: array([  24.91088005,  135.38955639,  150.93656394, ..., 9739.97497389,
        9742.93272297, 9744.43748069]),
 1049374987: array([5.38499616e+00, 6.86205404e+00, 7.84468153e+00, ...,
        9.74658037e+03, 9.74674443e+03, 9.74677057e+03]),
 1049374986: array([  14.63547448,   43.67156848,   45.07339348, ..., 9727.90681233,
        9732.87331613, 9745.31140883]),
 1049374985: array([ 974.74623164,  998.16219239, 1043.76255455, ..., 8157.38915176,
        8316.40510614, 8877.03013892]),
 1049374984: array([  57.93185035,   93.23850705,  324.52629832, ..., 9747.66542816,
        9747.85126039, 9747.93149324]),
 1049374983: array([7.44198393e+00, 1.37272132e+01, 1.54345364e+01, ...,
        9.74512741e+03, 9.74743226e+03, 9.74763953e+03]),
 1049373961: array([1937.33114074, 1937.37430715, 1937.48257317, ..., 9747.89259348,
        9747.89616012,

In [None]:
session.spike_times

In [10]:
print(f'channels.columns : {channels.columns}')

channels.columns : Index(['ecephys_probe_id', 'ecephys_session_id', 'probe_channel_number',
       'probe_vertical_position', 'probe_horizontal_position',
       'anterior_posterior_ccf_coordinate', 'dorsal_ventral_ccf_coordinate',
       'left_right_ccf_coordinate', 'structure_acronym', 'unit_count',
       'valid_data'],
      dtype='object')


In [10]:
channels.head(2)

Unnamed: 0_level_0,ecephys_probe_id,ecephys_session_id,probe_channel_number,probe_vertical_position,probe_horizontal_position,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate,structure_acronym,unit_count,valid_data
ecephys_channel_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1049365509,1048089911,1047969464,0,20.0,43.0,8445.0,4013.0,6753.0,MRN,0,True
1049365511,1048089911,1047969464,1,20.0,11.0,8443.0,4005.0,6755.0,MRN,5,True


In [None]:

# =============================================================================
# SWRSpikeBehaviorVisualizer Class
# =============================================================================
class SWRSpikeBehaviorVisualizer:
    def __init__(self, cache_dir=CACHE_DIR, swr_input_dir=SWR_INPUT_DIR, dataset_name=DATASET_NAME):
        """Initialize the visualizer with cache directory and SWR input directory."""
        self.cache = VisualBehaviorNeuropixelsProjectCache.from_s3_cache(cache_dir=cache_dir)
        self.explorer = SWRExplorer(base_path=swr_input_dir)
        self.swr_input_dir = swr_input_dir
        self.dataset_name = dataset_name
        
    def verify_regions(self, session_id, target_regions):
        """Verify if session has units in target regions."""
        print(f"\n=== Verifying regions for session {session_id} ===")
        
        # Get full unit and channel tables
        units = self.cache.get_unit_table()
        channels = self.cache.get_channel_table()
        
        print("\nInitial data shapes:")
        print(f"Units table shape: {units.shape}")
        print(f"Channels table shape: {channels.shape}")
        
        print("\nUnits columns:")
        print(units.columns.tolist())
        print("\nChannels columns:")
        print(channels.columns.tolist())
        
        # Filter for good quality units
        good_units = units[units['quality'] == 'good']
        print(f"\nNumber of good quality units: {len(good_units)}")
        
        # Join units with channels to get structure information
        units_with_structure = good_units.merge(
            channels[['structure_acronym', 'probe_vertical_position']], 
            left_on='ecephys_channel_id', 
            right_index=True,
            how='left'
        )
        print(f"\nShape after joining with channels: {units_with_structure.shape}")
        
        # Filter for target regions
        region_units = units_with_structure[units_with_structure.structure_acronym.isin(target_regions)]
        print(f"\nNumber of units in target regions: {len(region_units)}")
        
        # Group by session and count units per region
        session_region_counts = region_units.groupby(['ecephys_session_id', 'structure_acronym']).size().unstack(fill_value=0)
        print("\nUnits per region per session:")
        print(session_region_counts)
        
        # Find sessions that have enough units in each target region
        valid_sessions = session_region_counts[
            (session_region_counts >= MIN_UNITS_PER_REGION).all(axis=1)
        ].index.tolist()
        
        print(f"\nSessions with enough units in all target regions: {len(valid_sessions)}")
        return valid_sessions
    
    def get_region_units(self, session_id, region):
        """Get units from a specific region."""
        print(f"\n=== Getting units for region {region} in session {session_id} ===")
        
        # Get full unit and channel tables
        units = self.cache.get_unit_table()
        channels = self.cache.get_channel_table()
        
        # Filter for good quality units
        good_units = units[units['quality'] == 'good']
        
        # Join units with channels to get structure information
        units_with_structure = good_units.merge(
            channels[['structure_acronym', 'probe_vertical_position']], 
            left_on='ecephys_channel_id', 
            right_index=True
        )
        
        # Filter for target region and session
        region_units = units_with_structure[
            (units_with_structure.structure_acronym == region) & 
            (units_with_structure.ecephys_session_id == session_id)
        ]
        
        print(f"\nNumber of units in {region} for session {session_id}: {len(region_units)}")
        
        # Get spike times for these units
        session = self.cache.get_ecephys_session(session_id)
        spike_times = {}
        for unit_id in region_units.index:
            spike_times[unit_id] = session.spike_times[unit_id]
        
        print(f"Number of units with spike times: {len(spike_times)}")
        return region_units, spike_times
    
    def calculate_swr_correlations(self, spike_times, swr_events, window=WINDOW_SIZE):
        """Calculate correlation between unit spiking and SWR events."""
        correlations = {}
        
        for unit_id, times in spike_times.items():
            # Create binary spike train
            spike_train = np.zeros(len(swr_events))
            
            for i, (_, event) in enumerate(swr_events.iterrows()):
                # Count spikes in window around event
                event_time = (event['start_time'] + event['end_time']) / 2
                spikes_in_window = np.sum((times >= event_time - window) & 
                                        (times <= event_time + window))
                spike_train[i] = spikes_in_window
                
            # Calculate correlation with SWR power
            if 'sw_peak_power' in swr_events.columns:
                correlation = np.corrcoef(spike_train, swr_events['sw_peak_power'])[0,1]
                correlations[unit_id] = correlation
                
        return correlations
    
    def get_top_correlated_units(self, correlations, percentile=90):
        """Get top percentile of units based on correlation values."""
        if not correlations:
            return []
            
        # Convert correlations to numpy array and remove NaN values
        corr_values = np.array(list(correlations.values()))
        valid_mask = ~np.isnan(corr_values)
        valid_correlations = corr_values[valid_mask]
        
        if len(valid_correlations) == 0:
            return []
            
        # Calculate threshold for top percentile
        threshold = np.percentile(valid_correlations, percentile)
        
        # Get unit IDs for top correlated units
        top_units = [unit_id for unit_id, corr in correlations.items() 
                    if not np.isnan(corr) and corr >= threshold]
        
        return top_units

    def print_correlation_stats(self, correlations, region):
        """Print summary statistics of correlations for a region."""
        if not correlations:
            print(f"\nNo correlations calculated for {region}")
            return
            
        corr_values = np.array(list(correlations.values()))
        valid_correlations = corr_values[~np.isnan(corr_values)]
        
        if len(valid_correlations) == 0:
            print(f"\nNo valid correlations for {region}")
            return
            
        print(f"\nCorrelation statistics for {region}:")
        print(f"Number of units: {len(valid_correlations)}")
        print(f"Mean correlation: {np.mean(valid_correlations):.3f}")
        print(f"Median correlation: {np.median(valid_correlations):.3f}")
        print(f"Min correlation: {np.min(valid_correlations):.3f}")
        print(f"Max correlation: {np.max(valid_correlations):.3f}")
        print(f"25th percentile: {np.percentile(valid_correlations, 25):.3f}")
        print(f"75th percentile: {np.percentile(valid_correlations, 75):.3f}")
    
    def find_best_sessions(self, target_regions, min_events=10, min_units=5, min_sw_power=1.5, 
                          min_duration=0.08, max_duration=0.1, min_clcorr=0.8, exclude_gamma=True, 
                          exclude_movement=True, max_speed=2.0, min_pupil_diameter=0.5, verbose_debugging=False):
        """
        Find sessions with good quality data and events.
        
        Parameters:
        -----------
        target_regions : list
            List of target brain regions
        min_events : int
            Minimum number of events required
        min_units : int
            Minimum number of units required per region
        min_sw_power : float
            Minimum sharp wave power threshold
        min_duration : float
            Minimum event duration in seconds
        max_duration : float
            Maximum event duration in seconds
        min_clcorr : float
            Minimum circular-linear correlation
        exclude_gamma : bool
            Whether to exclude events overlapping with gamma
        exclude_movement : bool
            Whether to exclude events overlapping with movement
        max_speed : float
            Maximum running speed during events
        min_pupil_diameter : float
            Minimum pupil diameter during events
        verbose_debugging : bool
            Whether to print detailed debug information
            
        Returns:
        --------
        list
            List of session IDs with good quality data
        """
        # Get full unit and channel tables from cache
        units = self.cache.get_unit_table()
        channels = self.cache.get_channel_table()
        
        # Filter for good quality units
        good_units = units[units['quality'] == 'good'].copy()
        
        # Join with channels to get structure information
        good_units = good_units.merge(
            channels[['ecephys_channel_id', 'structure_acronym']], 
            left_on='peak_channel_id',
            right_on='ecephys_channel_id',
            how='left'
        )
        
        # Get unique sessions
        sessions = good_units['ecephys_session_id'].unique()
        
        good_sessions = []
        all_events = []  # Store all events for summary statistics
        
        for session_id in sessions:
            # Get units for this session
            session_units = good_units[good_units['ecephys_session_id'] == session_id]
            
            # Count units per region
            region_counts = session_units['structure_acronym'].value_counts()
            
            # Check if we have enough units in target regions
            has_enough_units = all(region_counts.get(region, 0) >= min_units for region in target_regions)
            
            if has_enough_units:
                # Get events for this session
                events = self.explorer.find_best_events(
                    'allen_visbehave_swr_murphylab2024',
                    str(session_id),
                    'probeA',  # We'll check all probes
                    min_sw_power=min_sw_power,
                    min_duration=min_duration,
                    max_duration=max_duration,
                    min_clcorr=min_clcorr,
                    exclude_gamma=exclude_gamma,
                    exclude_movement=exclude_movement,
                    max_speed=max_speed,
                    min_pupil_diameter=min_pupil_diameter
                )
                
                if len(events) >= min_events:
                    good_sessions.append(session_id)
                    all_events.append(events)
        
        # Print summary statistics for all events
        if all_events:
            all_events_df = pd.concat(all_events)
            print("\nSummary statistics for all SWR events across good sessions:")
            print("\nDuration and power metrics:")
            print(all_events_df[['duration', 'sw_peak_power', 'ripple_peak_power', 'sw_ripple_clcorr']].describe())
            
            print("\nEvent characteristics (counts):")
            for col in ['overlaps_with_gamma', 'overlaps_with_movement', 'is_global']:
                if col in all_events_df.columns:
                    print(f"\n{col}:")
                    print(all_events_df[col].value_counts())
        
        return good_sessions
    
    def plot_swr_with_spikes(self, session_id, swr_event_idx, region_units, spike_times, 
                            running_speed, pupil_diameter, window=WINDOW_SIZE):
        """Plot SWR event with spiking activity, running speed, and pupil diameter."""
        # Get the SWR event
        swr_events = self.explorer.find_best_events(
            dataset=self.dataset_name,
            session_id=session_id,
            probe_id=None,
            min_sw_power=MIN_SW_POWER,
            min_duration=MIN_DURATION,
            max_duration=MAX_DURATION,
            min_clcorr=MIN_CLCORR,
            exclude_gamma=True,
            exclude_movement=True
        )
        
        if len(swr_events) == 0:
            print("No SWR events found")
            return
            
        event = swr_events.iloc[swr_event_idx]
        event_time = (event['start_time'] + event['end_time']) / 2
        
        # Calculate correlations for all units
        correlations = self.calculate_swr_correlations(spike_times, swr_events)
        
        # Get top correlated units
        top_units = self.get_top_correlated_units(correlations, percentile=90)
        
        # Sort units by depth (probe_vertical_position)
        sorted_units = sorted(
            [(unit_id, region_units.loc[unit_id, 'probe_vertical_position']) 
             for unit_id in top_units],
            key=lambda x: x[1]
        )
        
        # Create figure with subplots
        fig = plt.figure(figsize=(15, 10))
        gs = fig.add_gridspec(4, 1, height_ratios=[3, 1, 1, 1])
        
        # Plot 1: Raster plot of spikes
        ax1 = fig.add_subplot(gs[0])
        for i, (unit_id, _) in enumerate(sorted_units):
            # Get spikes in window
            times = spike_times[unit_id]
            mask = (times >= event_time - window) & (times <= event_time + window)
            spike_times_window = times[mask]
            
            # Plot spikes
            ax1.vlines(spike_times_window - event_time, i, i+1, 
                      color='k', alpha=0.5)
            
            # Add correlation value as text
            corr = correlations[unit_id]
            ax1.text(window + 0.01, i + 0.5, f'r={corr:.2f}', 
                    fontsize=8, verticalalignment='center')
            
        ax1.set_ylabel('Unit ID (sorted by depth)')
        ax1.set_title(f'SWR Event at {event_time:.2f}s\nTop {len(top_units)} correlated units')
        
        # Plot 2: Running speed
        ax2 = fig.add_subplot(gs[1], sharex=ax1)
        speed_mask = (running_speed['timestamps'] >= event_time - window) & \
                    (running_speed['timestamps'] <= event_time + window)
        ax2.plot(running_speed['timestamps'][speed_mask] - event_time,
                running_speed['speed'][speed_mask])
        ax2.set_ylabel('Speed (cm/s)')
        
        # Plot 3: Pupil diameter
        ax3 = fig.add_subplot(gs[2], sharex=ax1)
        pupil_mask = (pupil_diameter['timestamps'] >= event_time - window) & \
                    (pupil_diameter['timestamps'] <= event_time + window)
        ax3.plot(pupil_diameter['timestamps'][pupil_mask] - event_time,
                pupil_diameter['diameter'][pupil_mask])
        ax3.set_ylabel('Pupil Diameter')
        
        # Plot 4: SWR power
        ax4 = fig.add_subplot(gs[3], sharex=ax1)
        if 'sw_peak_power' in event:
            ax4.axvline(x=0, color='r', linestyle='--', alpha=0.5)
            ax4.set_xlabel('Time from SWR (s)')
            ax4.set_ylabel('SWR Power')
        
        plt.tight_layout()
        return fig

def find_best_sessions(swr_explorer, target_regions, min_units=5, verbose_debugging=False):
    """
    Find sessions with good quality data in target regions.
    
    Parameters:
    -----------
    swr_explorer : SWRExplorer
        Initialized SWRExplorer instance
    target_regions : list
        List of target brain regions
    min_units : int
        Minimum number of units required per region
    verbose_debugging : bool
        Whether to print detailed debug information
        
    Returns:
    --------
    list
        List of session IDs with good quality data
    """
    # Get channel and unit tables
    channels = swr_explorer.allensdk_cache.get_channel_table()
    units = swr_explorer.allensdk_cache.get_unit_table()
    
    print("\n=== Initial Data ===")
    print(f"Channels table shape: {channels.shape}")
    print(f"Units table shape: {units.shape}")
    
    # Filter channels for valid data and target regions
    valid_channels = channels[
        (channels['valid_data'] == True) & 
        (channels['structure_acronym'].isin(target_regions))
    ]
    
    print("\n=== Valid Channels ===")
    print(f"Number of valid channels in target regions: {len(valid_channels)}")
    print("\nChannels per region:")
    print(valid_channels['structure_acronym'].value_counts())
    
    # Get good quality units
    good_units = units[units['quality'] == 'good'].copy()
    print(f"\nNumber of good quality units: {len(good_units)}")
    
    # Join good units with valid channels
    units_with_channels = good_units.merge(
        valid_channels[['ecephys_channel_id', 'ecephys_session_id', 'structure_acronym']],
        left_on='peak_channel_id',
        right_on='ecephys_channel_id',
        how='inner'
    )
    
    print("\n=== Units in Target Regions ===")
    print(f"Number of good units in target regions: {len(units_with_channels)}")
    print("\nUnits per region:")
    print(units_with_channels['structure_acronym'].value_counts())
    
    # Count units per region per session
    session_region_counts = units_with_channels.groupby(['ecephys_session_id', 'structure_acronym']).size().unstack(fill_value=0)
    
    print("\n=== Session Analysis ===")
    print("\nUnits per region per session:")
    print(session_region_counts)
    
    # Find sessions with enough units in each target region
    good_sessions = session_region_counts[
        (session_region_counts >= min_units).all(axis=1)
    ].index.tolist()
    
    print(f"\nFound {len(good_sessions)} sessions with enough units in all target regions")
    print("Session IDs:", good_sessions)
    
    return good_sessions


In [None]:
print("\nUnits columns:")
print(units.columns.tolist())
print("\nChannels columns:")
print(channels.columns.tolist())

# Filter for good quality units
good_units = units[units['quality'] == 'good']
print(f"\nNumber of good quality units: {len(good_units)}")

# Join units with channels to get structure information
units_with_structure = good_units.merge(
    channels[['structure_acronym', 'probe_vertical_position']], 
    left_on='ecephys_channel_id', 
    right_index=True,
    how='left'
)
print(f"\nShape after joining with channels: {units_with_structure.shape}")

# Filter for target regions
region_units = units_with_structure[units_with_structure.structure_acronym.isin(target_regions)]
print(f"\nNumber of units in target regions: {len(region_units)}")

# Group by session and count units per region
session_region_counts = region_units.groupby(['ecephys_session_id', 'structure_acronym']).size().unstack(fill_value=0)
print("\nUnits per region per session:")
print(session_region_counts)
