# Stim and LFP phase
## How well is the system detecting target LFP phases?

### Needs emk_neuro_analysis environment
### load LFP, DIO and time data

In [1]:
import pandas as pd
import numpy as np
from pprint import pprint
from matplotlib import pyplot as plt
import matplotlib.dates as mdates
import itertools
from scipy.signal import sosfiltfilt, butter, hilbert
from scipy.stats import circstd
from scipy import stats

from emk_analysis import builder_experiment as bld_exp
# from emk_neuro_analysis.align import align_to_events_discrete, align_to_events_continuous
from emk_neuro_analysis.lfp import iterator as lfp_iter
from emk_neuro_analysis.position import iterator as pos_iter
from mountainlab_pytools import mdaio
from emk_analysis import iterator as emk_iter

from scipy import signal
from pathlib import Path
from rec_to_binaries.read_binaries import readTrodesExtractedDataFile

### Experiment parameters

In [2]:
# name of experiment
experiment_name = '6082737'

experiment_phase = 'stim'

# data drive
data_disk = 'nvme0'

# directory with the preprocessed/extracted data files
dir_preprocess = f'/media/{data_disk}/Data/{experiment_name}/preprocessing/'

# Figure folder, where you want to save the output figures. Usually in your experiment folder for analysis.
dir_fig = f'/media/{data_disk}/Analysis/{experiment_name}/Results/'

# Location of track config file. 
# This is an excel spreadsheet that specifies the identities of the DIO for your experiment.
fname_config_track = (f'/media/{data_disk}/Data/{experiment_name}/config/CLC_linear_Config_laser.xlsx')

# Location of day records. 
# This is an excel spreadsheet that lists details for each session on your experiment day.
dir_records = (f'/media/{data_disk}/Data/{experiment_name}/dayrecords/')

# chose the date - as a list
choose_dates = [ '20220608',]

# choose the epoch - as a list
epoch_list = [1, 2, 3]

# choose the tetrodes - as a list
# tet_list = [28, 27, 20, 19, ]
tet_list = [28, 27, 20, 19, 18]

### Build day records from track config file and experiment file

In [3]:
data_days = []
for curr_date in choose_dates:

    fname_day_record = f'{dir_records}{curr_date}_{experiment_phase}_training_record.xlsx'
    
    dict_sessions_day = bld_exp.build_day_from_file(experiment_name, 
                                track_config_file=fname_config_track,
                                day_record_file=fname_day_record)
    data_days.append(dict_sessions_day)

dict_sessions_all = bld_exp.build_all_sessions(data_days)
pprint(dict_sessions_all)

