In [1]:
__author__           = "Anzal KS"
__copyright__        = "Copyright 2022-, Anzal KS"
__maintainer__       = "Anzal KS"
__email__            = "anzalks@ncbs.res.in"
from pathlib import Path
import neo.io as nio
import numpy as np
import pandas as pd
from scipy import signal as spy
import matplotlib.pyplot as plt
from pprint import pprint
from matplotlib.offsetbox import TextArea, DrawingArea, OffsetImage, AnnotationBbox
from matplotlib.font_manager import FontProperties
import trace_pattern_plot_x_distributed as tpp
import matplotlib.image as mpimg
import multiprocessing
import time
import argparse

In [2]:
"""
Font and color defining functions
"""
y_labels = FontProperties()
y_labels.set_family('sans-serif')
y_labels.set_size('large')
#font.set_style('bold')

sub_titles = FontProperties()
sub_titles.set_family('sans-serif')
sub_titles.set_size('x-large')

main_title = FontProperties()
main_title.set_family('sans-serif')
main_title.set_weight('bold')
main_title.set_size('xx-large')
pre_color = "#377eb8" #pre_color blue
post_color = "#ff7f00" #post_color orange

In [3]:
selected_cells = '/Users/anzalks/Documents/Expt_data/Recordings/x_spread_CA1_recordings'

In [4]:
p = Path(selected_cells)

In [5]:
outdir = p/'result_plots_multi'
outdir.mkdir(exist_ok=True, parents=True)

In [6]:
"""
File listing functions
"""


'\nFile listing functions\n'

In [7]:
def list_folder(p):
    f_list = []
    f_list = list(p.glob('*_cell_*'))
    f_list.sort()
    return f_list

def list_files(p):
    f_list = []
    f_list=list(p.glob('**/*abf'))
    f_list.sort()
    return f_list

def image_files(i):
    f_list = []
    f_list = list(i.glob('**/*bmp'))
    f_list.sort()
    return f_list


In [8]:
"""
single cell analysis list:
1.cell selection data - RMP measure, Input resistance measur -over the course of experiment.
2.Plasticity proportion - % amplitude difference for points and patterns
3.Cell heterogenity factors - spike frequency before and after. 

"""

'\nsingle cell analysis list:\n1.cell selection data - RMP measure, Input resistance measur -over the course of experiment.\n2.Plasticity proportion - % amplitude difference for points and patterns\n3.Cell heterogenity factors - spike frequency before and after. \n\n'

In [9]:
"""
open the folder with multiple cell folders, iterate through each cell, 
perform analysis on each cell folder, save data from each cell.
use multiprocessing for each cell.
"""

'\nopen the folder with multiple cell folders, iterate through each cell, \nperform analysis on each cell folder, save data from each cell.\nuse multiprocessing for each cell.\n'

In [10]:
"""
General use functions - TTL finding, filters, etc...
"""

'\nGeneral use functions - TTL finding, filters, etc...\n'

In [11]:
"""
1D array and get locations with a rapid rise, N defines the rolling window
"""
def find_ttl_start(trace, N):
    data = np.array(trace)
    data -= data.min()
    data /= data.max()
    pulses = []
    for i, x in enumerate(data[::N]):
        if (i + 1) * N >= len(data):
            break
        y = data[(i+1)*N]
        if x < 0.2 and y > 0.75:
            pulses.append(i*N)
    return pulses


"""
data filter function
"""
def filter_data(data, cutoff, filt_type, fs, order=3):
    b, a = spy.butter(order, cutoff, btype = filt_type, analog=False, output='ba', fs=fs)                                                                                     
    return spy.filtfilt(b, a, data) 

"""
Convert channel names to index as an intiger
"""
def channel_name_to_index(reader, channel_name):
    for signal_channel in reader.header['signal_channels']:
        if channel_name == signal_channel[0]:
            return int(signal_channel[1])
        
"""
Detects the file name with training data (LTP protocol) in it 
"""
def training_finder(f_name):
    f = str(f_name)
    reader = nio.AxonIO(f)
    protocol_name = reader._axon_info['sProtocolPath']
    protocol_name = str(protocol_name).split('\\')[-1]
    protocol_name = protocol_name.split('.')[-2]
