In [1]:
# Import statements
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import sys
from glob import glob
from pymatreader import read_mat
from scipy import signal
from scipy import stats
from scipy.ndimage import uniform_filter1d
import time

#import matlab.engine
# Test out matlab spike detection code
#eng = matlab.engine.start_matlab()
#eng.eval("startup", nargout=0)

# Importing custom python files
import consts
import importlib
importlib.reload(consts)

from open_ephys.analysis import Session

%matplotlib tk

In [2]:
# Functions

# Get all of the ephys data and section everything across all trials
def read_openephys_data(pathname, start_chan, frame_chan, stim_chan, flicker_chan):
    session = Session(pathname)
    ephys_data = session.recordnodes[0].recordings[0].continuous[0]

    # Extract and find the start times of recording trials
    start_chan_idx = ephys_data.metadata['channel_names'].index(start_chan)
    raw_start_trigger = ephys_data.samples[:, start_chan_idx]

    frame_chan_idx = ephys_data.metadata['channel_names'].index(frame_chan)
    raw_frame_trigger = ephys_data.samples[:, frame_chan_idx]

    stim_chan_idx = ephys_data.metadata['channel_names'].index(stim_chan)
    raw_stim_trigger = ephys_data.samples[:, stim_chan_idx]

    flicker_chan_idx = ephys_data.metadata['channel_names'].index(flicker_chan)
    raw_flicker_trigger = ephys_data.samples[:, flicker_chan_idx]

    # Get all of the rise idxs
    record_start_idx = consts.get_ephys_rise_indices(raw_start_trigger)
    all_frame_idx = consts.get_ephys_rise_indices(raw_frame_trigger)
    all_stim_idx = consts.get_ephys_rise_indices(raw_stim_trigger)
    all_flicker_idx = consts.get_ephys_rise_indices(raw_flicker_trigger)

    # Get the reverse flicker idxs
    #TODO fix this and try to get the reverse indices for the flicker and get the pulse width
    all_flicker_end_idx = consts.get_ephys_rise_indices(raw_flicker_trigger)

    block_frame_timestamps = []
    block_stim_timestamps = []
    block_flicker_timestamps = []

    #DEBUG
    #plt.figure()
    #plt.plot(ephys_data.timestamps, raw_start_trigger)
    #plt.plot(ephys_data.timestamps[record_start_idx], raw_start_trigger[record_start_idx], '|')
    #plt.title('Ephys Debugging')
    #plt.show()

    #plt.figure()
    #plt.plot(ephys_data.timestamps, raw_start_trigger, '-g')
    #plt.plot(ephys_data.timestamps, raw_stim_trigger, '-g')
    #plt.plot(ephys_data.timestamps, raw_flicker_trigger, '-b')

    # Loop through each recording start trigger and parse together all of the trials
    for i in range(record_start_idx.shape[0]):
        start_idx = record_start_idx[i]
        if i == record_start_idx.shape[0] - 1:
            next_idx = float('inf')
        else:
            next_idx = record_start_idx[i + 1]

        # Find all of the camera frame indices for the given block
        frame_idxs = all_frame_idx[(all_frame_idx > start_idx) & (all_frame_idx < next_idx)]
        stim_idxs = all_stim_idx[(all_stim_idx > start_idx) & (all_stim_idx < next_idx)]
        flicker_idxs = all_flicker_idx[(all_flicker_idx > start_idx) & (all_flicker_idx < next_idx)]

        block_frame_timestamps.append(ephys_data.timestamps[frame_idxs])
        block_stim_timestamps.append(ephys_data.timestamps[stim_idxs])
        block_flicker_timestamps.append(ephys_data.timestamps[flicker_idxs])
    
        #plt.plot(ephys_data.timestamps[stim_idxs], raw_stim_trigger[stim_idxs], '|r')
        #plt.plot(ephys_data.timestamps[flicker_idxs], raw_flicker_trigger[flicker_idxs], '|r')
        #plt.plot(ephys_data.timestamps[start_idx], raw_start_trigger[start_idx], '|m')
    #plt.show()

    return block_stim_timestamps, block_frame_timestamps, block_flicker_timestamps

