## Notes

-------------- 
### Comments and observations  
------------------------------
Created by Amparo Guemes (1 April 2023). Adapted from original code with objects for easiness in reuse 

Code for processing evoked CAP recorded using tungsten microwires coated with PEDOT from sciatic nerve

## Libraries

In [None]:
import IPython
# IPython.Application.instance().kernel.do_shutdown(True)

%matplotlib widget

import os
import sys
import json
import time
import datetime
import pycwt
import statistics
import random
import pickle
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import sklearn as sk
import tkinter as tk
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn import decomposition
from sklearn.decomposition import PCA
from tkinter import *
from tkinter import ttk
from sklearn import preprocessing
from datetime import date
import matplotlib.dates as mdates

from neurodsp.rhythm import sliding_window_matching
from neurodsp.utils.download import load_ndsp_data
from neurodsp.plts.rhythm import plot_swm_pattern
from neurodsp.plts.time_series import plot_time_series
from neurodsp.utils import set_random_seed, create_times
# Import listed chormap
from matplotlib.colors import ListedColormap
import matplotlib.dates as md
from matplotlib import colors as mcolors
# Scipy
from scipy import signal
from scipy import ndimage
# TKinter for selecting files
from tkinter import Tk     # from tkinter import Tk for Python 3.x
from tkinter.filedialog import askdirectory
#Bokeh
from bokeh.io import output_notebook

# Add my module to python path
sys.path.append("../")


## Functions

In [None]:
# Functions
def pipeline_peak_extraction(signal2analyse, ch,fs, spike_detection_config, verbose=False):
        """
        General method to extract neural spikes and waveforms.
        
        Parameters
        ------------
        ch:                    [int] channel to analyse
        spike_detection_config: [dict] parameters specifying the length of window, height... for detection ad extraction of neural events
        verbose:            [int - 0,1] signal to display text information (default 1 - show text, 0 - don't show) 

        Returns
        ------------
        spikes_idx:            [array] array with indexes of neural peaks (locations) after filtering with conditions
        waveforms:            [array, rows: number of neural peaks, columns: window with indexes around neural events] 
        spikes_vector_ampl: [array, same length as signal] amplitude of neural signal correspoding to spikes_vector_loc 
        spikes_vector_loc:    [array] numpy array with same size as signal with 1 where there's a spike and 0 otherwise
        index_first_edge:    [array] indexes of detected peaks (positive/negative/both) before alignment to max or min
        """

        time_start = time.time()

        # ----------------------------------------------------
        # Identify neural spikes from ENG
        # -----------------------------------------------------
        if verbose:
            print("-------------------------------")
            print("Detecting neural events")

        spikes_idx, numunique_peak_pos, index_first_edge = \
            spike_detection(signal2analyse[ch], detect_method,fs, verbose=verbose, **spike_detection_config)
        
        waveforms, spikes_idx = get_waveforms(signal2analyse, fs, ch, spikes_idx, spike_detection_config['general']['spike_window'], spike_detection_config['general']['min_thr'], spike_detection_config['general']['half_width'])
        spikes_vector_loc = get_spike_location(signal2analyse[ch], spikes_idx)    
        
        print("pipeline_peak_extraction completed. Time elapsed: {} seconds".format(time.time()-time_start))
        return spikes_idx, waveforms, spikes_vector_loc, index_first_edge
    
    
