In [5]:
import os 
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, lfilter

path = 'C:/Users/DJ/Documents/Python_Scripts/EPG/GIT_EPG/in_vivo'

os.chdir(path)
print(os.listdir())

['.ipynb_checkpoints', 'instant_decode.py', 'sdf.py', 'signal_analysis.py', 'Untitled.ipynb']


## Load the utilities

In [7]:
%run importrhsutilities.py # this has to be downloaded. see the README
%matplotlib inline

In [None]:
%run DJutilities.py

## Select the file and load the data

In [None]:
filename = "1.rhs"

# should firstly check the proper channel from here.
result, data_present = load_file(filename)

# if you want to see all channels(stimulator, amplitude, etc...), make below code un-commented
# print_all_channel_names(result)

### 1. Show the raw data

In [None]:
# plotting the original data from 9th, 13rd channel
channel_1 = 9
channel_2 = 13

fig = plt.figure()
ax1 = fig.add_subplot(2,1,1)
ax1.plot(result['t'], result['amplifier_data'][channel_1])

ax2 = fig.add_subplot(2,1,2)
ax2.plot(result['t'], result['amplifier_data'][channel_2])
plt.show()

### 2. Show the filtered data

In [None]:
# Setting cutting-frequency for filtering: [low, high]
channel_1 = 9
channel_2 = 13
spike_detection = [500, 5000]
lfp_detection = [1, 500] 
detection_mode = lfp_detection

# Degree of cutting magnitude?(order), threshold(inclination_threshold)
order = 2  
inclination_threshold = 3  # 3 for lfp, 11 for spike based on my experience
fs = result['frequency_parameters']['amplifier_sample_rate'] # sampling frequency
filtered_data1 = butter_bandpass_filter(result['amplifier_data'][channel_1],detection_mode[0],detection_mode[1],fs, order)
fig = plt.figure(figsize=(16,8))
plt.plot(result['t'], result['amplifier_data'][channel_1], label="original data")
plt.plot(result['t'], filtered_data1, label='filtered_data' )
plt.legend(loc='upper right')


### 3. Show the frequency distribution

In [None]:
freq_distribution, freq_amp = frequency_analysis(filtered_data1, 30000)
freq_distributiion2, freq_amp2 = frequency_analysis(result['amplifier_data'][channel_1], 30000)

plt.plot(freq_distribution, freq_amp, color='red', label='filtered')
plt.plot(freq_distribution2, freq_amp2, color='black', label='raw')
plt.legend(loc='upper right')
plt.xlim(0,30)
plt.ylim(0,1.5)

### Another way to do.. is using resample which reduces the size of file significantly....

In [None]:
# how to resample

from scipy import signal
a_signal = int(len(data)/30)  # 30khz -> 3khz -> 1khz
f2 = signal.resample(filtered_data1, a_signal)
x_new = np.linspace(time1[0], time1[-1], a_signal)

plt.plot(time1, filtered_data1, color ='black')

plt.plot( x_new, f2, color='red')
plt.show()
print(len(x_new), len(f2))

freq_time, freq_amp = frequency_analysis(f2, 30000)
plt.plot(freq_time, freq_amp)

amplitude
index = np.arange(len(tim))
# for i in tim:
#     plt.bar(index, data[i])
plt.bar(index, abs(data[tim]))
# plt.bar(index, data[tim])
plt.ylim(0,300)
    

# Below needs more revision...
### (4. selects the typical data such as firing or etc...)

In [None]:
arg = find_peaks_arg(result['t'], filtered_data1, inclination_threshold,int(fs), mode = 'lfp')
fig = plt.figure(figsize=(16,8))
plt.plot(result['t'], filtered_data1, label = "filtered signal")
plt.plot(result['t'][arg], filtered_data1[arg], label="chosen signal")
plt.legend(loc='upper right', fontsize=20)

### (5. get the spike)

In [None]:
import numpy as np
from scipy.signal import argrelextrema
import pandas as pd

df = pd.DataFrame()

for test_i in arg:
    if test_i-60 < 0 or test_i+90 > len(data):
        pass
    else:
        SPIKE_START_POINT = 60
        SPIKE_END_POINT = 90
        
        data_lp = argrelextrema(data[test_i-SPIKE_START_POINT:test_i],np.greater)[0][-1] + test_i - SPIKE_START_POINT # local maxima 
        data_rp = argrelextrema(data[test_i:test_i+SPIKE_END_POINT],np.greater)[0][0] + test_i # local maxima 
        data_mp = argrelextrema(data[test_i-4:test_i+20], np.less)[0][0]+ test_i - 3 # local minima +

        ldist = data_mp -data_lp
        rdist = data_rp - data_mp

        data_lv = data[data_lp]
        data_rv = data[data_rp]
        data_mv = data[data_mp]
        

        lmv = data_lv - data_mv
        rmv = data_rv - data_mv
        
        if len(df) == 0:
            df = pd.DataFrame(np.array([data_lp,data_rp,data_mp,data_lv,data_mv,data_rv, ldist, rdist, lmv, rmv]),
                              index = ["Left_point","Right_point","Middle_point","Left_voltage","Middle_voltage","Right_voltage",
                                      "Left_distance", "Right_distance", "Left_middle_voltage", "Right_middle_voltage"],
                              columns =[test_i])
        else:
            df[test_i] =np.array([data_lp,data_rp,data_mp,data_lv,data_mv,data_rv, ldist, rdist, lmv, rmv]) # concat makes poor performance..

        plot_configuration = False
        if plot_configuration == True:
            plt.plot(np.arange(0,5,1/30), data[test_i-SPIKE_START_POINT:test_i+SPIKE_END_POINT])
            plt.scatter(np.arange(0,5,1/30)[data_lp-test_i+SPIKE_START_POINT], data[data_lp])
            plt.scatter(np.arange(0,5,1/30)[data_mp-test_i+SPIKE_START_POINT], data[data_mp])
            plt.scatter(np.arange(0,5,1/30)[data_rp-test_i+SPIKE_START_POINT], data[data_rp])
            plt.show()

