# KMEANS

### 18/05/20

## Importing Libraries

In [2]:
# Import all libraries needed for the tutorial

# General syntax to import specific functions in a library: 
##from (library) import (specific library function)
from pandas import DataFrame, read_csv

# General syntax to import a library but no functions: 
##import (library) as (give the library a nickname/alias)
import matplotlib.pyplot as plt
import pandas as pd #this is how I usually import pandas
import sys #only needed to determine Python version number
import matplotlib #only needed to determine Matplotlib version number

# Enable inline plotting
%matplotlib inline

import scipy
import numpy as np
#import scipy.signal as signal
from scipy.signal import *
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

import sys  
sys.path.insert(0, '/Users/louiseplacidet/Desktop/PIR/GITPIR/GIT_29_04/PIR/AdabandFlt')
from AdaBandFlt import *

%matplotlib tk

## Loading Data

In [3]:
#load data
# file path of csv file
#Location = r'/Users/33614/ExternalDrive/SUPAERO/PIR_2A/Data/data_spikes/E18KABaseline_Bcut.txt'
#Location = r'/Users/SYL21/D_Drive/SUPAERO/PIR_2A/Data/data_spikes/E18KABaseline_Bcut.txt'
#Location = r'/Users/louiseplacidet/Desktop/PIR/Data/data_spikes/E18KABaseline_Bcut.txt'
Location = r'/Users/louiseplacidet/Desktop/PIR/Data/new_spike_data/newdata/E18KABaseline_BcutV2groundAll.txt'

# create dataframe
df = pd.read_csv(Location, sep='\t',skiprows=[0,1,3] , index_col='%t           ')

In [4]:
#####################################################################################################################
####  BANK OF PARTS OF DATA

all_raw_data = df #Entire recording from all electrodes

full_signal = df.iloc[:,1] #Entire recording from the first electrode

def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = filtfilt(b, a, data)
    return y

# Sample rate and desired cutoff frequencies (in Hz).
fs = 25000.0
lowcut = 100.0
highcut = 2500.0


y = butter_bandpass_filter(df.iloc[:,1], lowcut, highcut, fs, order=6)


filtereddf = pd.DataFrame(y)
filtereddf.index = df.index

signal_filtered = filtereddf.iloc[:,0] #Entire recording filtered by bandpass, for one electrode


fs = 25000

#xminnoise = int(np.round(11114*(fs/1000)))
#xmaxnoise = int(np.round(18511*(fs/1000)))

#noise_data = filtereddf.iloc[xminnoise:xmaxnoise,0]

#xminspike = int(np.round(130826*(fs/1000)))
#xmaxspike = int(np.round(131699*(fs/1000)))

#burst_data = filtereddf.iloc[xminspike:xmaxspike,0]

#begin_data = signal_filtered.iloc[:500000]

In [5]:
################################################################################################
####   TEST ADABANDFLT AVEC SIGNAL ORIGINAL (PASSE-BANDE+WIENER)

# Choices:
#  - full_signal : entire signal from first electrode
#  - signal_filtered : entire signal from first electrode after bandpass
#  - noise_data : part of signal where only noise (after bandpass)
#  - spike_data : part of signal where burst (after bandpass)


signal = signal_filtered

## Detecting and Aligning the Spikes

### Setting Up Thresholds adapted to Noise Levels

In [6]:
fs = 25000

noise_levels = init_noise_levels(signal_filtered, fs, 
                      noise_window_size = 0.01,
                      required_valid_windows = 100,
                      old_noise_level_propagation = 0.8, 
                      test_level = 5,
                      estimator_type = "RMS",
                      percentile_value = 25)


In [7]:
plt.figure()
plt.plot(noise_levels)
plt.grid(True)
plt.xlabel('Time')
plt.ylabel('Noise Amplitude [µV]')
plt.title('Noise Levels')

Text(0.5, 1.0, 'Noise Levels')

### Detecting Spikes

In [8]:
spike_info = find_spikes(signal, noise_levels, fs, 
                           window_size = 0.0005, 
                           noise_window_size = 0.01,
                           threshold_factor = 3.5,
                           maxseparation = 0.0015)

