Used to analyze photometry data when behavior is analyzed with Ethovision or events are flagged using the Synapse Notes feature.
Curve fit and motion correction are based on Simpson et al. 2024 (https://doi.org/10.1016/j.neuron.2023.11.016). Cursor AI software was used to assist with writing parts of this code.

In [None]:
import numpy as np
from sklearn.metrics import auc
import matplotlib.pyplot as plt  
import scipy.stats as stats

import tdt
import statsmodels.api as sm
import pandas as pd
import os
import pylab as plt
from scipy.signal import medfilt, butter, filtfilt
from scipy.stats import linregress
from scipy.optimize import curve_fit, minimize

### Import Data


Variables

In [None]:
#Indicate experiment type. 
#If you add a new experiment type, you will need to update EXPERIMENT_CONFIGS and ETHO_COLUMN_CONFIGS
experiment_type = 'robobug_manual'

# Define experiment configurations (i.e. Ethovision-tracked zones, or anything marked as 1 or 0 in first tab of etho output)
# To use the pattern matching in column_config, you need to have the zone name in the column name.
EXPERIMENT_CONFIGS = {
    'open_field': ['Edge', 'Middle', 'Center'],
    'ezm': ['Open', 'Closed'],
    'synapse_notes_only': [],
    'robobug_manual': []
}

ETHO_COLUMN_CONFIGS = {  #These reflect the column names in the etho output. They may change if you change Ethovision settings
    'open_field': {
        'pattern': 'In zone({arena_prefix}{zone} / Center-point)',
        'arena_mapping': {
            'Arena 1': 'A1',
            'Arena 2': 'A2'
        }
    },
    'ezm': {
        'pattern': 'In zone ({zone})' 
    },
    'robobug_manual':{
        'pattern': '{zone}'
    }
}

# Assign each mouse to a group, quotes around ID number
group_id = {
    'mouseID':'groupname',
    'mouseID':'groupname'
    }

group_ids = list(set(group_id.values()))

#Folder containing 1 day of photometry data from multiple animals
data_directory = './Data'

#Specify where to save the output. A timestamped subfolder will be created in this folder to hold all output
base_output_dir = './Output'

#Folder containing exported data from Ethovision. If you don't have Ethovision data, set this to 'no etho data'
etho_directory = './Ethovision'

#Settingsfor trimming data  
t_start = 8 # time threshold below which we will discard (in seconds)
t_end_target = 1200 # if video is shorter than this, it will not be trimmed (in seconds)

#Downsample rate
N = 100 # Average every N samples into 1 value

video_sampling_rate=0.05 # Video sampling rate (Hz)

#Settings for finding zone and behavior transitions
min_duration=1 # seconds they must be in the zone to count as a transition
#will also use video_sampling_rate defined above

#Set Peri-event window
PRE_TIME = 15 #  seconds before event to include
POST_TIME = 15 #  seconds after event to include

#Set time range for z score baseline and dFF baseline
base_start= -15
base_end= -11

#Set time range for AUC calculations
AUC_start= 0
AUC_end= 10

In [None]:
# Create timestamped subfolder for saving output
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
data_output = os.path.join(base_output_dir, f"analysis_{timestamp}")
print(f"Output will be saved to: {data_output}")

In [None]:
# Get a list of all folders in the 'data' directory
folderlist = [item for item in os.listdir(data_directory) if os.path.isdir(os.path.join(data_directory, item))]

# Extract metadata from the folder names and read in the data. 
#Creates 'alldat', a list of structs containing all the data 
masterdat = []
for foldername in folderlist:
    dat = {}
    
    if len(foldername) == 18:
        dat['mouse1'] = foldername[0:4]
        dat['mouse2'] = '0000'
        
        dat['date'] = foldername[5:11]
        
    else:    
        dat['mouse1'] = foldername[0:4]
        dat['mouse2'] = foldername[5:9]
    
        dat['date'] = foldername[10:16]

    dat['blockpath'] = foldername  # point to tanks (enter file path-this one uses the folders from above)

    # Assuming you have TDTbin2mat implemented or accessible
    dat['data'] = tdt.read_block(os.path.join(data_directory, dat['blockpath']))
    
    masterdat.append(dat)

### Pulling data and TTLs for each mouse

In [None]:
#The "_470A" etc. values and epocs may need to be changed if recording from a different photometry setup

alldat = []

for f in range(len(masterdat)):
    
    if masterdat[f]['mouse1'] != '0000':
        dat1 = {}
        dat1['mouseID'] = masterdat[f]['mouse1']
        dat1['green'] = masterdat[f]['data'].streams._470A.data
        dat1['isos'] = masterdat[f]['data'].streams._405A.data
        dat1['sampling_rate'] = masterdat[f]['data'].streams._470A.fs
        dat1['photom_time'] = (np.arange(1,len(dat1['green'])+1))/dat1['sampling_rate']
        dat1['video_time'] = masterdat[f]['data'].epocs.Cam1.onset
        if experiment_type == 'synapse_notes_only':
            dat1['ttl_notes'] = masterdat[f]['data'].epocs.Note.data.astype(int)
            dat1['ttl_notes_time'] = masterdat[f]['data'].epocs.Note.onset
        alldat.append(dat1)
    
    if masterdat[f]['mouse2'] != '0000':
        dat2 = {}
        dat2['mouseID'] = masterdat[f]['mouse2']
        dat2['green'] = masterdat[f]['data'].streams._470B.data
        dat2['isos'] = masterdat[f]['data'].streams._405B.data
        dat2['sampling_rate'] = masterdat[f]['data'].streams._470B.fs
        dat2['photom_time'] = (np.arange(1,len(dat2['green'])+1))/dat2['sampling_rate']
        dat2['video_time']=masterdat[f]['data'].epocs.Cam1.onset
        if experiment_type == 'synapse_notes_only':
            dat2['ttl_notes'] = masterdat[f]['data'].epocs.Note.data.astype(int)
            dat2['ttl_notes_time'] = masterdat[f]['data'].epocs.Note.onset
        alldat.append(dat2)

#assigning group name

for f in range(len(alldat)):
    alldat[f]['group'] = group_id[alldat[f]['mouseID']]

# cut so length matches
    
for f in range(len(alldat)):
    
    a = len(alldat[f]['green'])
    b = len(alldat[f]['isos'])
    if b < a:
        alldat[f]['green'] = alldat[f]['green'][0:b]
    if a < b:
        alldat[f]['isos'] = alldat[f]['isos'][0:a]

### Importing Ethovision Data

In [None]:
#No Ethovision files if using Synapse Notes Only

print(f"Current experiment_type is: '{experiment_type}'")
if experiment_type == 'synapse_notes_only':
    print('No Etho files for synapse_notes_only')

#The rest of the block will run for other experiment types
else:
    folder_path = etho_directory
    arenas = {}
    etho_output = {}
    etho_output_manual = {}  # separate dictionary for manual scoring sheets

    for file in os.listdir(folder_path):
        if file.endswith(".xlsx"):  
            file_path = os.path.join(folder_path, file)
            xls = pd.ExcelFile(file_path)  # Read the Excel file
                
            for sheet_name in xls.sheet_names:
                df = pd.read_excel(xls, sheet_name=sheet_name, header=None)  # Read each sheet without headers
                    
                # Find the row where "Eartag" is in column A (column index 0)
                eartag_row = df[df[0] == "Eartag"]
                if not eartag_row.empty:
                    eartag = str(eartag_row.iloc[0, 1])  # Get the value from column B (column index 1)
                else:
                    # If "Eartag" not found, search for "Animal ID" instead
                    animal_id_row = df[df[0] == "Animal ID"]
                    if not animal_id_row.empty:
                        eartag = str(animal_id_row.iloc[0, 1])  # Get the value from column B (column index 1)
                    else:
                        continue  # Skip if neither "Eartag" nor "Animal ID" found
                        
                # Find the row where "Arena" is in column A
                arena_row = df[df[0] == "Arena name"]
                if not arena_row.empty:
                    arena = str(arena_row.iloc[0, 1])  # Get the value from column B (column index 1)
                else:
                    continue  # Skip if no "Arena" found           
                
                # Find the row where "Trial time" is in column A
                trial_time_row = df[df[0] == "Trial time"]
                if not trial_time_row.empty:
                    trial_start_index = trial_time_row.index[0]  # Get the row index
                        
                    # Ensure there is a row after "Trial time" before attempting to drop
                    if trial_start_index + 1 < len(df):
                        df = df.drop(index=trial_start_index + 1)  # Drop the row after "Trial time"
                        
                    # Extract from "Trial time" row onward
                    trial_df = df.iloc[trial_start_index:].reset_index(drop=True)
                        
                    # Set first row as column headers
                    trial_df.columns = trial_df.iloc[0]  
                    trial_df = trial_df[1:].reset_index(drop=True)  # Remove header row
                else:
                    continue  # Skip if no "Trial time" found
                    
                # Store DataFrame in appropriate dictionary based on sheet name
                if "Manual" in sheet_name:
                    etho_output_manual[eartag] = trial_df
                else:
                    etho_output[eartag] = trial_df
                
                arenas[eartag] = arena

    # Now `etho_output` holds automated scoring DataFrames and `etho_output_manual` holds manual scoring DataFrames
    print(f"Processed {len(etho_output)} automated scoring sheets and {len(etho_output_manual)} manual scoring sheets")

In [None]:
# assign etho data to alldat

print(f"Current experiment_type is: '{experiment_type}'")
print()

if experiment_type == 'synapse_notes_only':
    print('No Etho files for synapse_notes_only')

else:
    for entry in alldat:
        mouse_id = entry['mouseID']  # Extract mouseID from alldat
        print(f"mouse_id:{mouse_id}")
        if mouse_id in etho_output:  # Check if it matches an Eartag in data_dict
            entry['etho_raw_data'] = etho_output[mouse_id]  # Assign the matching DataFrame
            print(" etho_output saved")
        else:
            entry['etho_raw_data'] = None  # Optional: Assign None if no match is found
            print(" etho_output not saved")

        if mouse_id in etho_output_manual:  # Check if it matches an Eartag in data_dict
            entry['etho_manual_data'] = etho_output_manual[mouse_id]  # Assign the matching DataFrame
            print(" etho_output_manual saved")
        else:
            entry['etho_manual_data'] = None  # Optional: Assign None if no match is found
            print(" etho_output_manual not saved")

        if mouse_id in arenas:  # Check if it matches an Eartag in data_dict
            entry['arena'] = arenas[mouse_id]  # Assign the matching DataFrame
            print(f" arena is {entry['arena']}")
        else:
            entry['arena'] = None  # Optional: Assign None if no match is found
            print(" no arena found")
        print()

In [None]:
# pull out velocity and time data

print(f"Current experiment_type is: '{experiment_type}'")
print()

if experiment_type == 'synapse_notes_only':
    print('No Etho files for synapse_notes_only')

else:
    for entry in alldat:
        print(f"mouse_id:{entry['mouseID']}")
        if 'etho_raw_data' in entry and isinstance(entry['etho_raw_data'], pd.DataFrame):
            if 'Velocity' in entry['etho_raw_data'].columns:
                entry['velocity'] = entry['etho_raw_data']['Velocity'].values 
                print(" velocity saved")
            else:
                entry['velocity'] = []  # Assign an empty list if the column is missing
                print(" velocity not saved")
        else:
            entry['velocity'] = []  # Assign an empty list if 'etho_raw_data' is missing or not a DataFrame  
            print(" velocity not saved")

        if 'etho_raw_data' in entry and isinstance(entry['etho_raw_data'], pd.DataFrame):
            if 'Trial time' in entry['etho_raw_data'].columns:
                entry['Etho_time'] = entry['etho_raw_data']['Trial time'].values 
                print(" Etho_time saved")
            else:
                entry['Etho_time'] = []  # Assign an empty list if the column is missing
                print(" Etho_time not saved")
        else:
            entry['Etho_time'] = []  # Assign an empty list if 'etho_raw_data' is missing or not a DataFrame 
            print(" Etho_time not saved")
        print()



In [None]:
# pull out zone data

print(f"Current experiment_type is: '{experiment_type}'")
if experiment_type == 'synapse_notes_only':
    print('No Etho files for synapse_notes_only')

else:
    # Get zones and column config for current experiment type
    zones = EXPERIMENT_CONFIGS[experiment_type]
    column_config = ETHO_COLUMN_CONFIGS[experiment_type]
    
    # Universal function to extract zone data
    def extract_zone_data(entry, zones, column_config):
        saved_info = []
        for zone in zones:
            column_name = None
            
            # Use pattern with arena prefix (like open_field)
            if 'arena_mapping' in column_config:
                arena = entry.get('arena', '')
                arena_prefix = column_config['arena_mapping'].get(arena, '')
                column_name = column_config['pattern'].format(zone=zone, arena_prefix=arena_prefix)
            
            # Use simple pattern (like NISF and EZM)
            else:
                column_name = column_config['pattern'].format(zone=zone)
            
            # Extract the data if we found a valid column name
            if column_name and column_name in entry['etho_raw_data'].columns:
                entry[zone] = entry['etho_raw_data'][column_name].values
                saved_info.append(f"{zone} from column [{column_name}]")
        
        return saved_info

    # Main processing loop
    for entry in alldat:
        if 'etho_raw_data' in entry and isinstance(entry['etho_raw_data'], pd.DataFrame):
            saved_info = extract_zone_data(entry, zones, column_config)
            mouse_id = entry.get('mouseID', 'Unknown')
            arena = entry.get('arena', 'No arena')
            print(f"Mouse {mouse_id} ({arena}):")
            if saved_info:
                for info in saved_info:
                    print(f"    {info}")
            else:
                print("    No zones saved")
        else:
            mouse_id = entry.get('mouseID', 'Unknown')
            print(f"Mouse {mouse_id}: No Etho data available")

In [None]:
#check ethovision additions; length should match the number of entries on ethovision raw data output

print(f"Current experiment_type is: '{experiment_type}'")
if experiment_type == 'synapse_notes_only':
    print('No Etho files for synapse_notes_only')

else:
    # Get zones for current experiment type
    zones = EXPERIMENT_CONFIGS[experiment_type]
    
    for f in range(len(alldat)):
        print(f"eartag {alldat[f]['mouseID']}")
        
        # Check velocity if available
        if 'velocity' in alldat[f] and len(alldat[f]['velocity']) > 0:
            print(f"    velocity array length {len(alldat[f]['velocity'])}")
        
        # Check etho time if available
        if 'Etho_time' in alldat[f] and len(alldat[f]['Etho_time']) > 0:
            print(f"    etho time array length {len(alldat[f]['Etho_time'])}")
        
        # Check zone data dynamically
        for zone in zones:
            if zone in alldat[f] and len(alldat[f][zone]) > 0:
                print(f"    {zone.lower()} array length {len(alldat[f][zone])}")
        
        print()

### Remove artifact at start and downsample

In [None]:
# remove artifact at start, also trim ethovision data since it's time aligned to video_time

zones = EXPERIMENT_CONFIGS[experiment_type]  # get the zones for current experiment type

for f in range(len(alldat)):
    t_end = min(max(alldat[f]['video_time']), t_end_target)  
    
    idx_start_photom = np.argmax(alldat[f]['photom_time'] > t_start)
    idx_start_video = np.argmax(alldat[f]['video_time'] > t_start)
    idx_end_video = np.argmax(alldat[f]['video_time'] > t_end) if t_end < max(alldat[f]['video_time']) else len(alldat[f]['video_time'])
    idx_end_photom = np.argmax(alldat[f]['photom_time'] > t_end) if t_end < max(alldat[f]['photom_time']) else len(alldat[f]['photom_time'])
   
    alldat[f]['trimmed_photom_time'] = alldat[f]['photom_time'][idx_start_photom:idx_end_photom]
    alldat[f]['trimmed_green'] = alldat[f]['green'][idx_start_photom:idx_end_photom]
    alldat[f]['trimmed_isos'] = alldat[f]['isos'][idx_start_photom:idx_end_photom]
    alldat[f]['trimmed_video_time'] = alldat[f]['video_time'][idx_start_video:idx_end_video]

    if experiment_type != 'synapse_notes_only':
        alldat[f]['trimmed_velocity'] = alldat[f]['velocity'][idx_start_video:idx_end_video]
        alldat[f]['trimmed_Etho_time'] = alldat[f]['Etho_time'][idx_start_video:idx_end_video]
    
    # Trim zone data based on experiment type
    for zone in zones:
        alldat[f][f'trimmed_{zone}'] = alldat[f][zone][idx_start_video:idx_end_video]

In [None]:
#check trims and make sure they are the same length
# camera length should match ethovision velocity and zone lengths

for f in range(len(alldat)):
    print(f"eartag {alldat[f]['mouseID']}")
    print(f"photom_time length {len(alldat[f]['trimmed_photom_time'])}  photom_time start {alldat[f]['trimmed_photom_time'][0]}  photom_time end {alldat[f]['trimmed_photom_time'][len(alldat[f]['trimmed_photom_time'])-1]}")
    print(f"video_time length {len(alldat[f]['trimmed_video_time'])}  video_time start {alldat[f]['trimmed_video_time'][0]}  video_time end {alldat[f]['trimmed_video_time'][len(alldat[f]['trimmed_video_time'])-1]}")
    if experiment_type != 'synapse_notes_only':
        print(f"velocity length {len(alldat[f]['trimmed_velocity'])}")
        print(f"Etho time length {len(alldat[f]['trimmed_Etho_time'])}")
    
    zones = EXPERIMENT_CONFIGS[experiment_type]
    for zone in zones:
        print(f"{zone} length {len(alldat[f][f'trimmed_{zone}'])}")
    print()

In [None]:
# Downsample and average 

for f in range(len(alldat)):
    F_green=[]
    for i in range(0, len(alldat[f]['trimmed_green']), N):
        small_list = np.mean(alldat[f]['trimmed_green'][i:i+N-1])
        F_green.append(small_list)
    alldat[f]['downsampled_green'] = np.array(F_green)
    
    F_isos=[]
    for i in range(0, len(alldat[f]['trimmed_isos']), N):
        small_lst = np.mean(alldat[f]['trimmed_isos'][i:i+N-1])
        F_isos.append(small_lst)
    alldat[f]['downsampled_isos'] = np.array(F_isos)
    
    alldat[f]['downsampled_photom_time'] = alldat[f]['trimmed_photom_time'][::N]
    

### Fit double exponential to each curve and use to correct for bleaching

In [None]:
def double_exponential(t, const, amp_fast, amp_slow, tau_slow, tau_multiplier):
    """Compute a double exponential function with constant offset.
    Parameters
    ----------
    t : array-like
        Time vector in seconds
    const : float
        Amplitude of the constant offset
    amp_fast : float
        Amplitude of the fast component
    amp_slow : float
        Amplitude of the slow component
    tau_slow : float
        Time constant of slow component in seconds
    tau_multiplier : float
        Time constant of fast component relative to slow (tau_fast = tau_slow * tau_multiplier)
    
    Returns
    -------
    array-like
        The computed double exponential values at each time point
    """
    tau_fast = tau_slow * tau_multiplier
    return const + amp_slow * np.exp(-t/tau_slow) + amp_fast * np.exp(-t/tau_fast)

def get_bounds(trace, timecourse, tau_slow_min=0.0001, tau_slow_max=30):
    """Calculate parameter bounds for double exponential fitting.
    
    This function determines appropriate bounds for the parameters used in double exponential
    fitting based on the characteristics of the input signal and time vector.
    
    Parameters
    ----------
    trace : array-like
        The signal trace to be fit
    timecourse : array-like
        Time vector corresponding to the trace data points
    tau_slow_min : float, optional
        Minimum value for tau_slow as a fraction of total time duration (default: 0.0001)
    tau_slow_max : float, optional
        Maximum value for tau_slow as a fraction of total time duration (default: 30)
    
    Returns
    -------
    tuple of lists
        Two lists containing the lower and upper bounds respectively for:
        [amp_min, amp_min, amp_min, time_constant_min, offset_min],
        [amp_max, amp_max, amp_max, time_constant_max, offset_max]
        These bounds are used for fitting the double exponential function parameters.
    """
    signal = trace
    
    # Amplitude bounds
    amp_min = 0  # Assuming amplitude cannot be negative
    amp_max = 2 * np.max(signal)  # Allowing for some flexibility

    # Time constant bounds based on the duration of the experiment
    time_constant_min = tau_slow_min * (timecourse[-1] - timecourse[0])  # Minimum time constant
    time_constant_max = tau_slow_max * (timecourse[-1] - timecourse[0])  # Maximum time constant

    # Offset bounds
    offset_min = np.min(signal) if np.min(signal) < 0 else 0  # Adjust based on signal characteristics
    offset_max = np.max(signal)

    return (
        [amp_min, amp_min, amp_min, time_constant_min, offset_min],
        [amp_max, amp_max, amp_max, time_constant_max, offset_max]
    )

def process_photometry_signals(alldat):
    """Process photometry signals by fitting and removing exponential bleaching curves.
    
    For each entry in alldat, fits double exponential curves to both signals and
    subtracts them to correct for bleaching. Creates side-by-side plots showing
    raw signals with fits and detrended signals.
    
    Parameters
    ----------
    alldat : list of dict
        List of data dictionaries containing photometry signals
    
    Returns
    -------
    None
        Modifies alldat in place, adding:
        - 'expfitgreen', 'expfitisos': fitted curves
        - 'detrended_green', 'detrended_isos': bleaching-corrected signals
    """
    for f in range(len(alldat)):
        # Get signals
        green = alldat[f]['downsampled_green']
        isos = alldat[f]['downsampled_isos']
        time = alldat[f]['downsampled_photom_time']
        
        # Fit green signal
        max_sig = np.max(green)
        inital_params = [max_sig/2, max_sig/4, max_sig/4, 3600, 0.1]
        bounds = get_bounds(green, time)
        green_parms, _ = curve_fit(double_exponential, time, green, 
                                 p0=inital_params, bounds=bounds, maxfev=1000)
        green_expfit = double_exponential(time, *green_parms)
        alldat[f]['expfitgreen'] = green_expfit
        
        # Fit isosbestic signal
        max_sig = np.max(isos)
        inital_params = [max_sig/2, max_sig/4, max_sig/4, 3600, 0.1]
        bounds = get_bounds(isos, time)
        isos_parms, _ = curve_fit(double_exponential, time, isos, 
                                 p0=inital_params, bounds=bounds, maxfev=1000)
        isos_expfit = double_exponential(time, *isos_parms)
        alldat[f]['expfitisos'] = isos_expfit
        
        # Detrend signals
        green_detrended = green - green_expfit
        isos_detrended = isos - isos_expfit
        alldat[f]['detrended_green'] = green_detrended
        alldat[f]['detrended_isos'] = isos_detrended

        # Create figure with two subplots side by side
        fig = plt.figure(figsize=(20, 8))
        
        # Plot 1: Raw signals with fits
        ax1 = plt.subplot(1, 2, 1)
        plot1 = ax1.plot(time, green, 'g', label='green')
        plot3 = ax1.plot(time, green_expfit, 'k', linewidth=1.5, label='Exponential fit')
        ax1.set_xlabel('Time (seconds)')
        ax1.set_ylabel('Green Signal (V)', color='g')
        ax1.tick_params(axis='y', labelcolor='g')
        
        ax1_twin = ax1.twinx()
        plot2 = ax1_twin.plot(time, isos, 'b', label='isos')
        plot4 = ax1_twin.plot(time, isos_expfit, 'k', linewidth=1.5)
        ax1_twin.set_ylabel('Isos Signal (V)', color='b')
        ax1_twin.tick_params(axis='y', labelcolor='b')
        
        # Combine legends
        lines = plot1 + plot2 + plot3
        labels = [l.get_label() for l in lines]
        ax1.legend(lines, labels, loc='upper right')
        ax1.set_title(f"{alldat[f]['mouseID']} - Raw Signals")
        
        # Plot 2: Detrended signals
        ax2 = plt.subplot(1, 2, 2)
        plot5 = ax2.plot(time, green_detrended, 'g', label='green')
        ax2.set_xlabel('Time (seconds)')
        ax2.set_ylabel('Green Signal (V)', color='g')
        ax2.tick_params(axis='y', labelcolor='g')
        
        ax2_twin = ax2.twinx()
        plot6 = ax2_twin.plot(time, isos_detrended, 'b', label='isos')
        ax2_twin.set_ylabel('Isos Signal (V)', color='b')
        ax2_twin.tick_params(axis='y', labelcolor='b')
        
        # Combine legends
        lines = plot5 + plot6
        labels = [l.get_label() for l in lines]
        ax2.legend(lines, labels, loc='upper right')
        ax2.set_title(f"{alldat[f]['mouseID']} - Detrended Signals")
        
        plt.tight_layout()
        plt.show() 

def plot_motion_correction(time, isos_detrended, green_detrended, green_corrected, green_est_motion, slope, r_value, mouseID):
    """Plot motion correction analysis with correlation and correction plots side by side.
    
    Parameters
    ----------
    time : array-like
        Time points for x-axis
    isos_detrended : array-like
        Detrended isosbestic signal
    green_detrended : array-like
        Detrended green signal before motion correction
    green_corrected : array-like
        Motion-corrected green signal
    green_est_motion : array-like
        Estimated motion artifact
    slope : float
        Regression slope between isos and green
    r_value : float
        Correlation coefficient
    mouseID : str
        Mouse identifier for plot title
    """
    fig = plt.figure(figsize=(20, 8))
    
    # Plot 1: Correlation scatter plot
    ax1 = plt.subplot(1, 2, 1)
    ax1.scatter(isos_detrended[::5], green_detrended[::5], alpha=0.1, marker='.')
    x = np.array([np.min(isos_detrended), np.max(isos_detrended)])
    intercept = np.mean(green_detrended) - slope * np.mean(isos_detrended)
    ax1.plot(x, intercept + slope*x, 'b', linewidth=2)  # Changed from 'r' to 'b'
    ax1.set_xlabel('Isos')
    ax1.set_ylabel('Green')
    ax1.set_title(f"{mouseID} - Motion Correlation\nSlope: {slope:.3f}, RÂ²: {r_value**2:.3f}")
    
    # Plot 2: Signal correction
    ax2 = plt.subplot(1, 2, 2)
    plot1 = ax2.plot(time, green_detrended, 'b', label='Pre motion correction', alpha=0.5)
    plot2 = ax2.plot(time, green_corrected, 'g', label='Motion corrected', alpha=0.5)
    plot3 = ax2.plot(time, green_est_motion - 1.05, 'y', label='Estimated motion')
    
    ax2.set_xlabel('Time (seconds)')
    ax2.set_ylabel('Green Signal (V)')
    ax2.set_title(f"{mouseID} - Motion Correction")
    ax2.legend(loc='upper right')
    ax2.set_xlim(60, 120)  # 60 sec window
    
    plt.tight_layout()
    plt.show()

def process_motion_correction(alldat):
    """Process motion correction for all recordings.
    
    Parameters
    ----------
    alldat : list of dict
        List of data dictionaries containing photometry signals
    """
    for f in range(len(alldat)):
        # Get detrended signals
        green_detrended = alldat[f]['detrended_green']
        isos_detrended = alldat[f]['detrended_isos']
        time = alldat[f]['downsampled_photom_time']

        # Calculate correlation between signals
        slope, intercept, r_value, p_value, std_err = linregress(x=isos_detrended, y=green_detrended)
        print(f"\nMouse {alldat[f]['mouseID']}:")
        print('Slope    : {:.3f}'.format(slope))
        print('R-squared: {:.3f}'.format(r_value**2))

        # Calculate and remove motion effect
        green_est_motion = intercept + slope * isos_detrended
        green_corrected = green_detrended - green_est_motion
        alldat[f]['corrected_green'] = green_corrected

        # Plot correlation and correction
        plot_motion_correction(time, isos_detrended, green_detrended, 
                             green_corrected, green_est_motion, 
                             slope, r_value, alldat[f]['mouseID']) 

Curve fit

In [None]:
# Process and plot photometry signals
process_photometry_signals(alldat)

Motion correction

In [None]:
# Process motion correction
process_motion_correction(alldat)

### Map event timestamps to photometry time

In [None]:
#Functions to identify entries into zones. Can also be used for any other data output as 1's and 0's, such as freezing

def find_zeros_to_ones(arr, min_duration, sampling_rate): 
    min_samples = int(min_duration / sampling_rate)
    arr = np.array(arr)
    valid_mask = (arr == 0) | (arr == 1)
    filtered_arr = arr[valid_mask]
    
    # Indices in the filtered (cleaned) array where transition occurs
    transition_indices = np.where(np.diff(filtered_arr) == 1)[0] + 1
    
    # Check that the value after transition stays 1 for at least min_duration
    valid_transitions = []
    for idx in transition_indices:
        # Count how many 1s follow this index
        count = 1
        while (idx + count < len(filtered_arr)) and (filtered_arr[idx + count] == 1):
            count += 1
        if count >= min_samples:
            valid_transitions.append(idx)
    
    original_indices = np.flatnonzero(valid_mask)
    return original_indices[valid_transitions]

def find_ones_to_zeros(arr, min_duration, sampling_rate): 
    min_samples = int(min_duration / sampling_rate)
    arr = np.array(arr)
    valid_mask = (arr == 0) | (arr == 1)
    filtered_arr = arr[valid_mask]
    
    transition_indices = np.where(np.diff(filtered_arr) == -1)[0] + 1
    
    valid_transitions = []
    for idx in transition_indices:
        count = 1
        while (idx + count < len(filtered_arr)) and (filtered_arr[idx + count] == 0):
            count += 1
        if count >= min_duration:
            valid_transitions.append(idx)
    
    original_indices = np.flatnonzero(valid_mask)
    return original_indices[valid_transitions]

In [None]:
# Find indices of zone entries and exits. 

if experiment_type == 'synapse_notes_only':
    print('No zone data for synapse_notes_only')

else:       
    for f in range(len(alldat)):
        zones = EXPERIMENT_CONFIGS[experiment_type]
        
        print(f"Mouse {alldat[f]['mouseID']}:")
        
        # Process each zone
        for zone in zones:
            zone_status = alldat[f][f'trimmed_{zone}']
            zone_entries = find_zeros_to_ones(zone_status, min_duration=min_duration, sampling_rate=video_sampling_rate) 
            zone_exits = find_ones_to_zeros(zone_status, min_duration=min_duration, sampling_rate=video_sampling_rate)
            
            # Save these onsets back into data structure 
            entry_key = f'{zone.lower()}_entries'
            exit_key = f'{zone.lower()}_exits'
            alldat[f][entry_key] = zone_entries
            alldat[f][exit_key] = zone_exits
            
            # Print summary for each zone
            print(f"  {zone}: {len(zone_entries)} entries, {len(zone_exits)} exits")


In [None]:
# Pulling out synapse ttl notes timestamps. If you have multiple types of notes, they will be separated into 'bit1', 'bit2', etc.

if experiment_type != 'synapse_notes_only':
    print('No TTL data')

else:
    for f in range(len(alldat)): 
        bitvalues = [int(str(TTL)[0]) for TTL in alldat[f]['ttl_notes']]
        timestamps = alldat[f]['ttl_notes_time']
        
        # Get unique bit values
        unique_bits = np.unique(bitvalues)
        
        print(f"Dataset {f} (Mouse {alldat[f]['mouseID']}):")
        print(f"  Found bits: {unique_bits}")
        
        # Create a dictionary or separate arrays for each unique bit value
        for bit in unique_bits:
            # Create array name like 'bit1_timestamps', 'bit2_timestamps', etc.
            array_name = f'bit{bit}'
            # Get timestamps where bitvalues equals this bit
            bit_timestamps = timestamps[np.array(bitvalues) == bit]
            alldat[f][array_name] = bit_timestamps
            
            # Print summary for each bit
            print(f"  {array_name}: {len(bit_timestamps)} timestamps")
            
        # Keep the original arrays 
        alldat[f]['bit_values'] = np.array(bitvalues)
        alldat[f]['bit_timestamps'] = timestamps

In [None]:
# Identifying manual scoring events

# First, collect all unique behaviors from all mice
all_unique_behaviors_base = set()

for f in range(len(alldat)):
    #Check if entry has manual data
    if 'etho_manual_data' in alldat[f] and isinstance(alldat[f]['etho_manual_data'], pd.DataFrame):
        unique_behaviors_base = alldat[f]['etho_manual_data']['Behavior'].unique()
        all_unique_behaviors_base.update(unique_behaviors_base)

# Convert to sorted list for consistency
all_unique_behaviors_base = sorted(list(all_unique_behaviors_base))

# Expand unique_behaviors to include start and stop variants for ALL behaviors
all_unique_behaviors = []
for behavior in all_unique_behaviors_base:
    all_unique_behaviors.append(behavior)
    all_unique_behaviors.append(f"{behavior}_start")
    all_unique_behaviors.append(f"{behavior}_stop")

# Now process each mouse's data
for f in range(len(alldat)):
    #Check if entry has manual data
    if 'etho_manual_data' in alldat[f] and isinstance(alldat[f]['etho_manual_data'], pd.DataFrame):
        etho_time = alldat[f]['trimmed_Etho_time']
        #Get unique behaviors for this mouse
        unique_behaviors_base = alldat[f]['etho_manual_data']['Behavior'].unique()
        
        print(f"Mouse {alldat[f]['mouseID']}:")
        
        #Create arrays for each behavior and store in alldat
        for behavior_base in unique_behaviors_base:
            # Get timestamps for this behavior
            behavior_data = alldat[f]['etho_manual_data'][
                alldat[f]['etho_manual_data']['Behavior'] == behavior_base
            ]
            
            # Separate start and stop times
            start_times = behavior_data[
                behavior_data['Event'] == 'state start'
            ]['Trial time'].values
            
            stop_times = behavior_data[
                behavior_data['Event'] == 'state stop'
            ]['Trial time'].values
            
            # Store start times
            alldat[f][f"{behavior_base}_start_times"] = start_times
            
            # Store stop times
            alldat[f][f"{behavior_base}_stop_times"] = stop_times
            
            # Process start events
            if len(start_times) > 0:
                start_indices = np.searchsorted(etho_time, start_times)
                start_indices = np.clip(start_indices, 0, len(etho_time) - 1)
                mask = (start_indices > 0) & (np.abs(etho_time[start_indices-1] - start_times) < 
                                              np.abs(etho_time[start_indices] - start_times))
                start_indices[mask] -= 1
                alldat[f][f"{behavior_base}_start"] = start_indices
            else:
                alldat[f][f"{behavior_base}_start"] = np.array([])
            
            # Process stop events
            if len(stop_times) > 0:
                stop_indices = np.searchsorted(etho_time, stop_times)
                stop_indices = np.clip(stop_indices, 0, len(etho_time) - 1)
                mask = (stop_indices > 0) & (np.abs(etho_time[stop_indices-1] - stop_times) < 
                                             np.abs(etho_time[stop_indices] - stop_times))
                stop_indices[mask] -= 1
                alldat[f][f"{behavior_base}_stop"] = stop_indices
            else:
                alldat[f][f"{behavior_base}_stop"] = np.array([])
            
            # Print summary for each behavior
            print(f"  {behavior_base}: {len(start_times)} start events, {len(stop_times)} stop events")
        
        # Store the comprehensive unique_behaviors list (same for all mice)
        alldat[f]['unique_behaviors'] = all_unique_behaviors
 
    else:
        print(f"Mouse {alldat[f]['mouseID']}: No manual scoring data")

# Store the comprehensive list as a global variable 
unique_behaviors = all_unique_behaviors

In [None]:
print(unique_behaviors)

#### Defining time range for snips

In [None]:
fs = alldat[0]['sampling_rate']/N; # must account for downsampling w/ N, assuming sampling rate is same in all mice

#use downsampled photom time and trimmed camera time
photom_time_interval = alldat[0]['downsampled_photom_time'][1]-alldat[0]['downsampled_photom_time'][0] 
video_time_interval = np.diff(alldat[9]['trimmed_video_time']) 
avg_video_time_interval = np.mean(video_time_interval)
print(f"Photom time interval: {photom_time_interval}")
print(f"Avg video time interval: {avg_video_time_interval}")

#time span for peri-event filtering   
#TRANGE is range of indices ex. -100, 200 is pre_time 100 indices and post_time 200 indices
TRANGE = [int(-1*PRE_TIME/photom_time_interval),int(POST_TIME/photom_time_interval)]
print(f"TRANGE: {TRANGE}")
idx = np.argmax(alldat[0]['downsampled_photom_time'] > TRANGE[0])

#peri_time new timescale of the peri event
peri_time = np.array(range(int(TRANGE[1] - TRANGE[0]))) / (1/photom_time_interval) - PRE_TIME

print(f"len(peri_time): {len(peri_time)}")
print(f"range(photom_time snip): {alldat[0]['downsampled_photom_time'][0]} to {alldat[0]['downsampled_photom_time'][len(peri_time)-1]}")

#### Functions to map indices to photom time, pull out snips, and plot data

In [None]:
#Function to find indices in photometry time that match events in etho time
def map_etho_indices_to_photom_indices(alldat, index_array_name, output_name):
    for entry in alldat:
        # Check if the event exists for this mouse
        if index_array_name not in entry:
            # If the event doesn't exist for this mouse, create an empty array
            entry[output_name] = np.array([])
            continue
        
        indices = entry[index_array_name]
        
        # Check if indices array is empty
        if len(indices) == 0:
            entry[output_name] = np.array([])
            continue
        
        vid_time = entry['trimmed_video_time']
        photom_time = entry['downsampled_photom_time']

        # Get Vid_time values at specified indices
        event_time = vid_time[indices]

        # Find the closest index in time_array for each Etho_time value
        closest_indices = [np.argmin(np.abs(photom_time - x)) for x in event_time]

        # Save to the dictionary
        entry[output_name] = np.array(closest_indices)

# Example usage
# map_etho_indices_to_photom_indices(alldat, 'edge_entries', 'mapped_edge_entries')

#Function to find indices in photometry time that match TTL timestamps
def map_timestamps_to_photom_indices(alldat, timestamp_array_name, output_name):
    for entry in alldat:
        photom_time = entry['downsampled_photom_time']
        timestamps = entry[timestamp_array_name]  # These are already timestamps, not indices

        # Find the closest index in photom_time for each timestamp
        closest_indices = [np.argmin(np.abs(photom_time - x)) for x in timestamps]

        # Save to the dictionary
        entry[output_name] = np.array(closest_indices)

#Function to pull out photometry data surrounding indexed event timestamps
def extract_snips(alldat, trange, onset_key, snips_key, mean_key, trace):
    """
    Extracts snippets around event onsets for each entry in alldat.

    Parameters:
    - alldat: List of dictionaries containing data for each session
    - trange: Tuple (start_offset, end_offset) defining snippet range relative to onset
    - onset_key: Key in alldat[f] containing event onset indices
    - snips_key: Key to store extracted dFF snippets
    - mean_key: Key to store mean dFF signal across trials
    - trace: Key for the signal trace to analyze 

    Modifies alldat in place by adding the extracted snippets and their mean.
    """

    correct_size = trange[1] - trange[0]

    for f in range(len(alldat)): 
        indices = alldat[f][onset_key]

        trials = len(indices)
        snips = [None] * trials
        pre_stim = np.zeros(trials)
        post_stim = np.zeros(trials)

        for i in range(trials):
            pre_stim[i] = max(0, (indices[i] + trange[0]))
            post_stim[i] = min((indices[i] + trange[1]), len(alldat[f][trace])) 

            snips[i] = alldat[f][trace][int(pre_stim[i]):int(post_stim[i])]

        # Remove NaNs and ensure correct size
        snips = [arr for arr in snips if isinstance(arr, np.ndarray) and not np.isnan(arr).any()]
        snips = [arr for arr in snips if len(arr) == correct_size]
        
        # Compute mean dFF signal
        mean = np.mean(np.array(snips), axis=0) if snips else np.array([])

        # Store results in alldat
        alldat[f][snips_key] = snips
        alldat[f][mean_key] = mean

#Function to plot event means
def plot_event_means(events_to_plot, suffix='_corrected_mean', peri_time=peri_time, alldat=alldat, group_ids=group_ids):
    """
    Plot both individual traces and group averages for specified events.
    """
    
    # Create a new figure for each event
    for event_to_plot in events_to_plot:
        event_mean_key = f'{event_to_plot}{suffix}'
        mice_with_no_events = []
        has_any_data = False
        
        # Check for mice with no data
        for f in range(len(alldat)):
            mousenum = alldat[f]['mouseID']
            
            # Check if the data exists and has the correct shape
            if (event_mean_key not in alldat[f] or 
                alldat[f][event_mean_key] is None or 
                len(alldat[f][event_mean_key]) == 0 or 
                len(alldat[f][event_mean_key]) != len(peri_time)):
                mice_with_no_events.append(mousenum)
            else:
                has_any_data = True
        
        if mice_with_no_events:
            print(f"\n{event_to_plot} - Mice with no events: {', '.join(map(str, mice_with_no_events))}")
            
        if not has_any_data:
            print(f"No valid data found for event: {event_to_plot}")
            continue
            
        # Plot 1: Individual traces
        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        for f in range(len(alldat)): 
            mousenum = alldat[f]['mouseID']
            
            if (event_mean_key in alldat[f] and 
                alldat[f][event_mean_key] is not None and 
                len(alldat[f][event_mean_key]) > 0 and
                len(alldat[f][event_mean_key]) == len(peri_time)):
                plt.plot(peri_time, alldat[f][event_mean_key], label=f"{mousenum} ({alldat[f]['group']})")

        plt.title(f"{event_to_plot}{suffix}\nIndividual Traces")
        plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
        plt.xlabel('Time (s)')
        plt.ylabel('Signal')
        
        # Plot 2: Group averages
        plt.subplot(1, 2, 2)
        for group_id in group_ids:
            event_arrays = []

            for f in range(len(alldat)):
                if (alldat[f]['group'] == group_id and 
                    event_mean_key in alldat[f] and 
                    alldat[f][event_mean_key] is not None and
                    len(alldat[f][event_mean_key]) == len(peri_time)):
                    event_arrays.append(alldat[f][event_mean_key])

            if event_arrays:  # Only process if we have data for this group
                # Stack the arrays to compute the average and standard error
                event_stack = np.vstack(event_arrays)
                event_mean = np.mean(event_stack, axis=0)
                event_sem = np.std(event_stack, axis=0) / np.sqrt(event_stack.shape[0])

                # Plotting the average array with standard error bands
                plt.plot(peri_time, event_mean, label=f'{group_id} (n={len(event_arrays)})')
                plt.fill_between(peri_time, 
                               event_mean - event_sem, 
                               event_mean + event_sem, 
                               alpha=0.3, 
                               label='_nolegend_')

        plt.title(f"{event_to_plot}{suffix}\nGroup Averages")
        plt.legend()
        plt.xlabel('Time (s)')
        plt.ylabel('Signal')
        
        plt.tight_layout()

    plt.show()  # Show all figures

#### Mapping indices

In [None]:
#Convert etho indices to photom indices

# Initialize events list
events = []

# Add velocity events if velocity data is present
if experiment_type != 'synapse_notes_only':
    # Check if any mouse has velocity data
    has_velocity_data = any('velocity' in alldat[f] and len(alldat[f]['velocity']) > 0 for f in range(len(alldat)))
    
    if has_velocity_data:
        # Check if velocity onsets were actually found
        has_high_velocity = any('high_velocity_onsets' in alldat[f] and len(alldat[f]['high_velocity_onsets']) > 0 for f in range(len(alldat)))
        has_low_velocity = any('low_velocity_onsets' in alldat[f] and len(alldat[f]['low_velocity_onsets']) > 0 for f in range(len(alldat)))
        
        if has_high_velocity:
            events.append('high_velocity_onsets')
        if has_low_velocity:
            events.append('low_velocity_onsets')

# Add zone events based on experiment configuration
if experiment_type != 'synapse_notes_only':
    zones = EXPERIMENT_CONFIGS[experiment_type]
    for zone in zones:
        zone_lower = zone.lower()
        
        # Check if zone entries/exits exist for any mouse
        entries_key = f'{zone_lower}_entries'
        exits_key = f'{zone_lower}_exits'
        
        has_entries = any(entries_key in alldat[f] and len(alldat[f][entries_key]) > 0 for f in range(len(alldat)))
        has_exits = any(exits_key in alldat[f] and len(alldat[f][exits_key]) > 0 for f in range(len(alldat)))
        
        if has_entries:
            events.append(entries_key)
        if has_exits:
            events.append(exits_key)

# Add manual behaviors if they exist
if experiment_type != 'synapse_notes_only':
    # Check if manual behaviors were identified (from unique_behaviors)
    if 'unique_behaviors' in locals() and len(unique_behaviors) > 0:
        for behavior in unique_behaviors:
            # Check if this behavior exists for any mouse
            has_behavior = any(behavior in alldat[f] and len(alldat[f][behavior]) > 0 for f in range(len(alldat)))
            if has_behavior:
                events.append(behavior)

# Handle synapse_notes_only separately
if experiment_type == 'synapse_notes_only':
    # Get only the numbered bit keys (bit1, bit2, etc.)
    events = [key for key in alldat[0].keys() 
             if key.startswith('bit') and key[3:].isdigit()]  # ensures there's a number after 'bit'

# Raise error if no events found
if not events:
    raise ValueError(f"No valid events found for experiment type: {experiment_type}")

print(f"Using events: {events}")

# Store the events list in each mouse's data for later reference
for f in range(len(alldat)):
    alldat[f]['events_list'] = events

# Map indices based on experiment type
if experiment_type == 'synapse_notes_only': 
    for i in range(len(events)):
        map_timestamps_to_photom_indices(alldat, events[i], f'mapped_{events[i]}')  # for TTL timestamps
else:
    for i in range(len(events)):
        map_etho_indices_to_photom_indices(alldat, events[i], f'mapped_{events[i]}')  # for etho indices

#### Snips from motion corrected data

In [None]:
#Get snips and means from motion corrected data
#Note, "corrected_snips" here is equivalend of "bit_snips" in operant box code

for i in range(len(events)):
    extract_snips(alldat, TRANGE, f'mapped_{events[i]}', f'{events[i]}_corrected_snips', f'{events[i]}_corrected_mean', 'corrected_green')
    
    # Print summary of snips extracted for this event
    print(f"\n{events[i]} snips extracted:")
    for f in range(len(alldat)):
        snips_key = f'{events[i]}_corrected_snips'
        if snips_key in alldat[f]:
            num_snips = len(alldat[f][snips_key])
            print(f"  Mouse {alldat[f]['mouseID']}: {num_snips} snips")
        else:
            print(f"  Mouse {alldat[f]['mouseID']}: No snips found")

Plot snips from motion corrected data

In [None]:
events_to_plot = alldat[0]['events_list'] # this will plot everything, or you can specify a list of events i.e. ['edge_entries', 'velocity'] etc.
suffix = '_corrected_mean'  

plot_event_means(events_to_plot, suffix)

#### calculate peri-event z scores

In [None]:
def compute_z_scores(alldat, fs, TRANGE, Base_start, Base_end, AUC_start, AUC_end, vel_key):
    """
    Computes z-scored dF/F snippets, their means, standard errors, and AUC values.
    
    Parameters:
    - alldat: List of dictionaries containing data for each session.
    - fs: Sampling frequency.
    - TRANGE: Time range for peri-event window.
    - Base_start: Baseline start time (relative to event, in seconds).
    - Base_end: Baseline end time (relative to event, in seconds).
    - AUC_start: AUC calculation start time (relative to event, in seconds).
    - AUC_end: AUC calculation end time (relative to event, in seconds).
    - vel_key: Key in alldat[f] for velocity-related dF/F snippets (e.g., 'high_vel').

    Modifies alldat in place by adding z-scored data, means, standard errors, and AUC values.
    """

    # Convert time parameters to indices relative to TRANGE
    Base_start = int(Base_start * np.floor(fs) - TRANGE[0]) 
    Base_end = int(Base_end * np.floor(fs) - TRANGE[0]) 
    AUC_start = int(AUC_start * np.floor(fs) - TRANGE[0])
    AUC_end = int(AUC_end * np.floor(fs) - TRANGE[0])

    for f in range(len(alldat)):  
        trials = len(alldat[f][f"{vel_key}_corrected_snips"])
        if trials > 0:

            z_snips = [None] * trials
            z_AUC_calc = [None] * trials

            for i in range(trials):
                dFF_snip = np.array(alldat[f][f"{vel_key}_corrected_snips"])[i, :]
            
                # Compute baseline mean and std deviation
                zb = np.mean(dFF_snip[Base_start:Base_end])  
                zsd = np.std(dFF_snip[Base_start:Base_end])  

                # Compute Z-score for each trial
                z_snips[i] = (dFF_snip - zb) / zsd  

                # Compute AUC for the specified period
                z_AUC_calc[i] = np.trapz(z_snips[i][AUC_start:AUC_end])  

            # Compute mean and standard error
            z_means = np.mean(np.array(z_snips), axis=0)
            z_sterr = np.std(np.array(z_snips), axis=0) / np.sqrt(np.array(z_snips).shape[0])

            # Compute AUC values
            z_AUCs = np.array(z_AUC_calc)
            z_AUCmeans = np.mean(z_AUCs)

            # Store results in alldat
            alldat[f][f"{vel_key}_z_snips"] = z_snips
            alldat[f][f"{vel_key}_z_mean"] = z_means
            alldat[f][f"{vel_key}_z_sterr"] = z_sterr
            alldat[f][f"{vel_key}_z_AUCs"] = z_AUCs
            alldat[f][f"{vel_key}_z_AUCmeans"] = z_AUCmeans

In [None]:
#z score for all snips
#numbers are baseline start, baseline end, AUC start, AUC end
for i in range(len(events)):
    compute_z_scores(alldat, fs, TRANGE, base_start, base_end, AUC_start, AUC_end, events[i])

Plot peri-event z scores

In [None]:
events_to_plot = alldat[0]['events_list'] # this will plot everything, or you can specify a list of events i.e. ['edge_entries', 'velocity'] etc.
suffix = '_z_mean'  

plot_event_means(events_to_plot, suffix)

### Z scoring to whole trace

In [None]:
#z-score to mean and stdev of whole trace
for f in range(len(alldat)): 
    zm = np.mean(alldat[f]['corrected_green']) # mean
    zsd = np.std(alldat[f]['corrected_green']) # stdev
    zscore = (alldat[f]['corrected_green']-zm)/zsd
    
    alldat[f]['z_whole_trace']=zscore

#Get snips and means for velocity and location transitions

for i in range(len(events)):
    extract_snips(alldat, TRANGE, f'mapped_{events[i]}', f'{events[i]}_z_whole_trace_snips', f'{events[i]}_z_whole_trace_mean', 'z_whole_trace')

Plot whole trace z

In [None]:
events_to_plot = alldat[0]['events_list']
suffix = '_z_whole_trace_mean'  

plot_event_means(events_to_plot, suffix)

### Export files

In [None]:
#Export means

# Specify which keys to export
export_keys = ['z_mean', 'z_whole_trace_mean'] #not including corrected_mean but could be added back in

triallength = len(peri_time)

# Sort alldat by group and get sorted column names
sorted_indices = sorted(range(len(alldat)), key=lambda i: alldat[i]['group'])
colnames = [alldat[i]['mouseID'] for i in sorted_indices] 

for key in export_keys:
    excel_file_path = os.path.join(data_output, f'{key}.xlsx')
    os.makedirs(os.path.dirname(excel_file_path), exist_ok=True)
    
    with pd.ExcelWriter(excel_file_path, engine='xlsxwriter') as writer:
        for event in events:
            event_key = f'{event}_{key}'  # construct the full key name
            zframe = np.empty((0, triallength))

            # Use sorted_indices to iterate through alldat in group order
            for i in sorted_indices:
                f = i  # This maintains the original index for accessing alldat[f]
                if event_key in alldat[f]:
                    data = alldat[f][event_key]
                    # Check if data is empty or entirely NaN
                    if data is not None and len(data) > 0 and np.any(~np.isnan(data)):  
                        zframe = np.vstack((zframe, data))
                    else:
                        zframe = np.vstack((zframe, np.full((triallength,), np.nan)))
                else:
                    zframe = np.vstack((zframe, np.full((triallength,), np.nan)))

            df = pd.DataFrame(zframe).T
            df.columns = colnames
            df.insert(0, 'peri_time', peri_time)
            
            # Get group labels in the same order as colnames
            group_labels = [alldat[i]['group'] for i in sorted_indices]
            group_row = [''] + group_labels  # Blank for peri_time column

            # Insert the group row as the first row of the DataFrame
            df_with_group = pd.concat([
                pd.DataFrame([group_row], columns=df.columns),  # group row
                df
            ], ignore_index=True)

            sheet_name = event[:31]  # Excel has a 31 character limit for sheet names
            df_with_group.to_excel(writer, sheet_name=sheet_name, index=False)

In [None]:
#Create 1 workbook per mouse that has all types of snips for all trials for all events

# Create subfolder path
subfolder_name = f"individual_trials"
subfolder_path = os.path.join(data_output, subfolder_name)
os.makedirs(subfolder_path, exist_ok=True)

for f in range(len(alldat)):
    
    excel_file_path = os.path.join(subfolder_path, f"{alldat[f]['mouseID']}_trials.xlsx")

    with pd.ExcelWriter(excel_file_path, engine='xlsxwriter') as writer:

        for event in events:
            # Define all possible snip types for this event
            snip_types = {
               # 'corrected': f'{event}_corrected_snips',
                'z': f'{event}_z_snips',
                'z_whole_trace': f'{event}_z_whole_trace_snips'
            }
            
            # Process each snip type - create tab even if empty
            for snip_type, key in snip_types.items():
                # Generate sheet name: truncate event name first, then add snip_type
                # Excel sheet name limit is 31 characters
                # Reserve space for underscore and snip_type (longest is "z_whole_trace" = 13 chars)
                max_event_len = 31 - len(snip_type) - 1  # -1 for underscore
                truncated_event = event[:max_event_len]
                sheet_name = f'{truncated_event}_{snip_type}'[:31]  # Final safety check
                
                # Check if snips exist and have data
                if key in alldat[f] and len(alldat[f][key]) > 0:
                    # Convert snips to DataFrame and transpose
                    snips_df = pd.DataFrame(alldat[f][key]).T
                    # Add time column
                    snips_df.insert(0, 'Time', peri_time)
                else:
                    # Create empty DataFrame with just Time column if no data
                    snips_df = pd.DataFrame({'Time': peri_time})
                
                # Save to Excel (creates tab even if empty)
                snips_df.to_excel(writer, sheet_name=sheet_name, index=False)

print("Individual trial workbooks created for all mice")