#    print(f'protocol name = {protocol_name}')
    if 'training' in protocol_name:
        f_name= f_name
    elif 'Training' in protocol_name:
        f_name = f_name
#        print(f'training {f_name}')
    else:
#        print('not training')
        f_name = None
#    print(f'out_ training prot = {f_name}')
    return f_name 

"""
Sort the list of suplied files into pre and post trainign files and return the list 
"""
def pre_post_sorted(f_list):
    found_train=False
    for f_name in f_list:
        training_f = training_finder(f_name)
#        print(f'parsed prot train = {training_f}')
        if ((training_f != None) and (found_train==False)):
            training_indx = f_list.index(training_f)
            # training indx for post will have first element as the training protocol trace
            pre = f_list[:training_indx]
            post = f_list[training_indx:]
#            pprint(f'training file - {training_f} , indx = {training_indx} '
#                f'pre file ={pre} '
#                f'post file = {post} '
#                )
            found_train = True
        elif ((training_f != None) and (found_train==True)):
            no_c_train = f_name
        else:
            pre_f_none, post_f_none, no_c_train = None, None, None
    return [pre, post, no_c_train, pre_f_none, post_f_none ]

"""
Tag protocols with training, patterns, rmp measure etc.. assign a title to the file
"""
def protocol_tag(file_name):
    f = str(file_name)
    reader = nio.AxonIO(f)
    protocol_name = reader._axon_info['sProtocolPath']
    protocol_name = str(protocol_name).split('\\')[-1]
    protocol_name = protocol_name.split('.')[-2]
    if '12_points' in protocol_name:
        #print('point_protocol')
        title = 'Points'
    elif 'patternsx' in protocol_name:
        #print('pattern protocol')
        title = 'Patterns'
    elif 'Training' in protocol_name:
        #print('training')
        title = 'Training pattern'
    elif 'training' in protocol_name:
        #print('training')
        title = 'Training pattern'
    elif 'RMP' in protocol_name:
        #print('rmp')
        title='rmp'
    elif 'Input_res' in protocol_name:
        #print ('InputR')
        title ='InputR'
    elif 'threshold' in protocol_name:
        #print('step_current')
        title = 'step_current'
    else:
        #print('non optical protocol')
        title = None
    return title

"""
Pair files pre and post with point, patterns, rmp etc..
"""
def file_pair_pre_pos(pre_list,post_list):
    point = []
    pattern = [] 
    rmp = []
    InputR = []
    step_current = []
    for pre in pre_list:
        tag = protocol_tag(pre)
#        print(f' tag on the file ={tag}')
        if tag=='Points':
            point.append(pre)
        elif tag=='Patterns':
            pattern.append(pre)
        elif tag =='rmp':
            rmp.append(pre)
        elif tag=='InputR':
            InputR.append(pre)
        elif tag =='step_current':
            step_current.append(pre)
        else:
            tag = None
            continue
    for post in post_list:
        tag = protocol_tag(post)
        if tag=='Points':
            point.append(post)
        elif tag=='Patterns':
            pattern.append(post)
        elif tag=='rmp':
            rmp.append(post)
        elif tag=='InputR':
            InputR.append(post)
        elif tag=='step_current':
            step_current.append(post)
        else:
            tag = None
            continue
    #pprint(f'point files = {point} '
    #       f'pattern files = {pattern}'
    #      )
    return [point, pattern,rmp, InputR, step_current]

"""
get cell trace data
"""
def cell_trace(file_name):
    f = str(file_name)
    reader = nio.AxonIO(f)
    channels =reader.header['signal_channels']
    chan_count = len(channels)
    file_id = file_name.stem
    block = reader.read_block(signal_group_mode='split-all')
    segments = block.segments
    sample_trace = segments[0].analogsignals[0]
    sampling_rate = sample_trace.sampling_rate.magnitude
    ti = sample_trace.t_start
    tf = sample_trace.t_stop
    cell_trace_all = []
    for s, segment in enumerate(segments):
        cell = channel_name_to_index(reader,'IN0')
        analogsignals = segment.analogsignals[cell]
        unit = str(analogsignals.units).split()[1]
        trace = np.array(analogsignals)
        cell_trace_all.append(trace) 
        t = np.linspace(0,float(tf-ti),len(trace))
    cell_  = (t,cell_trace_all)
