Used for analyzing photometry data collected using Synapse software (Tucker Davis) with TTLs to timestamp events sent from MedAssociates operant boxes.

Code for curve fit and motion correction adapted from Simpson et al. 2024 (https://doi.org/10.1016/j.neuron.2023.11.016)

In [None]:
import numpy as np
from sklearn.metrics import auc
import matplotlib.pyplot as plt
import scipy.stats as stats
import tdt
import statsmodels.api as sm
import pandas as pd
import os
import pylab as plt
from scipy.signal import medfilt, butter, filtfilt
from scipy.stats import linregress
from scipy.optimize import curve_fit, minimize

### Import Data


In [None]:
# Assign each mouse to a group
group_id = {
    'mouseID':'groupname',
    'mouseID':'groupname'
    }

group_ids = list(set(group_id.values()))

In [None]:
#Folder containing 1 day of TDT data from multiple animals
data_directory = './Data'

#Specify where to save the output
data_output = './Output'

# Get a list of all folders in the 'data' directory
folderlist = [item for item in os.listdir(data_directory) if os.path.isdir(os.path.join(data_directory, item))]

# Extract metadata from the folder names and read in the data. Assumes folder names contain eartags as '1234-5678'
#Creates 'alldat', a list of structs containing all the data 
masterdat = []
for foldername in folderlist:
    dat = {}
    
    if len(foldername) == 18:
        dat['mouse1'] = foldername[0:4]
        dat['mouse2'] = '0000'
        
        dat['date'] = foldername[5:11]
        
    else:    
        dat['mouse1'] = foldername[0:4]
        dat['mouse2'] = foldername[5:9]
    
        dat['date'] = foldername[10:16]

    dat['blockpath'] = foldername  

    # Reading in data
    dat['data'] = tdt.read_block(os.path.join(data_directory, dat['blockpath']))
    
    masterdat.append(dat)

### Pulling data and TTLs for each mouse

In [None]:
#The "_470A" etc. values and epocs may need to be changed if recording from a different photometry setup

alldat = []

for f in range(len(masterdat)):
    
    if masterdat[f]['mouse1'] != '0000':
        dat1 = {}
        dat1['mouseID'] = masterdat[f]['mouse1']
        dat1['green'] = masterdat[f]['data'].streams._470A.data
        dat1['isos'] = masterdat[f]['data'].streams._405A.data
        dat1['TTL_values'] = masterdat[f]['data'].epocs.BOX1.data
        dat1['TTL_timestamps'] = masterdat[f]['data'].epocs.BOX1.onset
        dat1['sampling_rate'] = masterdat[f]['data'].streams._470A.fs
        alldat.append(dat1)
    
    if masterdat[f]['mouse2'] != '0000':
        dat2 = {}
        dat2['mouseID'] = masterdat[f]['mouse2']
        dat2['green'] = masterdat[f]['data'].streams._470B.data
        dat2['isos'] = masterdat[f]['data'].streams._405B.data
        dat2['TTL_values'] = masterdat[f]['data'].epocs.BOX2.data
        dat2['TTL_timestamps'] = masterdat[f]['data'].epocs.BOX2.onset 
        dat2['sampling_rate'] = masterdat[f]['data'].streams._470B.fs
        alldat.append(dat2)

#assigning group name

for f in range(len(alldat)):
    alldat[f]['group'] = group_id[alldat[f]['mouseID']]

In [None]:
# trim so length matches
    
for f in range(len(alldat)):
    
    a = len(alldat[f]['green'])
    b = len(alldat[f]['isos'])
    if b < a:
        alldat[f]['green'] = alldat[f]['green'][0:b]
    if a < b:
        alldat[f]['isos'] = alldat[f]['isos'][0:a]


In [None]:
# Get time series in seconds

for f in range(len(alldat)):
    alldat[f]['time'] = (np.arange(1,len(alldat[f]['green'])+1))/alldat[f]['sampling_rate']


### Remove artifact at start and downsample

In [None]:
# remove artifact at start
    
t = 8 # time threshold (in seconds) below which we will discard

for f in range(len(alldat)):
    indA = np.argmax(alldat[f]['time'] > t) # find first index of when time crosses threshold
    alldat[f]['time'] = alldat[f]['time'][indA:] # reformat vector to only include allowed time
    alldat[f]['trimmed_green'] = alldat[f]['green'][indA:]
    alldat[f]['trimmed_isos'] = alldat[f]['isos'][indA:]
    

In [None]:
# Remove last 10 seconds from the end of each recording in case you forgot to stop before unplugging

end_trim = 10  # Time in seconds to remove from the end of each recording

for f in range(len(alldat)):
    # Get the total duration of the recording
    T = alldat[f]['time'][-1]
    
    # Define the new end threshold
    t_end = T - end_trim
    
    # Find the first index where time exceeds t_end
    indB = np.argmax(alldat[f]['time'] > t_end)
    
    # Edge case: if the last time is still below or equal to t_end, it means there's effectively nothing to remove.
    if indB == 0 and alldat[f]['time'][-1] <= t_end:
        continue
    else:
        # Trim up to (not including) indB
        alldat[f]['time'] = alldat[f]['time'][:indB]
        alldat[f]['trimmed_green'] = alldat[f]['trimmed_green'][:indB]
        alldat[f]['trimmed_isos'] = alldat[f]['trimmed_isos'][:indB]

In [None]:
# Downsample and average 

N = 100 # Average every N samples into 1 value

for f in range(len(alldat)):
    F_green=[]
    for i in range(0, len(alldat[f]['trimmed_green']), N):
        small_list = np.mean(alldat[f]['trimmed_green'][i:i+N-1])
        F_green.append(small_list)
    alldat[f]['downsampled_green'] = np.array(F_green)
    

    F_isos=[]
    for i in range(0, len(alldat[f]['trimmed_isos']), N):
        small_lst = np.mean(alldat[f]['trimmed_isos'][i:i+N-1])
        F_isos.append(small_lst)
    alldat[f]['downsampled_isos'] = np.array(F_isos)
    
    alldat[f]['downsampled_time'] = alldat[f]['time'][::N]
    

### Fit double exponential to each curve and use to correct for bleaching

In [None]:
# The double exponential curve we are going to fit.
def double_exponential(t, const, amp_fast, amp_slow, tau_slow, tau_multiplier):
    '''Compute a double exponential function with constant offset.
    Parameters:
    t       : Time vector in seconds.
    const   : Amplitude of the constant offset. 
    amp_fast: Amplitude of the fast component.  
    amp_slow: Amplitude of the slow component.  
    tau_slow: Time constant of slow component in seconds.
    tau_multiplier: Time constant of fast component relative to slow. 
    '''
    tau_fast = tau_slow*tau_multiplier
    return const+amp_slow*np.exp(-t/tau_slow)+amp_fast*np.exp(-t/tau_fast)

In [None]:
def get_bounds(trace, tau_slow_min=0.0001, tau_slow_max=30):

    signal = trace
    timecourse = time

    # Amplitude bounds
    amp_min = 0  # Assuming amplitude cannot be negative
    amp_max = 2 * np.max(signal)  # Allowing for some flexibility

    # Time constant bounds based on the duration of the experiment
    time_constant_min = tau_slow_min * (timecourse[-1] - timecourse[0])  
    time_constant_max = tau_slow_max * (timecourse[-1] - timecourse[0])  

    # Offset bounds
    offset_min = np.min(signal) if np.min(signal) < 0 else 0  
    offset_max = np.max(signal)

    return (
        [amp_min, amp_min, amp_min, time_constant_min, offset_min],
        [amp_max, amp_max, amp_max, time_constant_max, offset_max]
    )

In [None]:
for f in range(len(alldat)):

# Assign the downsampled curves to variable names. 
    
    green=alldat[f]['downsampled_green']
    isos=alldat[f]['downsampled_isos']
    time=alldat[f]['downsampled_time']

    # Fit curve to green signal.
    max_sig = np.max(green)
    inital_params = [max_sig/2, max_sig/4, max_sig/4, 3600, 0.1]
    bounds = get_bounds(green)
    green_parms, parm_cov = curve_fit(double_exponential, time, green, 
                                    p0=inital_params, bounds=bounds, maxfev=1000)
    green_expfit = double_exponential(time, *green_parms)
    alldat[f]['expfitgreen'] = green_expfit

    # Fit curve to 405 signal.
    max_sig = np.max(isos)
    inital_params = [max_sig/2, max_sig/4, max_sig/4, 3600, 0.1]
    bounds = get_bounds(isos)
    isos_parms, parm_cov = curve_fit(double_exponential, time, isos, 
                                        p0=inital_params, bounds=bounds, maxfev=1000)
    isos_expfit = double_exponential(time, *isos_parms)
    alldat[f]['expfitisos'] = isos_expfit

    #plot fits over denoised data
    fig,ax1=plt.subplots()  
    plot1=ax1.plot(time, green, 'g', label='green')
    plot3=ax1.plot(time, green_expfit, color='k', linewidth=1.5, label='Exponential fit') 
    ax2=plt.twinx()
    plot2=ax2.plot(time, isos, color='b', label='isos') 
    plot4=ax2.plot(time, isos_expfit,color='k', linewidth=1.5) 


    ax1.set_xlabel('Time (seconds)')
    ax1.set_ylabel('Green Signal (V)', color='g')
    ax2.set_ylabel('Isos Signal (V)', color='b')
    ax1.set_title(alldat[f]['mouseID'])

    lines = plot1 + plot2 + plot3
    labels = [l.get_label() for l in lines]  
    legend = ax1.legend(lines, labels, loc='upper right'); 



In [None]:
#Subtract curve to correct for bleaching

for f in range(len(alldat)):
    green_detrended = alldat[f]['downsampled_green'] - alldat[f]['expfitgreen']
    alldat[f]['detrended_green'] = green_detrended
    isos_detrended = alldat[f]['downsampled_isos'] - alldat[f]['expfitisos']
    alldat[f]['detrended_isos'] = isos_detrended

    time=alldat[f]['downsampled_time']

    fig,ax1=plt.subplots()  
    plot1=ax1.plot(time, green_detrended, 'g', label='green')
    ax2=plt.twinx()
    plot2=ax2.plot(time, isos_detrended, color='b', label='isos') 

    ax1.set_xlabel('Time (seconds)')
    ax1.set_ylabel('Green Signal (V)', color='g')
    ax2.set_ylabel('Isos Signal (V)', color='b')
    ax1.set_title(alldat[f]['mouseID'])

    lines = plot1+plot2 
    labels = [l.get_label() for l in lines]  
    legend = ax1.legend(lines, labels, loc='upper right'); 


### Correcting for movement artifacts

In [None]:
for f in range(len(alldat)):
    
    #Scatter plot of Isos vs Green signal
    
    green_detrended = alldat[f]['detrended_green']
    isos_detrended = alldat[f]['detrended_isos']
    time = alldat[f]['downsampled_time']

    slope, intercept, r_value, p_value, std_err = linregress(x=isos_detrended, y=green_detrended)

    plt.figure()
    plt.scatter(isos_detrended[::5], green_detrended[::5],alpha=0.1, marker='.')
    x = np.array(plt.xlim())
    plt.plot(x, intercept+slope*x)
    plt.xlabel('Isos')
    plt.ylabel('Green')
    plt.title('Isos - Green correlation')

    print('Slope    : {:.3f}'.format(slope))
    print('R-squared: {:.3f}'.format(r_value**2))

    #Calculate motion effect

    green_est_motion = intercept + slope * isos_detrended
    green_corrected = green_detrended - green_est_motion

    alldat[f]['corrected_green']=green_corrected

    plt.figure()
    fig,ax1=plt.subplots()  
    plot1=ax1.plot(time, green_detrended, 'b' , label='Green - pre motion correction', alpha=0.5)
    plot3=ax1.plot(time, green_corrected, 'g', label='Green - motion corrected', alpha=0.5)
    plot4=ax1.plot(time, green_est_motion - 1.05, 'y', label='estimated motion')
    
    ax1.set_xlabel('Time (seconds)')
    ax1.set_ylabel('Green Signal (V)', color='g')
    ax1.set_title(alldat[f]['mouseID'])

    lines = plot1+plot3+plot4 #+ reward_ticks
    labels = [l.get_label() for l in lines]  
    legend = ax1.legend(lines, labels, loc='upper right', bbox_to_anchor=(0.95, 0.98))

    ax1.set_xlim(1000, 1060)  # 60 sec window


### Z scoring

In [None]:
#Set parameters for Z scoring

#Peri-event windows
PRE_TIME = 10 #  seconds before event to include
POST_TIME = 20 #  seconds after event to include

Base_start = -10 # seconds relative to event to start z-score baseline
Base_end = -6 # seconds relative to event to end z-score baseline

AUC_start = 0 # seconds relative to event to start AUC region
AUC_end = 5 # seconds relative to event to end AUC region

In [None]:
#Function to expand TTLs
def find_bits(sum_result):
    binary_representation = bin(sum_result)[2:][::-1]  # Reverse the binary string
    bits = [i for i, bit in enumerate(binary_representation) if bit == '1']
    return bits

In [None]:
#Getting list of bit values and timestamps

for f in range(len(alldat)):

    bit_values=[]
    bit_timestamps=[]

    for TTL in range(len(alldat[f]['TTL_values'])):
        bits=find_bits(alldat[f]['TTL_values'][TTL].astype(int))
        ts=alldat[f]['TTL_timestamps'][TTL]

        for bit in range(len(bits)):
            bit_values=np.append(bit_values, bits[bit])
            bit_timestamps=np.append(bit_timestamps, ts)

    alldat[f]['bit_values']=bit_values
    alldat[f]['bit_timestamps']=bit_timestamps

In [None]:
# Generate peri_time, centered around TTL event

fs = alldat[0]['sampling_rate']/N; # must account for downsampling w/ N

#time span for peri-event filtering        
TRANGE = [-1*PRE_TIME*np.floor(fs),POST_TIME*np.floor(fs)]

peri_time = range(int(TRANGE[1]-TRANGE[0]))/fs - PRE_TIME*np.floor(fs)/fs

In [None]:
#Get snips of trace surrounding each timestamp

#this gives the number of datapoints that should be in each snip
correct_size = TRANGE[1]-TRANGE[0]

for f in range(len(alldat)): 

    #Get list of bits used in this experiment
    Bit_list = sorted(set(alldat[f]['bit_values'].astype(int)))

    Bit_snips = {}
    Bit_means = {}

    #Get snips of trace surrounding each timestamp
    for Bit in Bit_list:

        Bitname = 'Bit_'+str(Bit)
        Beh_bits = np.where(alldat[f]['bit_values'] == Bit)[0]
        Beh_t = alldat[f]['bit_timestamps'][Beh_bits]
            
        trials = len(Beh_t)
        dFF_snips = [None] * trials
        array_ind = np.zeros(trials)
        pre_stim = np.zeros(trials)
        post_stim = np.zeros(trials)
            
        for i in range(trials):
                
            # Find first time index after bout onset
            array_ind[i] = np.argmax(alldat[f]['downsampled_time'] > Beh_t[i])

            # Find index corresponding to pre and post stim durations, making sure they're within the bounds of the trace
            pre_stim[i] = max(0, (array_ind[i] + TRANGE[0]))
            post_stim[i] = min((array_ind[i] + TRANGE[1]), len(alldat[f]['downsampled_time'])) 

            dFF_snips[i] = alldat[f]['corrected_green'][int(pre_stim[i]):int(post_stim[i])]
            

        #ignore any NANs or short snips caused by events too close to start or end of trace
        dFF_snips = [arr for arr in dFF_snips if isinstance(arr, np.ndarray) and not np.isnan(arr).any()]
        dFF_snips = [arr for arr in dFF_snips if len(arr) == correct_size]
            
        #Convert to a matrix. Store in 'Bit_snips' within each dat                
        Bit_snips[Bitname] = np.array(dFF_snips)

        #Get mean and stdev for each bit
        Bit_means[Bitname] = np.mean(np.array(dFF_snips), axis=0)


    alldat[f]['Bit_snips'] = Bit_snips #Call these by 'Bit_0' etc. 
    alldat[f]['Bit_means'] = Bit_means    



In [None]:
#Plot bits 

bit_to_plot = 'Bit_1'

for f in range(len(alldat)): 
    
    mousenum = alldat[f]['mouseID']

    if bit_to_plot in alldat[f]['Bit_means']:
        plt.plot(peri_time , alldat[f]['Bit_means'][bit_to_plot], label = str(mousenum+'_'+bit_to_plot))

    plt.title('Bit means')
    plt.legend(loc='upper left')
    

In [None]:
#Plot average across groups

bit_to_plot = 'Bit_1'


for group_id in group_ids:
    bit_arrays = []

    for f in range(len(alldat)):
        if alldat[f]['group']==group_id:
            if bit_to_plot in alldat[f]['Bit_means']:
                bit_arrays.append(alldat[f]['Bit_means'][bit_to_plot])

    # Stack the arrays to compute the average and standard error
    bit_stack = np.vstack(bit_arrays)
    bit_mean = np.mean(bit_stack, axis=0)
    bit_sem = np.std(bit_stack, axis=0) / np.sqrt(bit_stack.shape[0])  # Standard error of the mean

    # Plotting the average array with standard error bands
    plt.plot(peri_time, bit_mean, label=f'{group_id}')
    plt.fill_between(peri_time, bit_mean - bit_sem, bit_mean + bit_sem, alpha=0.3, label='_nolegend_')

plt.title('Bit means')
plt.legend()


In [None]:
# z-score everything
        
# Time windows defined above, relative to timestamp
Base_start_calc = (Base_start*np.floor(fs) - TRANGE[0]).astype(int) 
Base_end_calc = (Base_end*np.floor(fs) - TRANGE[0]).astype(int) 

AUC_start_calc = (AUC_start*np.floor(fs) - TRANGE[0]).astype(int)
AUC_end_calc = (AUC_end*np.floor(fs) - TRANGE[0]).astype(int)

for f in range(len(alldat)):  
        
    z_scores = {}
    z_means = {}
    z_sterr = {}
    z_AUCs = {}
    z_AUCmeans = {}
    z_means_20 = {}


    Bit_list = sorted(set(alldat[f]['bit_values'].astype(int))) 

    for Bit in Bit_list: 

        Bitname = 'Bit_'+str(Bit)

        trials = len(alldat[f]['Bit_snips'][Bitname])
        z_snips = [None] * trials
        z_AUC_calc = [None] * trials

        for i in range(trials):
            zb = np.mean(alldat[f]['Bit_snips'][Bitname][i, Base_start_calc:Base_end_calc]) # baseline period mean 
            zsd = np.std(alldat[f]['Bit_snips'][Bitname][i, Base_start_calc:Base_end_calc]) # baseline period stdev
            z_snips[i]=(alldat[f]['Bit_snips'][Bitname][i,:] - zb)/zsd # Z score for each trial
            z_AUC_calc[i] = np.trapz(z_snips[i][AUC_start_calc:AUC_end_calc], peri_time[AUC_start_calc:AUC_end_calc])

        #Convert to a matrix. Store in 'Bit_snips' within each dat
        z_scores[Bitname] = np.array(z_snips)

        #Get mean and stdev for each bit        
        z_means[Bitname] = np.mean(np.array(z_snips), axis=0)
        z_sterr[Bitname] = np.std(np.array(z_snips), axis=0)/np.sqrt(np.array(z_snips).shape[0])
        
        #Write AUCs for each bit
        z_AUCs[Bitname] = np.array(z_AUC_calc)

        z_AUCmeans[Bitname] = np.mean(z_AUCs[Bitname])

    alldat[f]['z_scores'] = z_scores 
    alldat[f]['z_means'] = z_means 
    alldat[f]['z_sterr'] = z_sterr
    alldat[f]['z_AUCs'] = z_AUCs
    alldat[f]['z_AUCmeans'] = z_AUCmeans   


In [None]:
#Plot z scores

bit_to_plot = 'Bit_1'

for f in range(len(alldat)): 
    
    mousenum = alldat[f]['mouseID']

    if bit_to_plot in alldat[f]['z_means']:
        plt.plot(peri_time , alldat[f]['z_means'][bit_to_plot], label = str(mousenum+'_'+bit_to_plot))

    plt.title('Z means')
    plt.legend(loc='upper left')

In [None]:
#Plot average

bit_to_plot = 'Bit_1'

for group_id in group_ids:
    bit_arrays = []
    for f in range(len(alldat)):
        if alldat[f]['group']==group_id:
            if bit_to_plot in alldat[f]['z_means']:
                bit_arrays.append(alldat[f]['z_means'][bit_to_plot])

    # Stack the arrays to compute the average and standard error
    bit_stack = np.vstack(bit_arrays)
    bit_mean = np.mean(bit_stack, axis=0)
    bit_sterr = np.std(bit_stack, axis=0) / np.sqrt(bit_stack.shape[0])

    # Plotting the average array with error bands
    plt.plot(peri_time, bit_mean, label=f'{group_id}')
    plt.fill_between(peri_time, bit_mean - bit_sterr, bit_mean + bit_sterr, alpha=0.3)

plt.title('Z means')
plt.legend()


### Export files

In [None]:
#Export means

# Specify which keys to export
export_keys = ['z_means', 'Bit_means']  # Add or remove keys as needed

# Get the directory name after the final slash
directory_name = os.path.basename(data_directory)

triallength = len(peri_time)

# Sort alldat by group and get sorted column names
sorted_indices = sorted(range(len(alldat)), key=lambda i: alldat[i]['group'])
colnames = [alldat[i]['mouseID'] for i in sorted_indices] 
Bit_list = [0,1,2,3,4,5,6] 

for key in export_keys:
    excel_file_path = os.path.join(data_output, f'{directory_name}_{key}.xlsx')
    os.makedirs(os.path.dirname(excel_file_path), exist_ok=True)
    
    with pd.ExcelWriter(excel_file_path, engine='xlsxwriter') as writer:
        for Bit in Bit_list: 
            Bitname = 'Bit_'+str(Bit)  
            zframe = np.empty((0, triallength))

            # Use sorted_indices to iterate through alldat in group order
            for i in sorted_indices:
                f = i  # This maintains the original index for accessing alldat[f]
                if Bitname in alldat[f].get(key, {}):
                    data = alldat[f][key][Bitname]
                    # Check if data is empty or entirely NaN
                    if data is not None and np.any(~np.isnan(data)):  
                        zframe = np.vstack((zframe, data))
                    else:
                        zframe = np.vstack((zframe, np.full((triallength,), np.nan)))
                else:
                    zframe = np.vstack((zframe, np.full((triallength,), np.nan)))

            df = pd.DataFrame(zframe).T
            df.columns = colnames
            df.insert(0, 'peri_time', peri_time)
            
            # Get group labels in the same order as colnames
            group_labels = [alldat[i]['group'] for i in sorted_indices]
            group_row = [''] + group_labels  # Blank for peri_time column

            # Insert the group row as the first row of the DataFrame
            df_with_group = pd.concat([
                pd.DataFrame([group_row], columns=df.columns),  # group row
                df
            ], ignore_index=True)

            sheet_name = Bitname
            df_with_group.to_excel(writer, sheet_name=sheet_name, index=False)

In [None]:
#Create 1 workbook per mouse that has bit_snips and zscores for all trials for all bits

# Get the directory name after the final slash
directory_name = os.path.basename(data_directory)

# Create subfolder path
subfolder_name = f"{directory_name}_individual_trials"
subfolder_path = os.path.join(data_output, subfolder_name)
os.makedirs(subfolder_path, exist_ok=True)

for f in range(len(alldat)):
    
    excel_file_path = os.path.join(subfolder_path, f"{alldat[f]['mouseID']}_trials.xlsx")

    Bit_list = sorted(set(alldat[f]['bit_values'].astype(int)))

    with pd.ExcelWriter(excel_file_path, engine='xlsxwriter') as writer:

        for Bit in Bit_list: 

            Bitname = 'Bit_'+str(Bit)  
            if Bitname in alldat[f]['Bit_snips']:
                Bit_snips = pd.DataFrame(alldat[f]['Bit_snips'][Bitname]).T
                name1 = str(Bitname +'_Bit_snips')
            if Bitname in alldat[f]['z_scores']:
                Bit_zscore = pd.DataFrame(alldat[f]['z_scores'][Bitname]).T
                name2 = str(Bitname +'_zscores')

                Bit_snips.to_excel(writer, sheet_name=name1, index=False)
                Bit_zscore.to_excel(writer, sheet_name=name2, index=False)