## Notes

-------------- 
### Comments and observations  
------------------------------
Created by Amparo Guemes (1 April 2023)
Code for processing LFP recorded using tungsten microwires coated with PEDOT


## Libraries

In [None]:
import IPython
# IPython.Application.instance().kernel.do_shutdown(True)


%matplotlib widget

import os
import sys
import json
import time
import datetime
import pycwt
import statistics
import random
import pickle
import numpy as np
import scipy as sp
import pandas as pd
import seaborn as sns
import sklearn as sk
import tkinter as tk
import matplotlib.pyplot as plt
from sklearn import metrics
from sklearn import decomposition
from sklearn.decomposition import PCA
from tkinter import *
from tkinter import ttk
from sklearn import preprocessing
from datetime import date
import matplotlib.dates as mdates

from neurodsp.rhythm import sliding_window_matching
from neurodsp.utils.download import load_ndsp_data
from neurodsp.plts.rhythm import plot_swm_pattern
from neurodsp.plts.time_series import plot_time_series
from neurodsp.utils import set_random_seed, create_times
# Import listed chormap
from matplotlib.colors import ListedColormap
import matplotlib.dates as md
from matplotlib import colors as mcolors
# Scipy
from scipy import signal
from scipy import ndimage
# TKinter for selecting files
from tkinter import Tk     # from tkinter import Tk for Python 3.x
from tkinter.filedialog import askdirectory
from scipy.stats import zscore


# Add my module to python path
sys.path.append("../")


In [None]:
# Functions
def filter(signal2filt,filter_ch, fs, filtername, **kargs):
    """
    Method to apply filtering to recordings (ENG)
    Note that despite the whole dataframe is passed, the algorithm only applies to the selected channels (filter_ch)

    Parameters
    ------------
    signal2filt: [dataframe] signals to filter (columns in dataframe structure)
    filtername:  [string] name of the filter to apply {'None', 'butter', 'fir', 'notch'}
    kargs:     [dict] specific parameters for for the filters

    Returns
    ------------
    filtered: [dataframe] updare the recording object with a parameter that is a dataframe with the results of the filtering

    """
    if filtername=='None':
        filtered = signal2filt
        print('No filter applied!')
        pass
    elif filtername=='butter':
        # Configure butterworth filter
        kargs['fs'] = fs
        sos = signal.butter(**kargs, output='sos')
        filtered = signal2filt.apply(lambda x: signal.sosfilt(sos, x)
                            if x.name in filter_ch else x)

    elif filtername=='fir':
        print(filter_ch)
        filtered = signal2filt.apply(lambda x: FIR_smooth(x, **kargs) 
                            if x.name in filter_ch else x)
    elif filtername=='notch':
        filtered = signal2filt.apply(lambda x: iir_notch(x, **kargs)
                            if x.name in filter_ch else x)
        return filtered
    
def plot_freq_content(signal2plot, ch,fs, nperseg=512, max_freq=10000, ylim=None, dtformat='%M:%S.%f', figsize=(10, 15), savefigpath='', show=False):
    """
    plt.specgram parameters: 
    NFFT : int
        The number of data points used in each block for the FFT. A power 2 is most efficient. The default value is 256.
        The benefit of a longer FFT is that you will have more frequency resolution. The number of FFT bins, the discrete 
        fequency interval of the transform will be N/2. So the frequency resolution of each bin will be the sample frequency Fs x 2/N.
    mode : {'default', 'psd', 'magnitude', 'angle', 'phase'}
        What sort of spectrum to use. Default is 'psd', which takes the power spectral density. 
        'magnitude' returns the magnitude spectrum. 'angle' returns the phase spectrum without unwrapping. 
        'phase' returns the phase spectrum with unwrapping.
    scale : {'default', 'linear', 'dB'}
        The scaling of the values in the spec. 'linear' is no scaling. 'dB' returns the values in dB scale. When mode is 'psd', 
        this is dB power (10 * log10). Otherwise this is dB amplitude (20 * log10). 'default' is 'dB' if mode is 'psd' or 'magnitude' 
        and 'linear' otherwise. This must be 'linear' if mode is 'angle' or 'phase'.
    """
    # Raw signal
    fig, ax = plt.subplots(3, 1, figsize=figsize)
    ax[0].plot(signal2plot.index, signal2plot['ch_%s'%ch], linewidth=0.5, zorder=0)
    ax[0].set_title('Sampling Frequency: {}Hz'.format(fs))
    ax[0].set_xlabel('Time [s]')
    ax[0].set_ylabel('Voltage [uV]')
    if ylim is not None:
        ax[0].set_ylim(ylim)

    # PSD (whole dataset ferquency distribution)
    f_data, Pxx_den_data = signal.welch(signal2plot['ch_%s'%ch], fs, nperseg=512) # nperseg
    # ax[1].psd(data[0:sf], NFFT=1024, Fs=sf)
    ax[1].semilogx(f_data, Pxx_den_data)
    ax[1].set_xlabel('Frequency [Hz]')
    ax[1].set_ylabel('PSD [V**2/Hz]')

    # Spectogram (frequency content vs time)
    # plt.specgram plots 10*np.log10(Pxx) instead of Pxx
    plt.subplot(313)
    powerSpectrum, freqenciesFound, time, imageAxis = plt.specgram(signal2plot['ch_%s'%ch], NFFT=nperseg, Fs=fs, mode='psd', scale='dB')
    plt.ylabel('Spectogram \n Frequenct [Hz]')
    plt.xlabel('Time [s]')
    plt.ylim([0, max_freq])
    clb = plt.colorbar(imageAxis)
    clb.ax.set_title('10*np.log10 \n [dB/Hz]') 

    # Format axes
    for i in range(len(ax)):
        # Hide the right and top spines
        ax[i].spines['right'].set_visible(False)
        ax[i].spines['top'].set_visible(False)
        # Only show ticks on the left and bottom spines
        ax[i].yaxis.set_ticks_position('left')
        ax[i].xaxis.set_ticks_position('bottom')
    ax[0].xaxis.set_major_formatter(md.DateFormatter(dtformat))

    if savefigpath!='':
        plt.savefig(savefigpath, facecolor='w')

    if show==True:
        plt.show()
    else:
        print('Plot will not show')
        plt.close()
         