#    print(cell_)
    return [cell_, sampling_rate]




In [12]:
"""
Analysis functions
"""
"""
Get peak events by taking TTL into account
"""
def peak_event(file_name):
    f = str(file_name)
    reader = nio.AxonIO(f)
    channels = reader.header['signal_channels']
    chan_count = len(channels)
    file_id = file_name.stem
    block  = reader.read_block(signal_group_mode='split-all')
    segments = block.segments
    sample_trace = segments[0].analogsignals[0]
    sampling_rate = sample_trace.sampling_rate.magnitude
    ti = sample_trace.t_start
    tf = sample_trace.t_stop
    cell_trace_all = []
    TTL_sig_all = []
    for s, segment in enumerate(segments):
        cell = channel_name_to_index(reader,'IN0')
        analogsignals = segment.analogsignals[cell]
        unit = str(analogsignals.units).split()[1]
        trace = np.array(analogsignals)
        cell_trace_all.append(trace) 
        t = np.linspace(0,float(tf-ti),len(trace))
#    print (f'IN0 = {cell_trace_all}')
    for s, segment in enumerate(segments):
        cell = channel_name_to_index(reader,'FrameTTL')
        analogsignals = segment.analogsignals[cell]
        unit = str(analogsignals.units).split()[1]
        trace = np.array(analogsignals)
        TTL_sig_all.append(trace) 
        t = np.linspace(0,float(tf-ti),len(trace))
#    print (f' TTL = {TTL_sig_all}')
    ttl_av = np.average(TTL_sig_all,axis=0 )
    ttl_xi= find_ttl_start(trace, 3)
    ttl_xf = (ttl_xi+0.2*sampling_rate).astype(int)
    #print(len(ttl_xf- ttl_xi))
    cell_trace  = np.average(cell_trace_all, axis =0)
    cell_trace_base_line = np.mean(cell_trace[0:2000] )
    cell_trace_av = cell_trace - cell_trace_base_line
    cell_trace_b_sub = cell_trace_all-cell_trace_base_line
#    print(f' baseline = {cell_trace_av}')
    #print(ttl_xi[0])
    #print(ttl_xf[0])
    event_av = []
    events = []
    for i,ti in enumerate(ttl_xi): 
        event_av.append(np.max(cell_trace_av[ttl_xi[i]:ttl_xf[i]]))
        pattern = []
        for n, ni in enumerate(cell_trace_b_sub):
            pattern.append(np.max(ni[ttl_xi[i]:ttl_xf[i]]))
        events.append(pattern)
    return [event_av, events]

"""
Summate the point files responsibel for the events. incase of X distributed patterns
"""
def summate_points_x_dist(points_file):
    point_events = peak_event(points_file)[1]
    point_events = np.mean(point_events,axis=1)
    p1= np.sum(point_events[0:5])
    p2= np.sum(point_events[2:7])
    p3= np.sum(point_events[3:8])
    p4= np.sum(point_events[4:9])
    p5= np.sum(point_events[5:10])
    p6= np.sum(point_events[6:11])
    p7= np.sum(point_events[7:-1])
    sum_of_points = (p1,p2,p3,p4,p5,p6,p7)
    #np.array(summate_points)
    print(f'length of summated points = {len(sum_of_points)}')
    return sum_of_points

"""
Input resistance measurement from a file used for the specific protcol
"""
def input_res(abf_for_InpR):
    pre_f = abf_for_InpR[0]
    post_f = abf_for_InpR[1]
    trace_data_pre = cell_trace(pre_f)
    trace_data_post = cell_trace(post_f)
    cell_pre=trace_data_pre[0]
    cell_post=trace_data_post[0]
    sampling_rate = trace_data_pre[1]
    injected_current= -20 #pA
    t= cell_pre[0]
    vms_pre = cell_pre[1]
    series_r_pre = series_res_measure(vms_pre,injected_current,sampling_rate)
    series_r_post = series_r_pre
    vms_post = cell_post[1]
    series_r_post = series_res_measure(vms_post,injected_current,sampling_rate)
    series_r_post= series_r_post
    series_r_f = np.round((series_r_post - series_r_pre),2)
    #print(f'series_r del = {series_r_f}')
    return series_r_f
