In [None]:
%matplotlib inline
__author__           = "Anzal KS"
__copyright__        = "Copyright 2022-, Anzal KS"
__maintainer__       = "Anzal KS"
__email__            = "anzalks@ncbs.res.in"
from pathlib import Path
import neo.io as nio
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as spy
import pandas

In [None]:
cell_folder = Path('/Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings')
csv_file = Path('/Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/experiment_notes.csv')
#outdir = cell_folder.parents[0]/'results_plot'
outdir = cell_folder/'results_plot'
outdir.mkdir(exist_ok=True,parents=True)

In [None]:
"""
file organising fucntions
"""
def list_files(p):
    f_list = []
    f_list=list(p.glob('**/*abf'))
    f_list.sort()
    return f_list

def find_csv(c):
    csv_file = str(c)
    csv_file = pandas.read_csv(csv_file,header=1)
    return(csv_file)


In [None]:
abf_files = list_files(cell_folder)
cell_data = find_csv(csv_file)

In [None]:
"""
assign cell and experiment data to files
"""
def expt_cell_assigner():
    cells = cell_data['cell_ID']
    cell_expt_data = []
    for cell in cells:
        cell_expt_data.append(cell_data[cell_data['cell_ID']==cell])
    print(cell_expt_data)

In [None]:
"""
extra functions
"""

"""
smoothening function
"""
def smooth(data, kernel_size):
    '''apply a smoothing kernel of length kernel_size to 1D data
    '''
    kernel = np.ones(kernel_size) / kernel_size
    return np.convolve(data, kernel, mode='same')



"""
data filter function
"""
def filter_data(data, cutoff, filt_type, fs, order=3):
    '''
    creates and applies a filter
    '''
    b, a = spy.butter(order, cutoff, btype = filt_type, analog=False, output='ba', fs=fs)
    return spy.filtfilt(b, a, data)
    """
    b, a = spy.butter(order, cutoff, analog=False, output='ba', fs=fs)
    return spy.filtfilt(b, a, data)
    """

def downsampling_funct(d_array, initial_fs, final_fs):
    """
    down sampling function
    downsampling from 32 khz to 1khz
    initial_fs = initial smapling rate
    final_fs = final sampling rate
    1_d_array = 1 d rray with samples(ep_single channel data)
    """
    downsampling_factor = np.floor(initial_fs/final_fs)
    dat_idx = np.arange(0,len(d_array), downsampling_factor).astype(np.int32)
    ds_dat = d_array[dat_idx]
    return ds_dat

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = spy.butter(N=order, Wn=[low, high], btype='bandpass')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = spy.lfilter(b, a, data)
    return y


In [None]:
"""
plot 3 pattern responses, extreme left, middle and right end
"""

def plot_select_patterns(f,outdir,p_no):
    f_plt = f'{outdir}/{f.stem}'
    cell_info =cell_data[cell_data['file_name']==f'{f.stem}.abf']
    expt_details = cell_info.to_string(index=False)
    cell_id = cell_info['cell_ID'].to_string(index=False,header=False)
    if 'cell' in cell_id:
        print(f'the cell id = {cell_id}')
    else:
        cell_id = 'non_categorised'
    f_plt_pat = outdir/cell_id
    f_plt_pat.mkdir(exist_ok=True,parents=True)
    f_plt_pat=f_plt_pat/f'{f.stem}_pattern_'
    expt_details = cell_data[cell_data['file_name']==f'{f.stem}.abf'].to_string(index=False)
    file_name = f'{f.stem}'
    f=str(f)
    reader = nio.AxonIO(f)
    channels = reader.header['signal_channels']
    chan_count = len(channels)
    block  = reader.read_block(signal_group_mode='split-all')
    segments = block.segments
    sample_trace = segments[0].analogsignals[0]
    sampling_rate = sample_trace.sampling_rate
    ti = sample_trace.t_start
    tf = sample_trace.t_stop
    total_time = int(tf-ti)
    protocol_raw = reader.read_raw_protocol()
    protocol_raw = protocol_raw[0]
    #print(total_time)
    unit = str(sample_trace.units).split()[1]
    #        print(unit)
    fig, axs = plt.subplots(2,1,figsize=(8,6),sharex='col')
    for i in range(chan_count):
        trace_average = []
        for s, segment in enumerate(segments):
            analogsignals = segment.analogsignals[i]
            unit = str(analogsignals.units).split()[1]
            if unit =='mV':
                unit ='mV'
            elif unit =='pA':
                unit ='pA'
            else:
                unit='Signal amplitude'
                    
            trace = np.array(analogsignals)
            #trace = smooth(trace, 4)
            #trace  = downsampling_funct(trace, int(sampling_rate), 2000)
            trace_average.append(trace)
            #    print(f'length of trace = {len(trace)}')
            t = np.linspace(0,float(tf-ti),len(trace))
            #downsampling_funct(t, int(sampling_rate), 2000)
            if i ==0:
                axs[0].plot(t,trace,alpha=0.7,linewidth=2, label = f'trial - {s+1}')
                axs[0].set_title('recording')
                axs[0].set_ylabel(unit)
            if i ==2:
                axs[1].plot(t,trace,alpha=0.7,linewidth=2, label = f'trial - {s+1}')
                axs[1].set_xlabel('time(s)')
        if i ==0:
            trace_average = np.mean(trace_average,axis=0)
            axs[i].plot(t,trace_average,linewidth=1,color='k', label = f'average')
    if p_no == 0:
        p_no =3
        axs[0].set_xlim(3.105,3.15) #1st pulse
        fig.suptitle(f'{file_name}_pattern_{p_no}')
        plt.xlabel('time (s)',loc='center')
        f_plt_pat = f'{f_plt_pat}_{p_no}.jpg'
        print(f'saved the plot to {f_plt_pat}_{p_no}.jpg')
    elif p_no == 1:
        p_no =4
        axs[0].set_xlim(4.105,4.15) #middle pulse
        fig.suptitle(f'{file_name}_pattern_{p_no}')
        plt.xlabel('time (s)',loc='center')
        f_plt_pat = f'{f_plt_pat}_{p_no}.jpg'
        print(f'saved the plot to {f_plt_pat}_{p_no}.jpg')
    elif p_no == 2:
        p_no =5
        axs[0].set_xlim(5.105,5.15) #last pulse
        fig.suptitle(f'{file_name}_pattern_{p_no}')
        plt.xlabel('time (s)',loc='center')
        f_plt_pat = f'{f_plt_pat}_{p_no}.jpg'
        print(f'saved the plot to {f_plt_pat}_{p_no}.jpg')
    else:
        print(f'saving file {file_name} to plot_failed')
    plt.text(0.2, -0.2, expt_details, fontsize=11, transform=plt.gcf().transFigure)
    fig.savefig(f'{f_plt_pat}',bbox_inches='tight')
    plt.close(fig)