# Get the flicker experiment
def get_param(fname, param):
    # Loop through each element and search for string pattern
    for exp_param in fname.split('_'):
        if param in exp_param:
            return exp_param    
    print('Missing param: ' + param)
    return

In [3]:
f = os.sep

# Initialize empty dataframe
df = pd.DataFrame()

# Data rootpath is actually in handata3 for now
server_root = '/home/pierfier/handata_server/eng_research_handata3'
data_root_path = f.join([server_root, 'Yangyang_Wang', 'PV_V1_LED_SomArchon']) + f

# Find all folders that have the traces matfile and were flicker experiments
mat_paths = glob(data_root_path + '**/**RawTraces**flicker**DBS**.mat', recursive=True)

# Ignore matfiles that are in archive files
fun = lambda s: 'archive' not in s.lower()
mat_paths = list(filter(fun, mat_paths))

# Set an amount of frames to drop in front of the traces
front_frame_drop = 14

# Loop through each matfile that contains
not_aligned = []
for mat_i, mat_path in enumerate(mat_paths): #TODO DEBUG mat_paths:    
    print(str(mat_i) +' '+ mat_path)

    #Get the current experiments data 
    cur_flicker = get_param(os.path.basename(mat_path), 'flicker')
    cur_freq = get_param(os.path.basename(mat_path), 'DBS')
    cur_fov = get_param(os.path.basename(mat_path), 'fov')

    # Set the experimental parameters into variables
    cur_stim_freq = int(cur_freq.replace('DBS', '').replace('hz', ''))
    cur_flicker_freq = int(cur_flicker.replace('flicker', '').replace('hz', ''))
    session_id = os.path.basename(os.path.dirname(mat_path))
    mouse_id = os.path.basename(os.path.dirname(os.path.dirname(mat_path)))
    stim_param_string = 'currentamp:' + get_param(os.path.basename(mat_path), 'ua') 

    #Find corresponding experiment openephys data
    ephys_dir = f.join([os.path.dirname(mat_path), 'ephys']) + f
    ephys_path = glob(ephys_dir + '*' + cur_flicker + '*' + \
                        cur_freq + '*' + cur_fov + '*')
    if not ephys_path:
        #TODO read in model ephys folder, depending on stim frequency
        print('Missing Corresponding Ephys Folder')
        print('Matfile: ' + os.path.basename(mat_path))
        print('Using model ephys data')
        #continue
        ephys_path = glob('.'+f+ '*' + cur_flicker + '*' + \
                        cur_freq + '*')

    if len(ephys_path) > 1:
        print('There are multiple ephys_path\'s with this experiment')
        not_aligned = not_aligned.append(mat_path)
        continue

    # Read in all of the Open Ephys data
    #1 are the frames
    #3 are the flicker pulses
    #4 are the stim pulses
    #5 are the trial start triggers
    stim_timestamps, frame_timestamps, flicker_timestamps = \
        read_openephys_data(pathname=ephys_path[0], start_chan='ADC5', frame_chan='ADC1', \
                            stim_chan='ADC4', flicker_chan='ADC3')

    # Read in trace data
    data = read_mat(mat_path)
    sp_info = data['spike_detect_SA_v4_info']
    data = data['roi_list']

    #TODO figure out how to check if there are multiple neuron traces
    if len(data['traces'].shape) < 2:
        data['traces'] = data['traces'].reshape(-1, 1)

    # If there are more voltage trial traces than ephys start triggers,
    # Just duplicate it again
    if len(sp_info) > len(frame_timestamps):
        frame_timestamps.extend(frame_timestamps)
        stim_timestamps.extend(stim_timestamps)
        flicker_timestamps.extend(flicker_timestamps)

    for roi_i in range(data['traces'].shape[1]):

        # Loop through each trial
        for trial_i in range(len(sp_info)):
            #TODO need to figure out a way to skip this trial during alignment
            # Maybe try to remove that index ephys information across all of its arrays

            # Delete trial elements that do not have the same amount 
            while len(frame_timestamps[trial_i]) < 1:
                print('Removed trial stuff')
                frame_timestamps.pop(trial_i)
                stim_timestamps.pop(trial_i)

                flicker_timestamps.pop(trial_i)
            
            cur_frame_time = frame_timestamps[trial_i][front_frame_drop:]

            raw_trace = data['traces'][np.where(data['trial_vec'] == trial_i + 1)[0], roi_i]
            raw_trace = raw_trace[front_frame_drop:]

            cur_trace_noise = sp_info[trial_i*(data['traces'].shape[1]) + roi_i]['trace_noise'][front_frame_drop:]

            # Grab the first number set of camera frames as the trace
            cur_frame_time = cur_frame_time[:raw_trace.shape[0]]
            cur_stim_time = stim_timestamps[trial_i]
            trial_start = cur_frame_time[0]
            stim_start = cur_stim_time[0]

            cur_frame_time = cur_frame_time - trial_start
            cur_stim_time = cur_stim_time - trial_start
            cur_flicker_time = flicker_timestamps[trial_i] - trial_start

            # Grab the original spike detection data
            cur_spike_raster = np.zeros_like(raw_trace)
            spike_idx = sp_info[trial_i*(data['traces'].shape[1]) + roi_i]['spike_idx']
            if isinstance(spike_idx, int):
                spike_idx = np.array([spike_idx])

            spike_idx = spike_idx[np.where(spike_idx >= front_frame_drop + 1)]

            cur_spike_raster[spike_idx - front_frame_drop - 1] = 1

            # Define the 'interp_time' and interpolate all of the data
            #TODO need to change how the interpolation is done here
            #ideal_trial_end
            #TODO this also messes up the flicker and stim raster stuff for some reasons
            step = 1/500
            interp_time = np.arange(0, raw_trace.shape[0]*step, step) 
            # Old method that did not really take into account the idealized frequency
            #interp_time = np.linspace(0, cur_frame_time[-1] - cur_frame_time[0], raw_trace.shape[0])
            interp_raw_trace = np.interp(interp_time, cur_frame_time, raw_trace)
            interp_spike_raster = np.interp(interp_time, cur_frame_time, cur_spike_raster)

            # Interpolate subthreshold Vm, with spikes removed
            sub_vm = sp_info[trial_i*(data['traces'].shape[1]) + roi_i]['trace_spikeRemoved'][front_frame_drop:]
            interp_subvm = np.interp(interp_time, cur_frame_time, sub_vm)

            # Interpolated the detrended raw trace 
            #(this is from the quick_trial_check.m moving window average before spike detection)
            detrend_trace = sp_info[trial_i*(data['traces'].shape[1]) + roi_i]['trace_raw'][front_frame_drop:]
            interp_raw_detrend_trace = np.interp(interp_time, cur_frame_time, detrend_trace)

            # Adjusted spike index
            peak_idx, _ = signal.find_peaks(interp_spike_raster)
            interp_spike_raster_adj = np.zeros(interp_raw_trace.shape)
            interp_spike_raster_adj[peak_idx] = 1

            # Create a spike amplitude raster
            spike_amp_raster = np.full(interp_raw_trace.shape, np.nan)
            spike_amp = sp_info[trial_i*(data['traces'].shape[1]) + roi_i]['spike_amplitude']            
            
            if not isinstance(spike_amp, np.ndarray):
                spike_amp = np.array([spike_amp])

            spike_amp = spike_amp[spike_amp.shape[0] - spike_idx.shape[0]:]
            spike_amp = spike_amp[:np.min([np.where(interp_spike_raster_adj == 1)[0].shape[0], spike_amp.shape[0]])]
            spike_amp_raster[np.where(interp_spike_raster_adj == 1)[0]] = spike_amp

            # Create stimulation raster
            cur_stim_raster = np.zeros(raw_trace.shape[0])
            diff_mat = [interp_time] - np.transpose([cur_stim_time])
            diff_mat[diff_mat < 0] = float('inf')
            stim_idx_i = np.argmin(diff_mat, axis=1)
            cur_stim_raster[stim_idx_i] = 1

            # Create flicker raster
            # Calculate the light duration for 1Hz
            # TODO will need to eventually adjust this maybe?
            cur_flicker_raster = np.zeros(raw_trace.shape[0])
            diff_mat = [interp_time] - np.transpose([cur_flicker_time])
            diff_mat[diff_mat < 0] = float('inf')
            flicker_idx_i = np.argmin(diff_mat, axis=1)
            cur_flicker_raster[flicker_idx_i] = 1
            flick_dur = 0.06256 # Calculated one time
            flick_pts = flick_dur/np.mean(np.diff(cur_frame_time))

            # Adjust the interpolated time from the flicker onset here?
            flicker_start = interp_time[np.where(cur_flicker_raster == 1)[0][0]]
            interp_time = interp_time - flicker_start
                        
            #DEBUG
            #plt.figure()
            #plt.plot(interp_time, cur_flicker_raster, label='Flicker Raster')
            #plt.plot(flicker_start, 1, '|', label='Flicker Start')
            #plt.show()
            #if mat_i == 1:
            #    raise Exception("Stopping execution")

            mask = np.zeros(cur_flicker_raster.shape[0])

            #Construct mask for flicker
            for i in np.where(cur_flicker_raster == 1)[0]:
                mask[np.arange(i, i + flick_pts + 1, dtype=int)] = 1
            mask = mask.astype(bool)
            cur_flicker_raster[mask] = 1

            trace_dict = {'frame_time':cur_frame_time,
                          'interp_time':interp_time,
                          'raw_trace':raw_trace,
                          'sub_vm':sub_vm,
                          'detrend_trace':detrend_trace,
                          'interp_subvm':interp_subvm,
                          'trace_noise':cur_trace_noise,
                          'interp_raw_trace':interp_raw_trace,
                          'interp_raw_detrend_trace':interp_raw_detrend_trace,
                          'spike_raster':cur_spike_raster,
                          'interp_spike_raster':interp_spike_raster_adj,
                          'spike_amp_raster':spike_amp_raster,
                          'flicker_raster':cur_flicker_raster,
                          'stim_raster':cur_stim_raster,
                          'stim_freq':cur_stim_freq,
                          'flicker_freq':cur_flicker_freq,
                          'stim_param':stim_param_string,
                          'mouse_id':mouse_id,
                          'session_id':session_id,
                          'trial_id':trial_i,
                          'fov_id':int(cur_fov.replace('fov', '')),
                          'roi_id':roi_i,                          
                          }
            
            df = pd.concat([df, pd.DataFrame(trace_dict)], ignore_index=True, join='outer')

            df['stim_freq'] = df['stim_freq'].astype('category')
            df['mouse_id'] = df['mouse_id'].astype('category')
            df['session_id'] = df['session_id'].astype('category')