In [None]:
fig = plt.figure(figsize = (4,8))
ax1 = fig.add_subplot(2,1,1)
ax2 = fig.add_subplot(2,1,2)
d = int(max(df.T['Right_distance']))

for i in df.columns:
    LP = int(df[i]['Left_point'])
    MP = int(df[i]['Middle_point'])
    RP = d+LP

    graph_want = data[LP-20:RP+20]
    if MP - LP < 13:
        
        ax1.plot(np.arange(len(graph_want)), graph_want, alpha = 0.1)
    elif MP - LP >= 13 and MP- LP  < 30:
        
        ax2.plot(np.arange(len(graph_want)), graph_want, alpha = 0.1)
        
    
plt.show()

In [None]:
# fig = plt.figure(figsize = (4,8))
# ax1 = fig.add_subplot(2,1,1)
# ax2 = fig.add_subplot(2,1,2)
d = int(max(df.T['Right_distance']))

for i in df.columns:
    graph_want = data[LP-20:RP+20]
    LP = int(df[i]['Left_point'])
    MP = int(df[i]['Middle_point'])
    RP = d+LP
    plt.plot(np.arange(len(graph_want)), graph_want, alpha = 0.1) 
    
plt.show()



### machine learning. application -> need more study

In [None]:
import seaborn as sns
sns.pairplot(df[3:].T)

In [None]:
from sklearn.preprocessing import StandardScaler
x = df.T.drop(['Left_point','Right_point', 'Middle_point', 'Left_middle_voltage', 'Right_middle_voltage', 'Middle_voltage'], axis=1).values
y = df.T['Middle_voltage'].values
x = StandardScaler().fit_transform(x)
features = ['Left_voltage', 'Right_voltage', 'Left_distance', 'Right_distance']
pd.DataFrame(x, columns=features).head()

In [None]:
from sklearn.decomposition import PCA

pca = PCA(n_components = 2)
principalComponents = pca.fit_transform(x)
principalDf = pd.DataFrame(data= principalComponents, columns = [1,2])

# pca.explained_variance_ratio_
principalDf[2]
plt.scatter(principalDf[1],principalDf[2])

In [None]:
pca = PCA()
pca.fit(df[7:9].T)
pca.components_
X_pca = pca.transform(df[7:9].T)
display(X_pca.shape)


plt.scatter(X_pca[:,0], X_pca[:,1], c=df.T['Middle_voltage'])
plt.colorbar()

In [None]:
from sklearn.cluster import KMeans
from sklearn import preprocessing

scaler = preprocessing.MinMaxScaler()
processed_data = df[3:8].T.copy()

processed_data[['Left_distance','Right_distance']] = scaler.fit_transform(processed_data[['Left_distance','Right_distance']])
processed_data[['Left_voltage','Right_voltage','Middle_voltage']] = scaler.fit_transform(processed_data[['Left_voltage','Right_voltage','Middle_voltage']])
sns.set_palette("Set2")

for i in range(1,9):
    kmeans = KMeans(n_clusters = i)
    kmeans.fit(processed_data)
    result_by_sklearn =  processed_data.copy()
    result_by_sklearn['cluster'] = kmeans.labels_
    sns.pairplot(data = result_by_sklearn, hue='cluster')
    plt.show()
    

# kmeans = KMeans(n_clusters = 8)
# kmeans.fit(df[3:8].T)

# result_by_sklearn =  df[3:8].T.copy()
# result_by_sklearn['cluster'] = kmeans.labels_

# sns.pairplot(data = result_by_sklearn, hue='cluster')


In [None]:
from scipy.cluster.hierarchy import linkage, dendrogram
linkage_list = ['single', 'complete', 'average', 'centroid', 'ward'] #min거리, max거리, average, centroid 거리, 군집간 제곱합-(군집내제곱합)
data3 = [df[3:8].T.copy(), processed_data]

fig, axes = plt.subplots(nrows=len(linkage_list), ncols=2, figsize=(16,35))

for i in range(len(linkage_list)):
    for j in range(len(data3)):
        hierarchical_single = linkage(data3[j], method=linkage_list[i])
        dn = dendrogram(hierarchical_single, ax=axes[i][j])
        axes[i][j].title.set_text(linkage_list[i])
plt.show()

In [None]:
from sklearn.cluster import AgglomerativeClustering
agg_clustering = AgglomerativeClustering(n_clusters = 2, linkage='ward')
labels = agg_clustering.fit_predict(data3)

# plt.figure(figsize = (20,6))
# plt.subplot(131)
# sns.scatterplot(x)
labels

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize = (10,10))
ax = fig.add_subplot(111, projection = '3d')

x = data3['Left_voltage']
y = data3['Middle_voltage']
z = data3['Right_voltage']
ax.scatter(x,y,z,c =labels, s =20, alpha=0.5, cmap='rainbow')