def spike_detection(signal, method,fs, manual_thres=None, verbose=False, **kwargs):
		"""
		Method for identifying neural spikes based on different methods

		Parameters
		------------
		signal: 		signal2analyse[ch]
		method: 		[string] 
		manual_thres: 	[None or floar] If not None, then select the value specified as manual_thres
		verbose:		[int - 0,1] signal to display text information (default 1 - show text, 0 - don't show) 

		Returns
		-----------
		spikes_idx:			[array] array with indexes of neural peaks (locations)
		numunique_peak_pos: [int] final number of neural spikes
		index_first_edge: 	[array] indexes of detected peaks before alignment to max or min

		"""
		#------------------------------------
		# Method 1: get_spikes(). 
		# detect max of peaks using the find_peaks() method (using minimum distance between peaks)
		#--------------------------------------------
		if method == "get_spikes_method":
			if verbose:
				print(method)
			threshold = set_height(signal,  kwargs['general']['C'], manual_thres)
			print('Have set height')

            print('get positive peaks')
			print(kwargs['general'])
			# Get positive peaks 
			print('calculating height')
			print('Max height: %s' %(kwargs['general']['Cmax'] * np.median(np.abs(signal)/0.6745)))
			print('Min height: %s' %threshold[0])
			kwargs['general']['find_peaks_args']['height'] = (threshold[0], kwargs['general']['Cmax'] * np.median(np.abs(signal)/0.6745))
			indspos = get_spikes(signal, fs, **kwargs['general'],neo=False)
			if verbose:
				print('Length array spike pos from find_peaks(): %s' %len(indspos))

		#----------------------------------------------------------
		# Determine positive/negative/both peaks and waveforms
		#----------------------------------------------------------
		print('Determining waveforms')
		spikes_idx=[]

		index_first_edge = []
		index_first_edge = indspos
		if verbose:
			print('Number of index_first_peaks: %s' %len(index_first_edge))

		# Align peaks around maxima in window from the filtered signal
		# Center in windows in raw signal: This needs to be done as the wavelet 
		# decomposition may slightly change the waveform, so we ensure we're picking 
		# up the maximum of the original signal
		spikes_idx, numunique_peak_pos = \
		max_centered_peaks_optimized(signal, index_first_edge, 
											 kwargs['general']['spike_window'],
											 verbose=verbose)



		return spikes_idx, numunique_peak_pos, index_first_edge
                                     
def set_height(signal, C, manual_thres=None):
		"""
		The first element is always interpreted as the minimal and 
		the second, if supplied, as the maximal required height (threshold).

		Parameter
		----------
		signal: [] 
		C: not used (see comment)
		manual_thres: [array of 1x2] first element is the positive and second element is the negative manual thresholds

		Returns
		------------
		threshold:	

		Comments
		----------
		Tried using C instead of np.sqrt(2*np.log(len(signal)), which is normally used for denoising wavelet
		components [Diedrich2003]. However, seems to work fine before, so I'm not using C for now. 
		It's however kept commented in case I want to use it in the future
		"""
		threshold = []
		if manual_thres is None:
			print('NOT A MANUAL THRESHOLD')
			# divided by 0.6745, which is the 75th percentile of the standard normal distribution [15],
			std_noise = np.median(np.abs(signal)/0.6745) #np.median(np.abs(signal)/0.6745)
			threshold.append(std_noise*C)  #AG 02/12/2021
			threshold.append(-std_noise*C)	#AG 02/12/2021
			print('np.median(np.abs(signal)/0.6745): %s' %np.median(np.abs(signal)/0.6745))
		else:
			threshold.append(manual_thres[0])
			threshold.append(manual_thres[1])
		return threshold

def get_waveforms(signal2analyse, fs, channel, spikes_idx, spike_window, min_thr, half_width):
		"""
		Method to filter detected peaks based on their amplitude and width (to discard peaks with shapes different to AP),
		and extract the window around the identified neural spikes.


		Parameters
		------------
		channel:	  [int] channel to extract waveforms from 
		spikes_idx:	  [array] array with indexes of neural peaks (locations) before filtering with conditions
		spike_window: [array 1x2] Length of window in samples: first position are samples before identified spike, second position after. From spike_detection_config['general']
		min_thr:	  [array 1x1] Minimum of amplitude of identified spikes. From spike_detection_config['general']
		half_width:	  [array 1x2] Min and max length in seconds from min to max of waveform. From spike_detection_config['general']

		Returns
		------------
		wdws:		[array, rows: number of neural peaks, columns: window with indexes around neural events] 
		spikes_idx:	[array] array with indexes of neural peaks (locations) before filtering with conditions

		"""
		print("Getting waveforms")
		waveforms = []
		original_spikes = spikes_idx
		nowidth = 0
		noamp = 0
		index = 0

		while index in range(len(original_spikes)): 
			p = original_spikes[index]
			# First filter peaks based on amplitude
			if np.abs(signal2analyse[channel][p])>min_thr[0] and np.abs(signal2analyse[channel][p])<min_thr[1]: 
				try:
					new_window = np.arange(p - spike_window[0], p + spike_window[1] )
					# Second filter based on width
					# Calculate absolute distances from zero_crossings to max value and get the smallest of them. This is the closest zero crossing
					cross_zero = np.where(np.diff(np.sign(signal2analyse[channel][new_window])))[0]+1
					dist = np.abs(np.argmax(signal2analyse[channel][new_window])-cross_zero)
					min2max = np.sort(dist)[0]
					if min2max > int(fs*half_width[0]) and min2max < int(fs*half_width[1]):   # Half width longer than 1 sec
						waveforms.append(signal2analyse[channel][new_window]) 	
						index = index + 1
					else: 
						spikes_idx.remove(p)
						nowidth = nowidth+1
						
				except IndexError:
						print('get_waveforms(): no zero-crossing within window ')
						spikes_idx.remove(p)
					#sys.exit()
			else: 
				noamp = noamp+1
				spikes_idx.remove(p)

		if waveforms:
			wdws = np.stack(np.array(waveforms))
		else:
			wdws = []
		print('Peaks excluded - not appropriate width: %s' %nowidth)
		print('Peaks excluded - not appropriate amplitude: %s' %noamp)
		print('Len filtered peaks: %s' %len(spikes_idx))
		print('Len wdsw peaks: %s' %len(wdws))
		return wdws, spikes_idx