# Print the not aligned matfiles
if not not not_aligned:
    print('Not aligned Matfiles')
    for mf in not_aligned:
        print(mf)


0 /home/pierfier/handata_server/eng_research_handata3/Yangyang_Wang/PV_V1_LED_SomArchon/109558_Vb_male/20240308/RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov1_100ua_red__all.mat
Missing Corresponding Ephys Folder
Matfile: RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov1_100ua_red__all.mat
Using model ephys data
1 /home/pierfier/handata_server/eng_research_handata3/Yangyang_Wang/PV_V1_LED_SomArchon/109558_Vb_male/20240308/RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov2_100ua_red__all.mat
Missing Corresponding Ephys Folder
Matfile: RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov2_100ua_red__all.mat
Using model ephys data
2 /home/pierfier/handata_server/eng_research_handata3/Yangyang_Wang/PV_V1_LED_SomArchon/109558_Vb_male/20240308/RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov3_100ua_red__all.mat
Missing Corresponding Ephys Folder
Matfile: RawTracesV2_109558_20240308_flicker8hz_DBS140hz_fov3_100ua_red__all.mat
Using model ephys data
3 /home/pierfier/handata_server

In [13]:
df.columns
df['mouse_id'].unique()
df['fov_id'].unique()

array([1, 2, 3, 4, 5, 6])