#spike_info.columns = ['indice_max','indice_min','indice_depass_positif','indice_depass_negatif', 'indice_1er_depass','indice_zero_central','i_max-i_min','Delta_amplitudes']
spike_info

Unnamed: 0,indice_max,indice_min,indice_depass_positif,indice_depass_negatif,indice_1er_depass,indice_zero_central,i_max-i_min,Delta_amplitudes
0,53278,53286,53277,53285,53277,53281,-8,16.236913
1,53391,53375,53389,53373,53373,53386,16,18.401228
2,53492,53504,53489,53499,53489,53497,-12,31.926263
3,53517,53504,53514,53499,53499,53511,13,30.926425
4,53645,53652,53643,53652,53643,53649,-7,17.141957
...,...,...,...,...,...,...,...,...
538,477905,477919,477897,477916,477897,477912,-14,29.808358
539,477952,477961,477947,477958,477947,477956,-9,33.708894
540,478147,478158,478142,478155,478142,478152,-11,32.536774
541,478169,478158,478165,478155,478155,478163,11,27.879169


### Recording the spikes

In [9]:
spike_data = record_spikes(signal, fs, spike_info,
                           'indice_zero_central',
                           t_before = 0.001, 
                           t_after = 0.002)


(544, 76)


In [10]:
spike_data_oneline = record_spikes_oneline(signal, fs, spike_info,
                                           'indice_zero_central',
                                           t_before = 0.001,
                                           t_after = 0.002)

### Plotting the spikes

In [11]:
#signal.plot(color = 'blue')
#spike_data_oneline.plot(color = 'red')
plt.plot(df.index, signal, color = 'blue')
plt.plot(spike_data_oneline.index, spike_data_oneline, color = 'red')
plt.title('Filtered Signal with Detected Spikes with RMS')
plt.xlabel('Time Windows')
plt.ylabel('Amplitude [µV]')
plt.legend()
plt.grid(True)

No handles with labels found to put in legend.


In [12]:
print_spikes(spike_data,
             t_before_alignement = 0.001,
             first_spike = 1,
             last_spike = 50,
             fs = 25000)

  func(*args)
  return self.func(*args)


# Bilan PCA + KMEANS

## PCA and KMEANS on spikes

In [43]:
def PCA_and_KMEANS_spikes(spike_data, spike_info, nb_PCA_components=3, nb_clusters=3):
    
    ## PCA
    pca_data = np.array(spike_data.iloc[:,1:].values).transpose()
    pca = PCA(n_components=nb_PCA_components)
    pca.fit(pca_data)
    PCA_X = pca.transform(pca_data)
    
    ## KMEANS
    kmeans = KMeans(n_clusters=nb_clusters, random_state=0).fit(PCA_X)

    
    ## Ajout du label des clusters dans spike info
    spike_info['cluster_label'] = kmeans.labels_
    
    return PCA_X, kmeans, spike_info
    

## Plotting the PCA

In [47]:
## Fonction qui plot la PCA

def PCA_plot(PCA_X):
    
    fig = plt.figure(4,figsize=(4,3))
    plt.clf()
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
    plt.cla()
  
    ax.scatter(PCA_X[:, 0], PCA_X[:, 1], PCA_X[:, 2], cmap=plt.cm.nipy_spectral,
           edgecolor='k')

    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])

    plt.show()

## Plotting the KMEANS

In [51]:
## Fonction qui plot le KMEANS

def KMEANS_plot(PCA_X,kmeans):
    
    y = kmeans.labels_
    
    fig = plt.figure(1, figsize=(4, 3))
    plt.clf()
    ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)

    ax.scatter(PCA_X[:, 0], PCA_X[:, 1], PCA_X[:, 2], c=y, cmap=plt.cm.nipy_spectral,
           edgecolor='k')

    ax.w_xaxis.set_ticklabels([])
    ax.w_yaxis.set_ticklabels([])
    ax.w_zaxis.set_ticklabels([])

    plt.show()