def select_channels(channels):
    """
    Method to select which channels to analyse 
    
    Parameters
    ------------
    channels:     ['all' or list of numbers] list of intan channels to be analysed

    Return
    --------
        filter_ch:    [list of string] list with the selected intan channels in string mode (starting in 'ch_')
    """

    filter_ch = []
    for i, ch in enumerate(channels):
        ch = int(ch)
        filter_ch.append('ch_%s'%ch)
        return filter_ch

## Load data

In [None]:
# Load data from pkl (saved after loading with INTAN script)
recording = pd.DataFrame()
path = ''
file = ''
fs = 30000

with open('%s/%s.pkl'%(path, file), 'rb') as f:
    recording = pickle.load(f)


# Select channel position/number in intan 
channels = [1,27] # device 1: [1,27] 5C vs W # device2: [6, 24] 20C vs W
plt_ch = channels[0] #Illustrative channel to plot

In [None]:
# Get current time for saving (avoid overwriting)
now = datetime.datetime.now()
current_time = now.strftime("%d%m%Y_%H%M%S")

## Configuration

###  General configuration  


#### Options list

In [None]:
options_filter = [
    "None", 
    "butter", 
    "fir"]                # Binomial Weighted Average Filter


In [None]:
# Configure
config_text = []
apply_filter = options_filter[1]    


print('SELECTED GENERAL CONFIGURATION:')
print('Filter: %s'%apply_filter)
print('Channels: %s' %channels) 
print('-------------------------------------')

filter_ch = select_channels(channels) 
print('filter_ch %s' %filter_ch)


## START ANALYSIS                                             


In [None]:
# Remove all columns expect the selected channels
recording.drop(recording.columns.difference(filter_ch['seconds']), 1, inplace=True)
print(recording)

### Plot raw signal

In [None]:
plot_freq_content(recording,int(plot_ch),fs, nperseg=512, max_freq=4000, ylim=[-500, 500], dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_original-%s.png' %(path, port, plot_ch, current_time), 
                         show=True) 

#### Bandwidth filter

In [None]:
# Configure filter
filt_config = {
    'W': [200],
    'None': {},
    'butter': {
            'N': 9,                # The order of the filter
            'btype': 'lowpass', #'bandpass', #'hp'  #'lowpass'     # The type of filter.
    },      
    'fir': {
            'n': 4,
    },
    'notch': {
            'quality_factor': 30,
    },
}

filt_config['butter']['Wn'] = filt_config['W']
filt_config['butter']['fs'] = fs


##### Apply filter

In [None]:
# Configure
time_start = time.time()
signal2filter = recording 
filtered = filter(recording, filter_ch, fs, apply_filter, **filt_config[apply_filter])

print("Time elapsed: {} seconds".format(time.time()-time_start))

##### Plot filtered signal

In [None]:
text_label = 'Filtered'
text = 'Channels after %s filtering'%apply_filter
plot_freq_content(filtered,int(plot_ch), fs, nperseg=512, max_freq=250, ylim=[-100, 100], dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_butter_filtering-%s.png' %(path,port,plot_ch, current_time),
                         show=True) 