In [21]:
print(np.min([np.where(interp_spike_raster_adj == 1)[0].shape[0], spike_amp.shape[0]]) + 1)
print(spike_amp)
print(np.where(interp_spike_raster_adj == 1)[0])
print(spike_idx)
plt.figure()
plt.plot(cur_frame_time, cur_spike_raster)
plt.plot(interp_time, interp_spike_raster_adj)
plt.show()

22
[0.01003602 0.0072997  0.00822962 0.00915412 0.0085015  0.0065599
 0.00791407 0.00712737 0.00542821 0.00637546 0.00936654 0.00662582
 0.00758872 0.00778126 0.00617549 0.0076527  0.00825255 0.00649919
 0.00789477 0.00569096 0.00659793]
[ 693  701  723 1085 1202 1244 1273 1277 1421 1434 1711 1797 1838 1860
 1884 1964 2069 2107 2155 2163 2222]
[ 704  712  734 1094 1211 1253 1281 1285 1429 1442 1717 1803 1844 1865
 1889 1969 2073 2111 2159 2167 2226]


In [4]:
# Figure out which trials to remove
ignore_trials = {
    '109558_Vb_male':{
        '20240311':{
            1:{40:[3, 4, 5, 8, 10],
               140:[1, 2, 8, 9]},
            2:{40:[4, 5, 7, 8, 9, 10],
               140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
            3:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
               140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
            4:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
               140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
         },
         '20240308':{
             1:{140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
             2:{140:[2, 5, 6, 8, 9, 10, 11, 14, 16, 17, 18, 19]},
             3:{140:[6, 7, 14, 15, 16]}
         }
    },
    '109567_Vb_male':{
        '20240311':{
            1:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
               140:[4, 8, 9]}, 
            2:{40:[3, 5, 6, 7, 8],
               140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
            3:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
               140:[]},
            4:{40:[2, 3, 6, 7, 8, 9, 10],
               140:[1, 2, 3, 4, 9, 10]},
            5:{40:[7, 8, 9],
               140:[1, 2, 3, 10]},
            6:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
               140:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
        },
        '20240411':{
            1:{40:[6, 7, 8, 10],
               140:[2, 3, 5, 8, 9, 10]},
            2:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
               140:[2, 3, 4, 5, 8, 9, 10]},
            3:{40:[1, 6, 7, 8, 9],
               140:[1, 3, 5]},
            4:{40:[4, 5, 7, 9],
               140:[7, 9, 10]},
            5:{40:[1, 6, 7],
               140:[1, 3, 4, 5, 6, 7]}
        },
        '20240424':{
            2:{40:[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]}
        }
    }
}