"""
measure input resistance from injected cucrrent, cell trace and sampling rate
"""
def series_res_measure(abf_path,injected_current,t_start,t_end,sampling_rate):
    trace = cell_trace(abf_path)[0][1]
    vl = []
    for i in trace:
        v = np.mean(i[int(t_start*sampling_rate):int((t_start+0.2)*sampling_rate)])
        vl.append(v)
    vl =np.array(vl)
    vb = []
    for i in trace:
        v = np.mean(i[int(t_end*sampling_rate):int((t_end+0.2)*sampling_rate)])
        vb.append(v)
    vb =np.array(vb)
    d_v = vb-vl
    series_r = np.around(((d_v/injected_current)*1000),2)
    return series_r

"""
Injected_currentfinder
"""
def current_injected(abf_file):
    f = str(abf_file)
    reader = nio.AxonIO(f)
    channels =reader.header['signal_channels']
    chan_count = len(channels)
    file_id = abf_file.stem
    block = reader.read_block(signal_group_mode='split-all')
    segments = block.segments
    sample_trace = segments[0].analogsignals[0]
    sampling_rate = sample_trace.sampling_rate.magnitude
    unit = str(sample_trace.units).split()[1]
    ti = sample_trace.t_start
    tf = sample_trace.t_stop
    protocol_raw = reader.read_raw_protocol()
    protocol_raw = protocol_raw[0]
    protocol_trace = []
    for n in protocol_raw:
        protocol_trace.append(n[0])
    i_min = np.abs(np.min(protocol_trace))
    i_max = np.abs(np.max(protocol_trace))
    i_av = np.around((i_max-i_min),2)
    return i_av

"""
Raw epsp response amplitude values 
"""

def raw_peak_dist(points_or_pattern_file_set_abf):
    pre_f = points_or_pattern_file_set_abf[0]
    post_f = points_or_pattern_file_set_abf[1]
    epsp_pre = np.transpose(peak_event(pre_f)[1])
    epsp_post = np.transpose(peak_event(post_f)[1])
    epsp_events = [epsp_pre,epsp_post]
    return epsp_events

"""
membrane voltage shift will return how much voltage change happened 
for individual trials as well as the mean change for a file
"""

def vm_shift(abf_file):
    trace = cell_trace(abf_file)
    sampling_rate=trace[1]
    trace = trace[0][1]
    vm_i = []
    vm_f = []
    for i in trace:
        vm_i.append(np.mean(i[0:int(0.2*sampling_rate)]))
        vm_f.append(np.mean(i[-int(0.2*sampling_rate):-1]))
    vm_i= np.array(vm_i)
    vm_f = np.array(vm_f)
    #print(f'vm f = {vm_f},vm i = {vm_i}')
    vm_shift = np.around((vm_f-vm_i),2)
    return vm_shift
        
    
    

In [13]:
"""
single cell functions
"""
"""
pair pre and post, points and patterns for each cell.
"""
def file_pair(cell_path):
    cell_id = str(cell_path.stem)
    abf_list = list_files(cell_path)
    sorted_f_list = pre_post_sorted(abf_list)
    pre_f_list = sorted_f_list[0]
    post_f_list = sorted_f_list[1][1:]
    training_f = sorted_f_list[1][0]
    no_c_train = sorted_f_list[2]
    paired_list = file_pair_pre_pos(pre_f_list, post_f_list)
    paired_points = paired_list[0]
    paired_patterns = paired_list[1]
    return [paired_points,paired_patterns]


"""
series resistance changes for X distributed patterns
"""