def get_spike_location(signal, spikes_idx):  
		"""
		Parameters:
		------------- 
			signal: input signal being analysed (signal2analyse[ch])
			spikes_idx: [array] 1D vector containing indexes of locations where neural spikes were identified 
		Return: 
		-------------
			spikes_vector_loc: numpy array with same size as signal with 1 where there's a spike and 0 otherwise
		"""

		spikes_vector_loc = np.zeros(len(signal))
		if spikes_idx:
			spikes_vector_loc[np.array(list(spikes_idx))] = 1
		else:
			spikes_vector_loc = []
		return spikes_vector_loc   

    def filter(signal2filt,filter_ch, fs, filtername, **kargs):
    """
    Method to apply filtering to recordings (ENG)
    Note that despite the whole dataframe is passed, the algorithm only applies to the selected channels (filter_ch)

    Parameters
    ------------
    signal2filt: [dataframe] signals to filter (columns in dataframe structure)
    filtername:  [string] name of the filter to apply {'None', 'butter', 'fir', 'notch'}
    kargs:     [dict] specific parameters for for the filters

    Returns
    ------------
    filtered: [dataframe] updare the recording object with a parameter that is a dataframe with the results of the filtering

    """
    if filtername=='None':
        filtered = signal2filt
        print('No filter applied!')
        pass
    elif filtername=='butter':
        # Configure butterworth filter
        kargs['fs'] = fs
        sos = signal.butter(**kargs, output='sos')
        filtered = signal2filt.apply(lambda x: signal.sosfilt(sos, x)
                            if x.name in filter_ch else x)

    elif filtername=='fir':
        print(filter_ch)
        filtered = signal2filt.apply(lambda x: FIR_smooth(x, **kargs) 
                            if x.name in filter_ch else x)
    elif filtername=='notch':
        filtered = signal2filt.apply(lambda x: iir_notch(x, **kargs)
                            if x.name in filter_ch else x)
        return filtered
    