# Loop through dictionary and remove from dataframe
for m in ignore_trials.keys():
    for s in ignore_trials[m].keys():
        for fov in ignore_trials[m][s].keys():
            for freq in ignore_trials[m][s][fov].keys():
                df = df.drop(df[ (df['mouse_id'] == m) & \
                            (df['session_id'] == s) & \
                            (df['fov_id'] == fov) & \
                            (df['stim_freq'] == freq) & \
                (df['trial_id'].isin(np.array(ignore_trials[m][s][fov][freq]) - 1)) ].index)

In [15]:
print(df[(df['mouse_id'] == '109567_Vb_male') & (df['session_id'] == '20240311') & \
         (df['fov_id'] == 6)]['trial_id'].unique())

[]


In [5]:
# Save data frame to pickle file
save_filename = f + 'home' +f+ 'pierfier' +f+ 'Projects' +f+ 'Pierre Fabris' +f+ 'PV DBS neocortex' +f+ 'Interm_Data' +f+ 'flicker.pkl'
df.to_pickle(save_filename)

In [6]:
plt.close('all')

In [26]:
print(ephys_path)
for path in ephys_path:
    print(path)

['/home/pierfier/handata_server/eng_research_handata3/Yangyang_Wang/PV_V1_LED_SomArchon/109567_Vb_male/20240311/ephys/109567_flicker8hz_DBS140hz_fov1_12024-03-11_18-19-37_2']
/home/pierfier/handata_server/eng_research_handata3/Yangyang_Wang/PV_V1_LED_SomArchon/109567_Vb_male/20240311/ephys/109567_flicker8hz_DBS140hz_fov1_12024-03-11_18-19-37_2


In [27]:
print(len(frame_timestamps))
print(len(flicker_timestamps))
print(len(sp_info))
print(trial_i)

11
11
10
9


In [1]:
print(df['mouse_id'].unique())
print(df['fov_id'].unique())
print(df[ (df['fov_id'] == 1) & (df['mouse_id'] == '109558_Vb_male')])

NameError: name 'df' is not defined