In [6]:
import os.path
import glob

import numpy as np
import pandas as pd

from io import BytesIO
from matplotlib import pyplot as plt


from scipy.io import loadmat
from scipy.interpolate import interp1d
from scipy import signal
from scipy.signal import butter, sosfilt

from scipy.stats import multivariate_normal
from scipy.stats import gamma
from scipy.stats import wishart


from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
from sklearn.preprocessing import StandardScaler

from sklearn.cross_decomposition import PLSRegression
from sklearn.linear_model import LinearRegression

from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score

In [None]:
centers = [(4,1), (6, 1), (8,1),\
           (3, 2), (5, 2), (7,2),\
           (2, 3), (4, 3), (6, 3), (8,3),\
           (1, 4), (3, 4), (5, 4), (7, 4),\
           (2, 5), (4, 5), (6, 5), \
           (1, 6), (3, 6), (5, 6), (7, 6),\
           (2, 7), (4, 7), (6, 7), (8, 7),\
           (1, 8), (3, 8), (5, 8), (7, 8),\
           (2, 9), (4, 9), (6, 9), (8, 9),\
           (1, 10), (3, 10), (5, 10), (7, 10),\
           (2, 11), (4, 11), (6, 11), (8, 11),\
           (1, 12), (3, 12), (5, 12), (7, 12),\
           (2, 13), (4, 13), (6, 13), (8, 13),\
           (1, 14), (3, 14), (5, 14), (7, 14),\
           (2, 15), (4, 15), (6, 15), (8, 15),\
           (1, 16), (3, 16), (5, 16), (7, 16),\
           (2, 17), (4, 17), (6, 17)]

In [8]:
title_font = {'family':'Arial', 'size': 20, 'color':'black', 'weight':'normal',
              'verticalalignment':'bottom'} # Bottom vertical alignment for more space
axis_font = {'family':'Arial', 'size': 16}


In [9]:
path_name_train = 'Data\\20100802S1_ECoG_Motion6'
path_name_test = 'Data\\20100726S1_ECoG_Motion6'

In [2]:
def synchronize_simple(signal_data, motion_data): 
       
    time, signal_idx, motion_idx = np.intersect1d(signal_data[:,0], motion_data[:,0],assume_unique=True, return_indices=True)
    ecog_signal = signal_data[signal_idx]
    motion = motion_data[motion_idx]

    return ecog_signal, motion, time

In [14]:
def synchronize_interpol(signal_data, motion_data): 
    start = max(signal_data[1, 0],motion_data[1,0])
    end = min(signal_data[-1, 0],motion_data[-1,0])

    #cutting signal and motion, only overlapping time left
    signal_data = signal_data[:,:][(signal_data[:,0]>=start)]
    signal_data = signal_data[:,:][(signal_data[:,0]<=end)]
    motion_data = motion_data[:,:][motion_data[:,0]>= start] 
    motion_data = motion_data[:,:][motion_data[:,0]<= end]
    M = []
    #signal and motion have different time stamps, we need to synchronise them
    #interpolating motion and calculating arm position in moments of "signal time"
    for i in range(1,motion_data.shape[1]):
        interpol = interp1d(motion_data[:,0],motion_data[:,i],kind="cubic")
        x = interpol(signal_data[:,0])
        M.append(x)

    #downsampling in 10 times to get faster calcultions

    ecog_signal = signal_data[::10,1:]
    motion = np.array(M).T[::10,:]
    time = signal_data[::10,0]
    
    
    #self.signal = signal_data[:,1:]
    #self.motion = np.array(M).T[:,:]
    #self.time = signal_data[:,0]

    return ecog_signal, motion, time

In [15]:
#signal filtering (not sure that it works correctly)
def bandpass_filter(ecog_signal, lowcut, highcut, fs = 100, order=7):
    nyq =  fs
    low = lowcut / nyq
    high = highcut / nyq
    sos = signal.butter(order,  (low, high), btype='band',analog=False,output='sos')
    filtered_signal = np.array([sosfilt(sos, ecog_signal[:,i]) for i in range(ecog_signal.shape[1])])

    return filtered_signal.T

In [16]:
 #Generating a scalogram by wavelet transformation 
def scalo(ecog_signal, motion, time, window, freqs,start,end, step = 100): #window in sec,freqs in Hz, step in ms
    #div = 1
    X = ecog_signal[start:end,:]
    div = 10 #downsampling
    window_len = int(((window * 1000 // step) + 2) * step//div)
    scalo = np.empty((X.shape[0]-window_len,X.shape[1],freqs.shape[0],(window * 1000 // step) + 2))
    for i in range(X.shape[1]):
        for j in range(window_len,X.shape[0]):
            scalo[j-window_len,i,:,:] = signal.cwt(data = X[j-window_len:j,i],
                                                    wavelet=signal.morlet,widths = freqs)[:,::step//div] **2
    return scalo, motion[start+window_len:end,:], time[start+window_len:end]
    

In [30]:
def get_mean(ECoG, time):
    mean = []
    intens = 0
    for j in range(len(time)):
        intens_sum = 0
        x, y = 0, 0
        for i in range(64):
            center_x, center_y = centers[i]
            intens = abs(ECoG[j,i])
            intens_sum += intens
            x += center_x * intens
            y += center_y * intens
            
        mean.append([x / intens_sum, y / intens_sum])
        
    return np.array(mean)

In [27]:

def get_disp(mean, ECoG, time):
    disp = []
    for j in range(len(time)):
        mean_x, mean_y = mean[j]
        intens_sum = 0
        disp_x, disp_y = 0, 0
        for i in range(64):
            center_x, center_y = centers[i]
            intens = abs(ECoG[j, i])
            intens_sum += intens
            disp_x += (center_x - mean_x)**2 * intens
            disp_y += (center_y - mean_y)**2 * intens

        disp.append([disp_x / intens_sum, disp_y / intens_sum])
    return np.array(disp)

In [29]:
def get_intens(mean, disp, mv, ECoG, time):
    intens = []
    for j in range(len(time)):
        mean_x, mean_y = mean[j]
        # Adjust our mean to integer points we have in centers
        x = int(round(mean_x))
        y = int(round(mean_y))
        if (x, y) in centers:
            i = centers.index((x,y))
        else:
            if y<17:
                y = y+1
            else:
                y = y-1
            i = centers.index((x,y))
        i = 0;
        
        pred_int = mv[j].pdf(centers[0])
        intens.append(ECoG[j,i]/pred_int)
    return intens

In [25]:
def plot_ecog(time, ecog):
    plt.figure(figsize=(14,6))
    plt.plot(time, ecog[:,0])
    plt.xticks( fontsize=14, rotation=0)
    plt.yticks(fontsize=14, rotation=0)
    plt.title('Канал #1', **title_font)
    plt.xlabel('Время, мс', **axis_font)
    plt.ylabel('Напряжение, мВ', **axis_font)
    plt.grid(True)
    plt.show()
    plt.close()