In [None]:
"""
plot all pattern responses
"""

def plot_all_patterns(f,outdir):
    cell_info =cell_data[cell_data['file_name']==f'{f.stem}.abf']
    expt_details = cell_info.to_string(index=False)
    cell_id = cell_info['cell_ID'].to_string(index=False,header=False)
    if 'cell' in cell_id:
        print(f'cell id = {cell_id}')
    else:
        cell_id = 'non_categorised'
    f_plt = outdir/cell_id
    f_plt.mkdir(exist_ok=True,parents=True)
    f_plt=f_plt/f'{f.stem}.jpg'
    file_name = f'{f.stem}'
    f=str(f)
    reader = nio.AxonIO(f)
    channels = reader.header['signal_channels']
    chan_count = len(channels)
    block  = reader.read_block(signal_group_mode='split-all')
    segments = block.segments
    sample_trace = segments[0].analogsignals[0]
    sampling_rate = sample_trace.sampling_rate
    ti = sample_trace.t_start
    tf = sample_trace.t_stop
    total_time = int(tf-ti)
    protocol_raw = reader.read_raw_protocol()
    protocol_raw = protocol_raw[0]
    #print(total_time)
    unit = str(sample_trace.units).split()[1]
    #        print(unit)
    fig, axs = plt.subplots(2,1,figsize=(12,3),sharex='col')
    for i in range(chan_count):
        trace_average = []
        for s, segment in enumerate(segments):
            analogsignals = segment.analogsignals[i]
            unit = str(analogsignals.units).split()[1]
            if unit =='mV':
                unit ='mV'
            elif unit =='pA':
                unit ='pA'
            else:
                unit='Signal amplitude'
                    
            trace = np.array(analogsignals)
            #trace = smooth(trace, 4)
            #trace  = downsampling_funct(trace, int(sampling_rate), 2000)
            trace_average.append(trace)
            #    print(f'length of trace = {len(trace)}')
            t = np.linspace(0,float(tf-ti),len(trace))
            #downsampling_funct(t, int(sampling_rate), 2000)
            if i ==0:
                axs[0].plot(t,trace,alpha=0.7,linewidth=2, label = f'trial - {s+1}')
                axs[0].set_title('recording')
                axs[0].set_ylabel(unit)
            if i ==2:
                axs[1].plot(t,trace,alpha=0.7,linewidth=2, label = f'trial - {s+1}')
                axs[1].set_xlabel('time(s)')
        if i ==0:
            trace_average = np.mean(trace_average,axis=0)
            axs[i].plot(t,trace_average,linewidth=1,color='k', label = f'average')

    axs[0].set_xlim(0.9,7.2) #last pulse
    plt.text(0.2, -0.2, expt_details, fontsize=11, transform=plt.gcf().transFigure)
    fig.suptitle(file_name)
    plt.xlabel('time (s)',loc='center')
    fig.savefig(f_plt,bbox_inches='tight')
    print(f'saved the plot to {f_plt}')
    plt.close(fig)

In [9]:
for abf in abf_files:
    plot_all_patterns(abf,outdir)
    for s in range(3):
        try:
            print(f'plotting pattern {s} in file {abf.stem}')
            plot_select_patterns(abf,outdir,s)
            print(f'saved file {abf.stem}')
        except:
            continue

saved the plot to /Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/results_plot/non_categorised/2022_11_21_0000.jpg
plotting pattern 0 in file 2022_11_21_0000
saved the plot to /Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/results_plot/non_categorised/2022_11_21_0000_pattern__3.jpg_3.jpg
saved file 2022_11_21_0000
plotting pattern 1 in file 2022_11_21_0000
saved the plot to /Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/results_plot/non_categorised/2022_11_21_0000_pattern__4.jpg_4.jpg
saved file 2022_11_21_0000
plotting pattern 2 in file 2022_11_21_0000
saved the plot to /Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/results_plot/non_categorised/2022_11_21_0000_pattern__5.jpg_5.jpg
saved file 2022_11_21_0000
saved the plot to /Users/anzalks/Documents/Expt_data/Recordings/CA3_recordings/results_plot/non_categorised/2022_11_21_0001.jpg
plotting pattern 0 in file 2022_11_21_0001
saved the plot to /Users/anzalks/Documents/Expt_data/Rec