#### Notch filtering

In [None]:
time_start = time.time()
freq_notch =  [50, 150]
for n in freq_notch:
    filt_config['notch']['notch_freq'] = n
    filtered_notch = filter(filtered, 'notch', **filt_config['notch'])
print("Time elapsed: {} seconds".format(time.time()-time_start))

recording=filtered_notch
recording.name = 'filtered'


#### Plot filtered signal

In [None]:
text_label = 'Filtered'
text = 'Channels after %s filtering'%'notch'

plot_freq_content(filtered_notch, filter_ch, fs, int(plot_ch),fs, ylim=[-400, 400], nperseg=512, max_freq=200, dtformat='%H:%M:%S',
                         figsize=(10, 10), savefigpath='%s/figures/%s_ch%s_allfilt-%s.png' %(path, port,plot_ch, current_time), show=True) 


## Create datasets of metrics from rest and active activity

In [None]:
# Select device start and end times (from filtered plot)
# device 1
ind_st_base    = [2, 16, 78, 97,  115, 151, 204, 212, 247, 284, 297, 312, 362, 387, 394, 420, 429, 440, 462, 495]
ind_end_base   = [6, 24, 82, 100, 119, 154, 206, 219, 250, 291, 301, 316, 365, 390, 396, 426, 435, 447, 473, 499]
ind_st_active  = [7, 25, 83, 100, 119, 154, 206, 219, 250, 291, 301, 316, 365, 390, 396, 426, 435, 447, 473, 499]
ind_end_active = [10,28, 86, 103, 151, 162, 212, 225, 258, 297, 305, 323, 369, 394, 400, 429, 440, 451, 478, 504]
'''

# device 2
ind_st_base    = [135, 198, 218, 240, 286, 310, 319, 347, 363, 382, 411]
ind_end_base   = [151, 206, 228, 246, 291, 316, 323, 352, 370, 391, 414]
ind_st_active  = [151, 206, 228, 246, 291, 316, 323, 352, 370, 391, 414]
ind_end_active = [156, 211, 235, 251, 302, 319, 327, 363, 374, 395, 416]
'''

# Initialise
alllist = range(len(ind_st_base))

var_active_0 = []
var_active_1 = []

var_base_0 = []
var_base_1 = []

std_base_0 = []
std_base_1 = []

# Compute metrics for each period 
for i in alllist:
    df = recording.iloc[ind_st_active[i]*int(fs): ind_end_active[i]*int(fs)]
    var_active_0.append(df[filter_ch[0]].var())
    var_active_1.append(df[filter_ch[1]].var())
   
print('--------------')
for i in alllist:
    df = recording.iloc[ind_st_base[i]*int(fs): ind_end_base[i]*int(fs)]
    var_base_0.append(df[filter_ch[0]].var())
    var_base_1.append(df[filter_ch[1]].var())
   
# Plot obtained metrics
print('------ STD active ------')
print(var_active_0)
print(var_active_1)

print('------ MAX active ------')
print(max_active_0)
print(max_active_1)

print('------ VAR base ---------')
print(var_base_0)
print(var_base_1)

print('------ STD base ---------')
print(std_base_0)
print(std_base_1)

print('-------SNR base--------')
snr0 = [10*np.log10(i / j) for i, j in zip(var_active_0, var_base_0)]
snr1 = [10*np.log10(i / j) for i, j in zip(var_active_1, var_base_1)]
print(snr0)
print(snr1)

print('-------------Mean SNR---')
print('%s +- %s' %(np.mean(snr0), np.std(snr0)))
print('%s +- %s' %(np.mean(snr1), np.std(snr1)))

print('-------------Mean STD rest---')
print('%s +- %s' %(np.mean(std_base_0), np.std(std_base_0)))
print('%s +- %s' %(np.mean(std_base_1), np.std(std_base_1)))


In [None]:
### Plotting comparison variables 
import seaborn as sns
sns.set_theme(style="whitegrid")

data = pd.DataFrame({'wire': np.repeat([1,2], len(ind_st_base)),
                   'response_snr': snr1 +snr0,
                   'response_var_active': var_active_1+ var_active_0,
                   'response_var_base': var_base_1+ var_base_0,
                   'response_var_base': var_base_1+ var_base_0,})


fig1, [ax1, ax2] = plt.subplots(1,2)
ax1.set_title('Boxplot')
sns.boxplot(data=data,x="wire", y="response_snr", ax=ax1, boxprops=dict(alpha=.6))
sns.despine(left=True)
plt.show()