{'20220608_01': {'date': '20220608',
                 'description': 'maze stim',
                 'end': Timestamp('2022-06-08 23:55:00'),
                 'experiment': '6082737',
                 'id': 1,
                 'name': 'stim',
                 'start': Timestamp('2022-06-08 01:00:00'),
                 'tasks': {'Track 1': {'animal_id': 'clc',
                                       'description': 'S',
                                       'dio': {'10': {'bit': '_',
                                                      'notes': None,
                                                      'type': 'Unnamed:'},
                                               'exit_sensor': {'bit': 3,
                                                               'notes': None,
                                                               'type': 'in'},
                                               'laser_pump': {'bit': 2,
                                                              'notes':

### Import LFP data

In [4]:
lfp_data, lfp_timestamp, _ = lfp_iter.iterate_lfp_load(dir_preprocess,
                                                       tet_list,
                                                       choose_dates,
                                                       epoch_list=epoch_list, 
                                                       remove_movement_artifact=False,
                                                       filter_linenoise=True,
                                                       print_debug=False)

Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_02_stim.LFP tet 28
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_02_stim.LFP tet 27


  return np.dtype(typearr)


filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_02_stim.LFP tet 20
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_02_stim.LFP tet 19
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_02_stim.LFP tet 18
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_03_stim.LFP tet 28
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_03_stim.LFP tet 27
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_03_stim.LFP tet 20
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_03_stim.LFP tet 19
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_03_stim.LFP tet 18
filter 60Hz noise
Loaded /media/nvme0/Data/6082737/preprocessing/20220608/20220608_6082737_01_stim.LFP tet 28
filter 60H

### Transform time

In [5]:
time_dict = {}

fs = 30000

for i in lfp_timestamp.items():
    time_curr = i[1]
    time_dict.update({i[0]:{'timestamp':np.array(time_curr),
                            'time':(np.array(time_curr) - time_curr[0])/fs}})

In [6]:
%matplotlib notebook

for i, (k, v) in enumerate(lfp_data.items()):
    for e in epoch_list:
        plt.plot((np.array(lfp_timestamp.get(e))-np.array(lfp_timestamp.get(e))[0])/fs,
                 np.array(v.get(e))/1000+5*(i-10), lw=.15)

<IPython.core.display.Javascript object>

### Import DIO data

In [7]:
# Load DIO
%matplotlib notebook

plot_DIO = False

filter_retrigger = 0

# time plotting settings
tick_minutes = mdates.MinuteLocator(interval=5)
tick_minutes_fmt = mdates.DateFormatter('%H:%M')
tick_minor = mdates.SecondLocator(interval=10)

# Specify parameters
dict_sensor_pump_map = {2: {'pump': 'laser_pump'},
                        6: {'sensor': 'reward_1_sensor'},}

# list dio to extract
list_dio = [2, 6]
y_label = ['laser',
           'reward 1',]

# plot each session
# get data for each animal
# initiate output
dict_dio_out = {}

dict_dio_in = {}

for animal_id in ['clc', ]:
    
    print(animal_id)
    cls_behavior = emk_iter.ProcessBehavior(dict_sessions_all,
                                        experiment_name, trodes_version=2)
    cls_behavior.filter_animals(animal_id)
    dict_rewards = cls_behavior.count_reward_delivered()
    
    if not dict_rewards:
        continue
        
    df_pump = cls_behavior.report_reward_delivered(remove_zeroth=False,
                                                   output_raw=False,
                                                   filter_retrigger=None)
    
    df_sensor = cls_behavior.report_triggers(remove_zeroth=False,
                                             output_raw=False,
                                             filter_retrigger=None)
    
    # get unique sessions
    sessions_unique = np.sort(df_sensor['session'].unique())
    print(sessions_unique)
    n_subplots = len(sessions_unique)
    
    if plot_DIO:
        fig = plt.figure(figsize=(10, n_subplots*3+2))
        axs = fig.subplots(n_subplots, 1)
        if n_subplots == 1:
            axs = [axs, ]
            sessions_unique = [sessions_unique[0], ]
    
    else:
        axs = [0]*len(sessions_unique)
    
    # divide into sessions
    for sn, (ax, session) in enumerate(zip(axs, sessions_unique)):
        # get session times
        curr_start = dict_sessions_all.get(session).get('start')
        curr_end = dict_sessions_all.get(session).get('end')
        # get sensor and pump times
        df_sensor_curr = df_sensor[df_sensor['session']==session]
        df_sensor_curr = df_sensor_curr[(df_sensor_curr['on_time_sys']>=curr_start)
                                       & (df_sensor_curr['on_time_sys']<curr_end)]
        df_pump_curr = df_pump[df_pump['session']==session]
        df_pump_curr = df_pump_curr[(df_pump_curr['on_time_sys']>=curr_start)
                                       & (df_pump_curr['on_time_sys']<curr_end)]
        dict_dio_out.update({int(session.split('_')[1]): df_pump_curr})
        dict_dio_in.update({int(session.split('_')[1]): df_sensor_curr})
        
        if not plot_DIO:
            continue
        
        # plot DIO data for all sessions
        for i, d in enumerate(list_dio):
            #print(d)
            yval = i+1
            curr_pump_name = dict_sensor_pump_map.get(d).get('pump')
            df_plot_pump = df_pump_curr[df_pump_curr['dio']==curr_pump_name]
            curr_sensor_name = dict_sensor_pump_map.get(d).get('sensor')
            df_plot_sensor = df_sensor_curr[df_sensor_curr['dio']==curr_sensor_name]
            # plot well triggers
            
            for ind, row in df_plot_sensor.iterrows():
                ax.scatter(row['on_time_sys'], yval+.3, s=25, c='k')
                
            for ind, row in df_plot_pump.iterrows():
                
                try:
                    ax.plot([row['on_time_sys'],
                              row['off_time_sys']], [yval+.15, yval+.15], c='r')
                
                except:
                    pass

clc
no dataframe created - data not suitable


  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a

no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable


  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['ti

no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable


  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view

no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable


  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time


no dataframe created - data not suitable


  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view

no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable


  return np.dtype(typearr)
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time
  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time


no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable
no dataframe created - data not suitable
['20220608_01' '20220608_02' '20220608_03']


  return np.dtype(typearr)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_off['time_sys_diff'] = df_off['time_sys'] - curr_time


In [8]:
stim_time_dict = {}
for item in dict_dio_out.items():
    df_pump_curr = item[1]
    key_curr = item[0]
    for i, d in enumerate(list_dio):
        # extract the wanted DIO data
        try:
            df_pump_timestamp = np.array(df_pump_curr[df_pump_curr['dio_bit']=='ECU_Dout2'].get('on_time'))
            df_pump_time = (df_pump_timestamp - time_dict.get(key_curr).get('timestamp')[0])/fs
            stim_time_dict.update({item[0]:{'timestamp':df_pump_timestamp,'time':df_pump_time}})
        except AttributeError:
            pass

### Identify, select phase targets

In [9]:
def index_in(butter_filter, raw_data, num_std=0):
    filtered_data = sosfiltfilt(butter_filter,raw_data)
    filtered_envelope = np.abs(hilbert(filtered_data))
    bar = np.mean(filtered_envelope) + num_std*np.std(filtered_envelope)
    index_in = np.array(filtered_envelope) > bar
    return index_in


def true_phase(butter_filter, raw_data):
    filtered_data = sosfiltfilt(butter_filter,raw_data)
    filtered_phase = np.angle(hilbert(filtered_data)) + np.pi
    return filtered_phase


def generate_matrix(regr_buffer_size):
    sampling_axis = np.arange(regr_buffer_size)
    A = np.vstack([sampling_axis, np.ones(len(sampling_axis))]).T
    return A


def calculate_derv(A, buffer):
    curr_regr = buffer[:, np.newaxis]
    pinv = np.linalg.pinv(A)
    alpha = pinv.dot(curr_regr)
    return alpha[0][0]


def critical_point(derivative_history,time,regr_buffer_size,index_in,which,num_to_wait=5):
# def critical_point(derivative_history,time,regr_buffer_size,num_to_wait=5):
    critical_index = []
    critical_time = []
    current_sign = True
    sign_buffer = []

    for i in range(len(derivative_history)):
        if i == 0:
            current_sign = (derivative_history[0]>0)
            sign_buffer = [derivative_history[0]>0]*num_to_wait
            continue

        sign_buffer.append(derivative_history[i]>0)
        sign_buffer.pop(0)

        flip = True
        for sign in sign_buffer:
            if current_sign == sign:
                flip = False
                break
    
        if flip:
            current_sign = sign_buffer[-1]
            if which: 
            # which == True means you want the peaks
                if (not current_sign) and index_in[i+int(regr_buffer_size/2)-num_to_wait]:
                    # at peaks, current_sign already flipped to be false
                    critical_index.append(i+int(regr_buffer_size/2)-num_to_wait)
                    critical_time.append(time[i+int(regr_buffer_size/2)-num_to_wait])
            else:  
            # otherwise you want the troughs
                if current_sign and index_in[i+int(regr_buffer_size/2)-num_to_wait]: 
                    # at troughs, current_sign flipped to be true
                    critical_index.append(i+int(regr_buffer_size/2)-num_to_wait)
                    critical_time.append(time[i+int(regr_buffer_size/2)-num_to_wait])
    
    return np.array(critical_index), np.array(critical_time)


def select_target(butter_filter, raw_data, time, target, index_in, error_bound=0.0015, buffer_size=200):
    
    if target == 0 or target == np.pi:
        
        filtered_data = sosfiltfilt(butter_filter,raw_data)
        derivative_history = []
        A = generate_matrix(buffer_size)
        
        for i in range(buffer_size,len(filtered_data)):
            curr_buffer = filtered_data[i-buffer_size:i]
            curr_derv = calculate_derv(A,curr_buffer)
            derivative_history.append(curr_derv)
        
        target_index, target_time = critical_point(derivative_history,time,buffer_size,index_in,
                                                    target/np.pi,num_to_wait=3)
        
    else:
        
        df_phase = pd.DataFrame({"phase":true_phase(butter_filter, raw_data), 
                                 "time":time, "index_in":index_in})
        selected_targets = df_phase[(np.abs(df_phase["phase"]-target) <= error_bound) & 
                                     df_phase['index_in']]
        target_index = np.array(selected_targets.index)
        target_time = np.array(selected_targets["time"])
        
    return target_index, target_time


def calculate_accuracy(target_time, hardware_on_time, window=0.060):
    # find the timepoint when detection started
    detection_start = np.searchsorted(np.array(target_time) - hardware_on_time[0], [0])[0]
    target_time = target_time[detection_start:]
    
    # determine if there's a stim event within window time around phase target
    hit = []
    for target in target_time[detection_start:]:
        for i in range(len(hardware_on_time)):
            if 0 <= (hardware_on_time[i]-target) <= window:
                hit.append(hardware_on_time[i])
                break

    return np.array(hit), len(hit)/len(target_time[detection_start:])


def calculate_precision(target_time, hardware_on_time, window=0.060):
    # find the timepoint when detection started
    detection_start = np.searchsorted(np.array(target_time) - hardware_on_time[0], [0])[0]
    target_time = target_time[detection_start:]
    
    # determine if there's a phase target within window time around stim
    precise = []
    for hardware in hardware_on_time:
        for target in target_time[detection_start:]:
            if 0 <= (hardware-target) <= 0.100:
                precise.append(hardware_on_time[i])
                break
 
    return precise, len(precise)/len(hardware_on_time)


def replace_with_nearest(lst1, lst2):
    '''
    replace lst1 elements with their nearest elements
    in lst2
    '''
    return lst2[[np.abs(lst2 - e).argmin() for e in lst1]]

### Identify phase targets

In [10]:
target_lowcut = 4
target_highcut = 10
buffer_size = 200

# session_target_map = {1:0,
#                       2:np.pi,
#                       3:np.pi/2,
#                       4:3*np.pi/2}

# session_target_map = {1:np.pi/2,
#                       2:3*np.pi/2,}

session_target_map = {1:np.pi,
                      2:np.pi,
                      3:np.pi}

tetrode = 27
session = 1

butter_filter = butter(1, [target_lowcut, target_highcut], 'bp', fs=fs/20, output='sos')
raw = lfp_data[tetrode][session]
filtered = sosfiltfilt(butter_filter, raw)

# LFP time
time = time_dict.get(session).get('time')
timestamp = time_dict.get(session).get('timestamp')

# stim time
stim_time = stim_time_dict.get(session).get('time')
stim_timestamp = replace_with_nearest(stim_time_dict.get(session).get('timestamp'), 
                                      time_dict.get(session).get('timestamp'))

In [11]:
# identify stim periods
stim_pivots = np.where(np.diff(np.append(-19, stim_time)) > 1)[0][1:]
stim_timestamp_cut = np.split(stim_timestamp, stim_pivots)
stim_start = np.array([s[0] for s in stim_timestamp_cut])
stim_end = np.array([s[-1] for s in stim_timestamp_cut])

In [12]:
# power_bar - only consider periods of considerable theta power
power_bar = False

if power_bar:
    index_in = index_in(butter_filter, raw)
else: 
    index_in = [True]*len(raw)

target_index, target_time = select_target(butter_filter, raw, time, 
                                          session_target_map.get(session), 
                                          index_in, error_bound=0.015)

# select targets in the stim periods
real = np.isin(time_dict.get(session).get('timestamp')[target_index],
               np.hstack([np.linspace(start, end, (end-start+1))
                          for start, end in zip(stim_start, stim_end)]))

### Visualize LFP and stim

In [13]:
# visualize LFP, targets, and stims
%matplotlib notebook

filtered = sosfiltfilt(butter_filter, raw)

# raw signal
plt.plot(time_dict.get(session).get('time'), raw/1000, c='lightgray', label='Raw')

# filtered signal
plt.plot(time_dict.get(session).get('time'), filtered/1000, lw=3, color='#00AEEF', label='6-10Hz')

# stimulation events
for on_time in stim_time:
    plt.plot([on_time, on_time], [-2, 2], c='r', lw=1.5)

# targets
plt.scatter(target_time[real], (filtered/1000)[target_index[real]],
            c='k', s=25, zorder=3, label='target')

# plt.scatter(target_time, (filtered/1000)[target_index],
#             c='k', s=25, zorder=3, label='target')

# legend placeholder
# plt.plot([0],[0], c='r', label='Stim')
# plt.legend()

plt.xlabel('Time (s)')
plt.ylabel('Signal (uV)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Signal (uV)')

In [14]:
### inspect good detection cycles and bad detection cycles respectively ###

win_size_time = 140 # in ms
win_size_timestamp = win_size_time*(10**(-3))*fs

df_cycles = pd.DataFrame({'start_idx':[],
                          'start_timestamp':[],
                          'end_idx':[],
                          'end_timestamp':[],
                          'stim_time_aligned':[],
                          'stim_time_real':[],
                          'stim_count':[],
                          'raw':[],
                          'filtered':[]})

# use theta peaks as window end
for t in target_index[real]:
    win_end = timestamp[t]
    # index of theta cycle window start
    win_start_idx = np.abs(timestamp-(win_end-win_size_timestamp)).argmin()
    # real stim time
    stim_time_real = stim_time[((stim_timestamp > timestamp[win_start_idx])
                                  &(stim_timestamp < win_end))]
    # relative stim time in milliseconds aligned to theta peak
    stim_time_aligned = (stim_time_real - time[t])*1000 + win_size_time
    df_cycles = df_cycles.append({'start_idx':win_start_idx,
                                  'start_timestamp':timestamp[win_start_idx],
                                  'end_idx':t,
                                  'end_timestamp':win_end,
                                  # if the cycle only has one stim, it's good
                                  'stim_time_aligned':stim_time_aligned,
                                  'stim_time_real':stim_time_real,
                                  'stim_count':len(stim_time_aligned),
                                  'raw':raw[win_start_idx:t],
                                  'filtered':filtered[win_start_idx:t]},
                                  ignore_index=True)

### Compare good and bad cycles

In [15]:
signal_type='raw'

group_1 = df_cycles[df_cycles['stim_count']==1].get(signal_type).tolist()
group_1_avg = np.average(group_1, axis=0)
group_1_sem = stats.sem(group_1, axis=0)
group_1_std = stats.tstd(group_1, axis=0)

group_2 = df_cycles[df_cycles['stim_count']!=1].get(signal_type).tolist()
group_2_avg = np.average(group_2, axis=0)
group_2_sem = stats.sem(group_2, axis=0)
group_2_std = stats.tstd(group_2, axis=0)

In [16]:
%matplotlib notebook

# stim distribution
stim_good_hist, stim_good_edges = np.histogram(np.hstack(df_cycles[df_cycles['stim_count']==1].get('stim_time_aligned').tolist()), 
                                               bins=50, range=(0,140))

stim_bad_hist, stim_bad_edges = np.histogram(np.hstack(df_cycles[df_cycles['stim_count']!=1].get('stim_time_aligned').tolist()), 
                                               bins=50, range=(0,140))

plt.bar(stim_bad_edges[:-1], stim_bad_hist, color='r', alpha=0.4)
plt.bar(stim_good_edges[:-1], stim_good_hist, color='b', alpha=0.8)

# plt.plot(stim_good_edges[:-1], stim_good_hist/sum(stim_good_hist), color='b', alpha=0.4)
# plt.plot(stim_bad_edges[:-1], stim_bad_hist/sum(stim_bad_hist), color='r', alpha=0.4)
# plt.plot(stim_good_edges[:-1], stim_good_hist, color='b', alpha=0.4)
# plt.plot(stim_bad_edges[:-1], stim_bad_hist, color='r', alpha=0.4)

shrink_factor = 10

plt.plot(np.linspace(0, win_size_time, len(group_1_avg)), group_1_avg/shrink_factor, c='b')
plt.fill_between(np.linspace(0, win_size_time, len(group_1_avg)),
                 (group_1_avg + group_1_sem)/shrink_factor,
                 (group_1_avg - group_1_sem)/shrink_factor,
                 alpha=0.4, label='Good', color='b')

plt.plot(np.linspace(0, win_size_time, len(group_2_avg)), group_2_avg/shrink_factor, c='r')
plt.fill_between(np.linspace(0, win_size_time, len(group_2_avg)),
                 (group_2_avg + group_2_sem)/shrink_factor,
                 (group_2_avg - group_2_sem)/shrink_factor,
                 alpha=0.4, label='Bad', color='r')

plt.legend()
plt.xlabel('Time (ms)')
plt.ylabel('Signal (uV)')
plt.title('Theta Cycles Averaged')

<IPython.core.display.Javascript object>

Text(0.5, 1.0, 'Theta Cycles Averaged')

In [17]:
%matplotlib notebook

# signal_type = 'raw'
signal_type = 'filtered'
# shrink_factor = 50
offset = 1000

for ind, row in df_cycles.iterrows():
    if row['stim_count'] == 1:
        plt.plot(np.linspace(0, win_size_time, len(row[signal_type])), 
                 row[signal_type]/shrink_factor+offset, c='b', alpha=0.1)
        # pass
    else:
        plt.plot(np.linspace(0, win_size_time, len(row[signal_type])), 
                 row[signal_type]/shrink_factor+2*offset, c='r', alpha=0.1)
        # pass

# stim distribution
plt.bar(stim_bad_edges[:-1], stim_bad_hist*shrink_factor, 
        color='r', alpha=0.4, label='Bad Stim')
plt.bar(stim_good_edges[:-1], stim_good_hist*shrink_factor, 
        color='b', alpha=0.8, label='Good Stim')

# plt.plot(stim_good_edges[:-1], stim_good_hist/sum(stim_good_hist), color='b', alpha=0.4)
# plt.plot(stim_bad_edges[:-1], stim_bad_hist/sum(stim_bad_hist), color='r', alpha=0.4)
# plt.plot(stim_good_edges[:-1], stim_good_hist, color='b', alpha=0.4)
# plt.plot(stim_bad_edges[:-1], stim_bad_hist, color='r', alpha=0.4)
        
plt.plot([0],[0],c='b',label='Good Cycles')
plt.plot([0],[0],c='r',label='Bad Cycles')
plt.legend()
plt.xlabel('Time (ms)')
plt.ylabel('Signal (uV)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Signal (uV)')

### Signal aligned to stim event groups

In [19]:
def align_to_events_continuous(data_in, times_in, ref_events, win):
    """
    Aligns a continuous data to reference events.
    Example usage: align LFP relative to known events.
    All event times are in seconds

    :param data_in: array of values (LFP signal)
    :param times_in: array of timestamps for values
    :param ref_events: array of event times to align to
    :param win: window length around alignment time in seconds
    :return: array of event values and times aligned to ref_events
    """

    arr_out_data = []
    arr_out_time = []

    for e in ref_events:

        # get time indices waround window
        try:
            ind_start = np.argmax(times_in >= e - win)
            ind_end = np.argmax(times_in >= e + win)
            data_incl = data_in[ind_start:ind_end]
            time_incl = times_in[ind_start:ind_end] - e
            arr_out_data.append(data_incl)
            arr_out_time.append(time_incl)
        except:
            print(f'{e} failed')

    return arr_out_data, arr_out_time

In [20]:
### good stim - CONTROL group - only one stim in cycle ###
good_stim = np.hstack(df_cycles[df_cycles['stim_count']==1].get('stim_time_real'))

In [21]:
### off stim - stim events delayed to opposite phase ###
off_stim = np.hstack(df_cycles.get('stim_time_real'))[np.hstack(df_cycles.get('stim_time_aligned')) > 60]

In [22]:
### double stim - first ones and second ones ###
double_stim = np.hstack(df_cycles[df_cycles['stim_count']==2].get('stim_time_real'))
first_stim = double_stim[::2]
second_stim = double_stim[1::2]

In [53]:
stim_group_dict = {}

win_size = 0.1
tvals = np.linspace(-win_size, win_size, 300)

# for s_t, n in zip([good_stim, off_stim, first_stim, second_stim],
#                   ['good', 'off', 'first', 'second']):

for s_t, n in zip([good_stim, first_stim, second_stim],
                  ['good', 'first', 'second']):

# for s_t, n in zip([off_stim], ['off']):
    
    raw_stim, time_stim = align_to_events_continuous(raw, time, s_t, win_size)
    filtered_stim, _ = align_to_events_continuous(filtered, time, s_t, win_size)
    
    raw_interp = []
    filtered_interp = []
    for r, f, t in zip(raw_stim, filtered_stim, time_stim):

        try:
            raw_interp.append(np.interp(tvals, t, r))
            filtered_interp.append(np.interp(tvals, t, f))

        except ValueError:
            pass
    
    stim_group_dict.update({n:{
        'raw_interp':raw_interp,
        'filtered_interp':filtered_interp,}})

In [55]:
%matplotlib notebook

group_color_map = {'good':'green',
                   'off':'gray',
                   'first':'red',
                   'second':'blue'}

signal_type = 'raw'
# signal_type = 'filtered'

for n, s_g in stim_group_dict.items():
    
    curr_color = group_color_map.get(n)
    
    for s in s_g.get(signal_type+'_interp'):
        
        plt.plot(tvals, s, c=curr_color, alpha=0.2)

for n, color in group_color_map.items():
    plt.plot([0], [0], c=color, label=n)
        
plt.legend()
plt.xlabel('Time (s)')
plt.ylabel('Signal (uV)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Signal (uV)')

In [56]:
%matplotlib notebook

for n, s_g in stim_group_dict.items():
    
    curr_color = group_color_map.get(n)
    avg = np.average(s_g.get(signal_type+'_interp'), axis=0)
    sem = stats.sem(s_g.get(signal_type+'_interp'), axis=0)
    plt.plot(tvals, avg, c=curr_color)
    plt.fill_between(tvals, avg-sem, avg+sem, color=curr_color, alpha=0.2)

for n, color in group_color_map.items():
    plt.plot([0], [0], c=color, label=n)

plt.legend()
plt.title(f'{signal_type} signal')
plt.xlabel('Time (s)')
plt.ylabel('Signal (uV)')

<IPython.core.display.Javascript object>

Text(0, 0.5, 'Signal (uV)')

### Accuracy and Precision

In [28]:
hit, accuracy = calculate_accuracy(target_time[real], stim_time)
precise, precision = calculate_precision(target_time[real], stim_time)
# precise, precision = calculate_precision(target_time, stim_time)
print(accuracy)
print(precision)

0.8642335766423358
0.7205673758865249


### Lag

In [29]:
# analyze general lag
%matplotlib notebook

on_lag_hist, on_lag_edges = np.histogram(np.hstack([stim - target_time[real] for stim in stim_time])*1000, 
                                         bins=20, range=(-100,100))

plt.plot(on_lag_edges[:-1],
         on_lag_hist/sum(on_lag_hist),)

plt.ylabel('Probability')
plt.xlabel('Lag (ms)')

<IPython.core.display.Javascript object>

Text(0.5, 0, 'Lag (ms)')

# LFP movement artifact
### determine frequency composition of movement artifact

In [32]:
%matplotlib notebook

# for e in epoch_list:
#     for t in tet_list:
for t in tet_list:
    # for e in epoch_list:
    plt.figure()
    for e in [1, 3]: 
        # plt.figure()
        
        # f, Pxx_spec = signal.welch(stim_period_only, fs, nperseg=fs, noverlap=fs/2)
        # f, Pxx_spec = signal.welch(raw, fs, nperseg=4*fs, noverlap=fs/2)
        f, Pxx_spec = signal.welch(lfp_data.get(t).get(e), fs, nperseg=4*fs, noverlap=fs/2)
        
        plt.scatter(f, Pxx_spec)
        plt.plot(f, Pxx_spec, label=('ses_%d'%e))
        # plt.scatter(f, np.sqrt(Pxx_spec))
        # plt.plot(f, np.sqrt(Pxx_spec))
        
        plt.legend()
        plt.xlabel('frequency [Hz]')
        plt.ylabel('Linear spectrum [V RMS]')
        plt.xlim([0, 20])
        
        # title = 'PSD-session%d-tet%d' %(e,t)
        title = 'PSD-session-comparison-tet%d' %t
        plt.title(title)
        plt.savefig(f"{dir_fig}{choose_dates[0]}_{experiment_name}_{title}.png")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>