def plot_freq_content(signal2plot, ch,fs, nperseg=512, max_freq=10000, ylim=None, dtformat='%M:%S.%f', figsize=(10, 15), savefigpath='', show=False):
    """
    plt.specgram parameters: 
    NFFT : int
        The number of data points used in each block for the FFT. A power 2 is most efficient. The default value is 256.
        The benefit of a longer FFT is that you will have more frequency resolution. The number of FFT bins, the discrete 
        fequency interval of the transform will be N/2. So the frequency resolution of each bin will be the sample frequency Fs x 2/N.
    mode : {'default', 'psd', 'magnitude', 'angle', 'phase'}
        What sort of spectrum to use. Default is 'psd', which takes the power spectral density. 
        'magnitude' returns the magnitude spectrum. 'angle' returns the phase spectrum without unwrapping. 
        'phase' returns the phase spectrum with unwrapping.
    scale : {'default', 'linear', 'dB'}
        The scaling of the values in the spec. 'linear' is no scaling. 'dB' returns the values in dB scale. When mode is 'psd', 
        this is dB power (10 * log10). Otherwise this is dB amplitude (20 * log10). 'default' is 'dB' if mode is 'psd' or 'magnitude' 
        and 'linear' otherwise. This must be 'linear' if mode is 'angle' or 'phase'.
    """
    # Raw signal
    fig, ax = plt.subplots(3, 1, figsize=figsize)
    ax[0].plot(signal2plot.index, signal2plot['ch_%s'%ch], linewidth=0.5, zorder=0)
    ax[0].set_title('Sampling Frequency: {}Hz'.format(fs))
    ax[0].set_xlabel('Time [s]')
    ax[0].set_ylabel('Voltage [uV]')
    if ylim is not None:
        ax[0].set_ylim(ylim)

    # PSD (whole dataset ferquency distribution)
    f_data, Pxx_den_data = signal.welch(signal2plot['ch_%s'%ch], fs, nperseg=512) # nperseg
    # ax[1].psd(data[0:sf], NFFT=1024, Fs=sf)
    ax[1].semilogx(f_data, Pxx_den_data)
    ax[1].set_xlabel('Frequency [Hz]')
    ax[1].set_ylabel('PSD [V**2/Hz]')

    # Spectogram (frequency content vs time)
    # plt.specgram plots 10*np.log10(Pxx) instead of Pxx
    plt.subplot(313)
    powerSpectrum, freqenciesFound, time, imageAxis = plt.specgram(signal2plot['ch_%s'%ch], NFFT=nperseg, Fs=fs, mode='psd', scale='dB')
    plt.ylabel('Spectogram \n Frequenct [Hz]')
    plt.xlabel('Time [s]')
    plt.ylim([0, max_freq])
    clb = plt.colorbar(imageAxis)
    clb.ax.set_title('10*np.log10 \n [dB/Hz]') 

    # Format axes
    for i in range(len(ax)):
        # Hide the right and top spines
        ax[i].spines['right'].set_visible(False)
        ax[i].spines['top'].set_visible(False)
        # Only show ticks on the left and bottom spines
        ax[i].yaxis.set_ticks_position('left')
        ax[i].xaxis.set_ticks_position('bottom')
    ax[0].xaxis.set_major_formatter(md.DateFormatter(dtformat))

    if savefigpath!='':
        plt.savefig(savefigpath, facecolor='w')

    if show==True:
        plt.show()
    else:
        print('Plot will not show')
        plt.close()
         
def select_channels(channels):
    """
    Method to select which channels to analyse 
    
    Parameters
    ------------
    channels:     ['all' or list of numbers] list of intan channels to be analysed

    Return
    --------
        filter_ch:    [list of string] list with the selected intan channels in string mode (starting in 'ch_')
    """

    filter_ch = []
    for i, ch in enumerate(channels):
        ch = int(ch)
        filter_ch.append('ch_%s'%ch)
        return filter_ch

class MyWaveforms:
	def __init__(self, waveforms, recording, fs, spikes_vector_loc, num_clusters, path):
		"""
		spikes_vector_loc: numpy array with same size as analysed signal with 1 where there's a spike and 0 otherwise
		"""
		self.waveforms = waveforms
		self.recording = recording
		self.fs = fs
		self.num_clusters = num_clusters
		self.spikes_vector_loc = spikes_vector_loc
		self.path = path
        
    def noise_sd(self, signal, ch, noise_samples=[0, 1000]):
		noise_sd = np.std(signal.iloc[int(noise_samples[0]):int(noise_samples[1])],0) #SD of noise
		print('sd noise: %s' %noise_sd)
		return noise_sd
    
    def plot_waveforms_multipleCh_mean(self, axes, ch, spike_window, ylim=[], dtformat='%S.%f'):
		"""
		
		"""
		print('Plotting waveforms')

		time = np.linspace(0, self.waveforms.shape[1] / self.fs, self.waveforms.shape[1]) * 1000
		#np.linspace(-spike_window[0]/self.fs,spike_window[1]/self.fs

		self.meanWave = np.mean(self.waveforms,0)
		std_waveforms = np.std(self.waveforms, 0)

		spike_window = spike_window*1000
		axes.plot(time, self.meanWave, '--', )
		axes.fill_between(time,
						 	self.meanWave - std_waveforms,
					     	self.meanWave + std_waveforms,
						 	alpha=0.15, label='Channel: %s'%ch)
		axes.legend()
		if len(ylim)==0:
			pass	
		else:	
			axes.set_ylim(ylim)
		# Hide the right and top spines
		axes.spines['right'].set_visible(False)
		axes.spines['top'].set_visible(False)

		# Only show ticks on the left and bottom spines
		axes.yaxis.set_ticks_position('left')
		axes.xaxis.set_ticks_position('bottom')