fig1, [ax1, ax2] = plt.subplots(1,2)
ax1.set_title('Boxplot')
sns.boxplot(data=data,x="wire", y="response_var_active", ax=ax1, boxprops=dict(alpha=.6))
sns.despine(left=True)
plt.show()

fig1, [ax1, ax2] = plt.subplots(1,2)
ax1.set_title('Boxplot')
sns.boxplot(data=data,x="wire", y="response_var_base", ax=ax1)
sns.violinplot(data=data, x="wire", y="response_var_base", ax=ax2)
sns.despine(left=True)
plt.show()

fig1, [ax1, ax2] = plt.subplots(1,2)
ax1.set_title('Boxplot')
sns.boxplot(data=data,x="wire", y="response_std_rest", ax=ax1)
sns.violinplot(data=data, x="wire", y="response_std_rest", ax=ax2)
sns.despine(left=True)
plt.show()

## Calculate statistics and make plots

In [None]:

###############################################################################
#
#                       STATISTICAL TESTS 2 GROUPS
#
###############################################################################

from processing.stat_analysis import jhu_stat_analysis
import permutation_test

def stat_analysis_2groups(g0, g1, alpha=0.05, print_info=True):
    """

    """
    #p_value = stats.permutation_test(g0, g1)
    #print("p-value in Permutation: %s" %p_value)
    # ------------------------------------------------------
    # Compute Shapiro-Wilk test tests the null hypothesis that the data was drawn from a normal distribution.
    # ------------------------------------------------------

    shapiro_test_g0 = stats.shapiro(g0)
    p_s_g0 = shapiro_test_g0[1]

    shapiro_test_g1 = stats.shapiro(g1)
    p_s_g1 = shapiro_test_g1[1]
    # ------------------------------------------------------
    # Compute Levene test for equal variances (for not normal distributions)
    # ------------------------------------------------------

    Tl, pl = stats.levene(g0, g1, center='trimmed') 

    if print_info:
            print("p-value in Levene's: %s" %pl)


    # ------------------------------------------------------
    # Compute t-test for equal means (require equal variances)
    # ------------------------------------------------------

    T, p = stats.ttest_ind(g0, g1)

    if print_info:
            print("p-value in t-test: %s" %p)

    # ------------------------------------------------------
    # Compute paired t-test (require equal variances)
    # ------------------------------------------------------

    T, p_rel = stats.ttest_rel(g0, g1)

    if print_info:
            print("p-value in paired t-test: %s" %p_rel)

    # ------------------------------------------------------
    # Compute Wilcoxon to check for non-parametric
    # ------------------------------------------------------

    Tw, pw = stats.wilcoxon(g0, g1)

    if print_info:
            print("p-value in Wilcoxon's: %s" %pw)


    # ------------------------------------------------------
    # Analysis of results
    # ------------------------------------------------------

    print("------------ Analysis  -------------")
    if p_s_g0<alpha:
        print("Shapiro-Wilk test on g0: reject H0 --> Doesn't come from normal distribution" )
    if p_s_g1<alpha:
        print("Shapiro-Wilk test on g1: reject H0 --> Doesn't come from normal distribution" )

    if p_s_g0>alpha:
        print("Shapiro-Wilk test on g0: can't reject H0 --> Might come from normal distribution" )
    if p_s_g1>alpha:
        print("Shapiro-Wilk test on g1: can't reject H0 --> Might come from normal distribution" )

    if pl<alpha:
        print("Levene's: reject H0 --> There is a difference in variances of groups" )
        print("           Parametric t-test can't be used")
        if pw<alpha:
                print("Wilcoxon: reject H0 --> Configurations are significantly different" )
        if pw>alpha:
                print("Wilcoxon: can't reject H0 --> Not significant difference: come from same population" )

    if pl>alpha:
        print("Levene's: can't reject H0 --> Not significant difference in variances" )
        print("           Parametric t-test CAN be used" )
        if p<alpha:
            print("T-test: reject H0 --> Mean of configurations is significantly different" )
        if p_rel<alpha:
            print("Paired T-test: reject H0 --> Mean of configurations is significantly different" )

        if p>alpha:
            print("T-test: can't reject H0 --> Not significant difference in mean" )
        if p_rel>alpha:
            print("Paired T-test: can't reject H0 --> Not significant difference in mean") 


    return None


print('----------------------------------------------------------------')
print(' SNR' )
print('----------------------------------------------------------------')
stat_analysis_2groups(g0=snr0, g1=snr1, print_info=True)


print('----------------------------------------------------------------')
print(' std_base')
print('----------------------------------------------------------------')
stat_analysis_2groups(g0=std_base_0, g1=std_base_1, print_info=True)