## Plotting the spikes from the different clusters after KMEANS

In [57]:
## Fonction qui plot les spikes dans les clusters du KMEANS

def print_spikes_clusterized(spike_data,
                             labels,
                             t_before_alignement = 0,
                             nb_spike = 20,
                             y_lim_min = -50,
                             y_lim_max = 60,
                             fs = 25000):
    
    nb_clusters = max(labels) + 1
    
    nb_line = nb_clusters//2
    if nb_clusters%2 != 0:
        nb_line += 1
    
    #spike_data.iloc[:,first_spike:last_spike].plot()
    t_b = int(np.round(fs*(t_before_alignement)))
    y = (spike_data.iloc[:,0]-t_b)*1000/fs
        
    a = [i+1 for i in range(len(labels))]
    b = np.transpose([a,list(labels)])
    
    figure = plt.figure()
    for nb in range(nb_clusters):
        data = spike_data.iloc[:,[x for x,y in b if y==nb]]
        m = len(data.values[0])
        
        kept = []
        i = 0
        while i < nb_spike:
            r = randint(0,m-1)
            if (r in kept) == False:
                kept.append(r)
                i += 1
                
        x = data.iloc[:,kept].values
        
        #plt.subplot(nb_line, 2, nb+1)
        axes = figure.add_subplot(nb_line, 2, nb+1)
        axes.plot(y, x)
        axes.set_xlabel('Time in ms')
        axes.set_ylim(y_lim_min , y_lim_max)
        axes.grid()
        
"""
print_spikes_clusterized(spike_data,
                             kmeans.labels_,
                             t_before_alignement = 0.0015,
                             nb_spike = 20,
                             y_lim_min = -50,
                             y_lim_max = 60,
                             fs = 25000)
"""

'\nprint_spikes_clusterized(spike_data,\n                             kmeans.labels_,\n                             t_before_alignement = 0.0015,\n                             nb_spike = 20,\n                             y_lim_min = -50,\n                             y_lim_max = 60,\n                             fs = 25000)\n'

## Printing the spikes from clusters on original signal

In [None]:
## Fonction qui affiche les spikes des différents clusters sur échelle temporelle

def print_spikes_clusterized_oneline(signal, spike_info_avec_clusters):

## Histogram of Amplitudes and Time

In [66]:
## Histogramme de différences d'amplitude Peak-to-Peak des spikes

def histogram_spikes_amplitude(spike_info):
    plt.hist(spike_info['Delta_amplitudes'], bins='auto')
    plt.title("Histogram of Spike Peak-to-Peak Amplitude")
    plt.show()

In [68]:
## Histogramme de l'écartement temporel des spikes (imax-imin)

def histogram_spikes_time(spike_info):
    plt.hist(spike_info['i_max-i_min'], bins='auto')
    plt.title("Histogram of Spike i_max-i_min")
    plt.show()

# Tests des fonctions

In [45]:
# Function: PCA_and_KMEANS_spikes(spike_data, spike_info, nb_PCA_components=3, nb_clusters=3)

PCA_X, KMEANS, updated_spike_info = PCA_and_KMEANS_spikes(spike_data, spike_info, 5, 3)

In [48]:
# Function: PCA_plot(spike_data, nb_clusters=3)

PCA_plot(PCA_X)

In [53]:
# Function: KMEANS_plot(spike_data,updated_spike_info)

KMEANS_plot(PCA_X,KMEANS)

In [58]:
# Function: print_spikes_clusterized(spike_data,
#                             labels,
#                             t_before_alignement = 0,
#                             nb_spike = 20,
#                             y_lim_min = -50,
#                             y_lim_max = 60,
#                             fs = 25000)
                        
print_spikes_clusterized(spike_data,
                             KMEANS.labels_,
                             t_before_alignement = 0,
                             nb_spike = 20,
                             y_lim_min = -50,
                             y_lim_max = 60,
                             fs = 25000)                        

In [67]:
# Function: histogram_spikes_amplitude(spike_info)

histogram_spikes(spike_info)

In [None]:
# Function: 