def series_res_cell(cell_path):
    files_paired = file_pair(cell_path)
    points,patterns = files_paired[0],files_paired[1]
    point_pre, point_post = points[0],points[1]
    pattern_pre,pattern_post =patterns[0],patterns[1]
    sampling_rate = cell_trace(point_pre)[1]
    points_t_start, points_t_end = 4.0,4.55
    pattern_t_start, pattern_t_end =8.4,8.75
    injected_current = current_injected(point_pre)
    series_r_point_pre = series_res_measure(point_pre,injected_current,points_t_start,points_t_end,sampling_rate)
    series_r_point_post = series_res_measure(point_post,injected_current,points_t_start,points_t_end,sampling_rate)
    series_r_pattern_pre = series_res_measure(pattern_pre,injected_current,pattern_t_start,pattern_t_end,sampling_rate)
    series_r_pattern_post = series_res_measure(pattern_post,injected_current,pattern_t_start,pattern_t_end,sampling_rate)
    series_r_point = [series_r_point_pre,series_r_point_post]
    series_r_patern = [series_r_pattern_pre,series_r_pattern_post]
    series_r = [series_r_point,series_r_patern]
    return series_r


"""
pickup the changes in series resistances across recordings in the cell.
returns the difference in series resistance change across different protocols
give a percentage change from begining to end of the protocol.
"""
def series_r_shift(cell_path):
    series_r=series_res_cell(cell_path)
    point_pre = series_r[0][0]
    point_pre_mean = np.mean(point_pre)
    point_post = series_r[0][1]
    point_post_mean = np.mean(point_post)
    pattern_pre = series_r[1][0]
    pattern_pre_mean = np.mean(pattern_pre)
    pattern_post = series_r[1][1]
    pattern_post_mean = np.mean(pattern_post)
    point_sr_diff = point_post_mean-point_pre_mean
    pattern_sr_diff = pattern_post_mean-pattern_pre_mean
    cel_sr_diff= point_post_mean-point_pre_mean # first and last protocol is point based.
    pec_change = np.around((((np.abs(cel_sr_diff))/np.abs(point_pre_mean))*100),2)
    sr_diff = np.abs([point_sr_diff,pattern_sr_diff,cel_sr_diff])
    return [sr_diff,pec_change]

"""
take in the cell path and list down the Vms for different files
"""

def vm_shift_cell(cell_path):
    files_paired = file_pair(cell_path)
    points = files_paired[0]
    patterns = files_paired[1]
    vm_shift_points_pre = vm_shift(points[0])
    vm_shift_points_post = vm_shift(points[1])
    vm_shift_patterns_pre = vm_shift(patterns[0])
    vm_shift_patterns_post = vm_shift(patterns[1])
    vm_s_point = [vm_shift_points_pre,vm_shift_points_post]
    vm_s_pattern = [vm_shift_patterns_pre,vm_shift_patterns_post]
    vm_s_cell = np.mean(np.abs(vm_shift_points_post-vm_shift_points_pre))
    return vm_s_cell

In [14]:
"""
Functions used across cells
"""

"""
cell selection cutoffs, vm change less than 2.5, % change in series res = 25%
"""
def select_healthy_cells(cells):
    cell_health_data = []
    cell_status = 'fail'
    for c in cells:
        try:
            sr_stat=series_r_shift(c)
            percentage_sr = sr_stat[1]
            vm_sh = vm_shift_cell(c)
            if ((percentage_sr<20.0)and(vm_sh<2.5)):
                cell_status = 'pass'
            else:
                cell_status = cell_status
            cell_stat = [c.stem,percentage_sr,vm_sh,cell_status]
            cell_health_data.append(cell_stat)
            cell_status = 'fail'
        except:
            print(f'file {c} donot have cell data')
            continue
    column_names = ['cell_ID','series_R_change(%)','RMP_change(mV)','selection_status']
    cell_health_data = pd.DataFrame(cell_health_data,columns=column_names)
    return cell_health_data


In [15]:
"""
Plotting functions
"""