## Load data

In [None]:
dir_name = ('../datasets/')
path = '../xx'

run_first_time = True


In [None]:
# Load data from two ports 
chC = 23 
chD = 24 

time_start = time.time()

# Load data from pkl (saved after loading with INTAN script)
recording = 
path = ''
file = ''
fs = 30000

plt_ch = channels[0] #Illustrative channel to plot

''' Load channels independentily first time, then load merged dataframe
################### Load port 1 ##################################
port = 'Port C'
recording_C = pd.DataFrame()
with open('%s/%s.pkl'%(path, file), 'rb') as f:
    recording_C = pickle.load(f)

################### Load port 2 ##################################
port = 'Port D'
recording_D = pd.DataFrame()
with open('%s/%s.pkl'%(path, file), 'rb') as f:
    recording_D = pickle.load(f)


# Create directory to save figures
if not os.path.exists('%s/figures/' %(path)):
    os.makedirs('%s/figures/' %(path))

'''
# Merge in singke dataframe
recording = recording_C
recording = pd.concat([recording_C['seconds'], recording_C['ch_0'], recording_D['ch_1']], axis=1, keys=['seconds','ch_0', 'ch_1']) 
recording.name = 'original'

channels = [0, 1] #0: port C , 1:portD

print('Saving data into: %s/recording.pkl' %path)
recording.to_pickle(r'%s/recording.pkl' %path)


## Configuration

###  General configuration  


#### Options list

In [None]:
# Configure
config_text = []
apply_filter = 'butter' 
detect_method = 'get_spikes_method'

print('SELECTED GENERAL CONFIGURATION:')
print('Filter: %s'%apply_filter)
print('Detection: %s'%detect_method)
print('Channels: %s' %channels) 
print('-------------------------------------')

filter_ch = select_channels(channels) 
print('filter_ch %s' %filter_ch)

### Specific configuration  

##### Spike detection config

In [None]:
spike_detection_config = {
    'general':{
        'cardiac': False,
        # Length of window that containes a neural spike
        'spike_window': [int(0.002 * fs), int(0.05 * fs)],   #[int(0.0018 * fs), int(0.0018 * fs)],  # in total 0.36s length AP: 3ms to refractory and 5ms total
        'min_thr': [0, 100],              # Min & max of amplitude of identified spikes GUT: 10000 in gut
        'half_width': [0.01/1000,10],  # Length in sec from zero cross to max of waveform. 
        'C': 11, 
        'Cmax': 1000,
        'find_peaks_args': {
        # Input to find_peaks() function:
        # Required minimal horizontal distance (>= 1) in samples between 
        # neighbouring peaks. Smaller peaks are removed first until the 
        # condition is fulfilled for all remaining peaks.
        'distance': int(0.0018 * 2 * fs),
        'height': 1000
        }
    },
}



## START ANALYSIS                                             


### Filtering


#### Bandwidth filter

In [None]:
# Configure
filt_config = {
    'W':  [200, 7000], #  
    'None': {},
    'butter': {
            'N': 9,                # The order of the filter
            'btype': 'bandpass', #'bandpass', #'hp'  #'lowpass'     # The type of filter.
    },      
    'fir': {
            'n': 4,
    },
    'notch': {
            'quality_factor': 30,
    },
}

filt_config['butter']['Wn'] = filt_config['W']
filt_config['butter']['fs'] = fs



##### Apply filter

In [None]:
# Configure
time_start = time.time()
signal2filter = recording 
filtered = filter(recording, filter_ch, fs, apply_filter, **filt_config[apply_filter])

print("Time elapsed: {} seconds".format(time.time()-time_start))