"""
plot raw peaks for epsps for each cell
"""
def plot_raw_epsp(cell_path):
    files_paired = file_pair(cell_path)
    raw_peaks_pre_points = raw_peak_dist(files_paired[0])
    raw_peaks_post_pattern = raw_peak_dist(files_paired[1])
    x_point = np.arange(0,len(raw_peaks_pre_points[0][0]),1)
    x_pattern = np.arange(0,len(raw_peaks_post_pattern[0][0]),1)
    for b,i in enumerate(raw_peaks_pre_points[0]):
        plt.scatter(x_point,i,color=pre_color)
        if (b+1)==len(raw_peaks_post_pattern[1]):
            plt.scatter(x_point,i,label = 'responses pre training',color=pre_color)
    for b,i in enumerate(raw_peaks_pre_points[1]):
        plt.scatter(x_point,i,color=post_color)
        if (b+1)==len(raw_peaks_post_pattern[1]):
            plt.scatter(x_point,i,label = 'responses post training',color=post_color)
    plt.title(f'{cell_path.stem} raw response to points')
    plt.show()
    plt.close()
    for a,j in enumerate(raw_peaks_post_pattern[0]):
        plt.scatter(x_pattern,j,color=pre_color)
        if (a+1)==len(raw_peaks_post_pattern[0]):
            plt.scatter(x_pattern,j,label = 'responses pre training',color=pre_color)
    for a,j in enumerate(raw_peaks_post_pattern[1]):
        plt.scatter(x_pattern,j,color=post_color)
        if (a+1)==len(raw_peaks_post_pattern[1]):
            plt.scatter(x_pattern,j,label = 'responses post training',color=post_color)
    plt.legend(bbox_to_anchor =(0.75, -0.15), ncol = 2)        
    plt.title(f'{cell_path.stem} raw response to patterns')
    plt.show()
    plt.close()
    
"""
plot trail by trail series resitance and mean 
"""
def plot_series_r(cell_path):
    series_r = series_res_cell(c)
    point_pre = series_r[0][0]
    point_pre_mean = np.mean(point_pre)
    point_post = series_r[0][1]
    point_post_mean = np.mean(point_post)
    pattern_pre = series_r[1][0]
    pattern_pre_mean = np.mean(pattern_pre)
    pattern_post = series_r[1][1]
    pattern_post_mean = np.mean(pattern_post)
    for a, i in enumerate(point_pre):
        plt.scatter(a,i, color=pre_color)
        if a==0:
            plt.scatter(a,i, color=pre_color, label='point pre')
    plt.axhline(y=point_pre_mean,color=pre_color)
    
    for a, i in enumerate(point_post):
        plt.scatter(a,i, color=post_color)
        if a==0:
            plt.scatter(a,i, color=post_color, label='point post')
    plt.axhline(y=point_post_mean,color=post_color)
    plt.ylabel('Mohm')
    plt.title(f'series resistance for points:{cell_path.stem}')
    plt.legend(bbox_to_anchor =(0.75, -0.15), ncol = 2) 
    plt.show()
    plt.close()
    
    for a, i in enumerate(pattern_pre):
        plt.scatter(a,i, color=pre_color)
        if a==0:
            plt.scatter(a,i, color=pre_color, label='pattern pre')
    plt.axhline(y=pattern_pre_mean,color=pre_color)
    
    for a, i in enumerate(pattern_post):
        plt.scatter(a,i, color=post_color)
        if a==0:
            plt.scatter(a,i, color=post_color, label='pattern post')
    plt.axhline(y=pattern_post_mean,color=post_color)    
    plt.ylabel('Mohm')
    plt.title(f'series resistance for patterns:{cell_path.stem}')
    plt.legend(bbox_to_anchor =(0.75, -0.15), ncol = 2) 
    plt.show()
    plt.close()

In [16]:
cells = list_folder(p)
cell_health_stat = select_healthy_cells(cells)

file /Users/anzalks/Documents/Expt_data/Recordings/x_spread_CA1_recordings/2022_11_03_cell_2 donot have cell data


In [17]:
cell_health_stat

Unnamed: 0,cell_ID,series_R_change(%),RMP_change(mV),selection_status
0,2022_10_03_cell_1,29.96,0.256,fail
1,2022_10_21_cell_1,5.96,0.278,pass
2,2022_10_21_cell_2,3.27,0.242,pass
3,2022_10_25_cell_1,8.8,0.106,pass
4,2022_10_26_cell_1,12.83,0.228,pass
5,2022_11_01_cell_1,10.97,0.266,pass
6,2022_11_01_cell_2,8.92,0.772,pass
7,2022_11_02_cell_1,11.46,0.422,pass
8,2022_11_02_cell_2,10.25,0.488,pass
9,2022_11_02_cell_3,8.36,0.214,pass