##### Plot filtered signal

In [None]:
text_label = 'Filtered'
text = 'Channels after %s filtering'%apply_filter
plot_freq_content(filtered,int(plot_ch), fs, nperseg=512, max_freq=250, ylim=[-100, 100], dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_butter_filtering-%s.png' %(path,port,plot_ch, current_time),
                         show=True) 

#### Signal selection

In [None]:
# Uncomment if desired to continue with the filtered signal for analysis
recording=filtered
recording.name = 'filtered'

## Signal analysis

In [None]:
# Configure

# Select the signal that will be used to extract the neural spikes from
signal2analyse = -recording # Invert to detect negative peaks    
save_waveforms = [] # To store waveform objects for more than one channel

dtformat = '%S.%f' #'%H:%M:%S' '%M:%S.%f''

verbose = True


In [None]:
#-------------------------------------------------------------
# Initialize figures
#-------------------------------------------------------------
time_start_analysis=time.time()


fig2, axes_wv = plt.subplots(1, 1, figsize=(15, 8), sharex=True)
fig2.suptitle('Waveforms', fontsize=16, family='serif')  
axes_wv = axes_wv.flatten()
noise_sd = []
save_waveforms = []
for n, j in enumerate(ch_loc): 

    #-------------------------------------------------------------
    # Start pipeline of cardiac and neural peaks identification
    #-------------------------------------------------------------
    hr_idx, hr_vector, spikes_idx, waves, spikes_vector, spikes_vector_loc, index_first_edge = \
                            pipeline_peak_extraction(ch, fs, spike_detection_config,
                                                            verbose=True)
    # Create waveform object for processing of waveforms
    waveforms = MyWaveforms(waves, signal2analyse, fs, spikes_vector_loc, num_clus, path)

    # Compute signal to noise ratio: change 12 and 13 for baseline period in each recording
    noise_sd.append(waveforms.noise_sd(signal2analyse[ch],ch, noise_samples=[0*fs,2*fs]))
    waveforms.plot_waveforms_multipleCh_mean(axes_wv[0], ylim=[], dtformat=dtformat)
    plt.show()

    save_waveforms.append(waveforms)

    #Save mean waveform 
    waveforms.meanWave = np.mean(waveforms.waveforms,0)


print ('Analysis done! Time elapsed: {} seconds'.format(time.time()-time_start_analysis))


## Plot heatmaps for each comparison

In [None]:
import glob
import seaborn
from matplotlib.colors import LinearSegmentedColormap


#path_gen = '../datasets/recording/Sciatic-microwires/RQ_RQ5C'
#path_gen = '../datasets/recording/Sciatic-microwires/RQ_RQ20C'

max_time = 50 #ms

fig, axes = plt.subplots(1, 2, figsize=(10, 3), sharex=True)
m=[]
for n in range(2):
    print(n)
    # glob allows to load all files in the folder
    files = glob.glob('%s/all_results/'%(path_gen) +'/*%s.npy'%n, recursive=True)
    files_2 = sorted(files)
    size = len(files)
    print(size)
    # Create arrays to store all the data
    waves = np.empty(0)
    # Run over all files
    for num,file in enumerate(files_2):
        print(file)
        #new_wave = np.load('%s/all_results/'%(path_gen) + 'cluster_df_num%s_*_%s.npy'%(num,n))[0:int(0.035*30000)]
        new_wave = np.load(file)[0:int(max_time/1000*30000)]
        #print(new_wave)
        normalized_wave = new_wave / np.linalg.norm(new_wave)
        waves = np.append(waves, normalized_wave, axis=0)
        m.append(waves.max())
    waves = np.reshape(waves, [size, len(new_wave) ])
    # Plot heatmap
    seaborn.heatmap(waves, cmap=sns.diverging_palette(0, 255, s=110, l=55, sep=8, n=256),  ax=axes[n], vmin=-0.15, vmax=0.15) # vmin=-m[0], vmax=m[0])# , s=100, l=55
plt.show()

# Get current time for saving (avoid overwriting)
now = datetime.datetime.now()
current_time = now.strftime("%d%m%Y_%H%M%S")
fig.savefig('%s/all_waveforms_new-%s.svg' %(path_gen, current_time), facecolor='w')