# Imports
---

In [None]:
from scipy.io import loadmat
from scipy.signal import welch, butter, lfilter
import numpy as np
import matplotlib.pyplot as plt
from numpy.lib.stride_tricks import as_strided
import itertools
from matplotlib.patches import Patch
%matplotlib inline

In [None]:
from sklearn.model_selection import cross_val_score
from sklearn.metrics import confusion_matrix
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
import numpy as np

# Verify Original Data
---

In [None]:
data = loadmat('../p3_subjectData.mat')
fs = data['subjectData'][0,0][0][0,0][1][0,0][0][0,0]
sensors = np.concatenate(data['subjectData'][0,0][0][0,0][1][0,0][1][:,0])
trialTypes = ['Flx','Ext','Rst']

In [None]:
def generateData(period,run):
    subject = data['subjectData'][0,0]
    p = subject[period]
    r = p[0,run]
    header = r[1][0,0]
    assert(np.array_equal(sensors,np.concatenate(header[1][:,0])))
    assert(fs==header[0][0,0])
    timestamps = header[2][0,0][1]
    return r[0][timestamps[0,0]:timestamps[-1,0]], header[2][0,0][0][3:-1:5,0]//100-1, timestamps[1:-1].reshape(-1,5)

def filterTrials(period,run,motion):
    eeg, motions, timestamps = generateData(period,run)
    return eeg, timestamps[motions==motion]

In [None]:
eeg, motions, timestamps = generateData(0,0)

In [None]:
window = 1000
overlap = 900
def rolling_window(a, window, overlap):
    shape = (int((a.shape[0] - overlap)/(window-overlap)), window) + a.shape[1:]
    strides = (a.strides[0]*(window-overlap),)+a.strides
    return as_strided(a, shape=shape, strides=strides)
def mav(eeg):
    return rolling_window(np.abs(eeg), window, overlap).mean(axis=1)

In [None]:
for i in range(32):
    plt.plot(eeg[:,i])

# Import Data
---

In [None]:
def load(name):
    return loadmat(f'features/{name}.mat')[name]
def rolling_window(a, window, overlap):
    shape = (int((a.shape[0] - overlap)/(window-overlap)), window) + a.shape[1:]
    strides = (a.strides[0]*(window-overlap),)+a.strides
    return as_strided(a, shape=shape, strides=strides)
def mav(eeg):
    return rolling_window(np.abs(eeg), window, overlap).mean(axis=1)

In [None]:
sub1pre = load('sub1PRE_DATA')
sub1post = load('sub1POST_DATA')
nsensors = sub1pre.shape[1]
ntrials = sub1pre.shape[2]
sub1pre = mav(sub1pre).T
sub1post = mav(sub1post).T

In [None]:
fig,ax = plt.subplots(1,2)
ax[0].imshow(sub1pre[0,:,:],vmin=0, vmax=4.7)
ax[1].imshow(sub1post[0,:,:],vmin=0, vmax=4.7)
plt.show()

In [None]:
sub1all = np.vstack((sub1pre,sub1post)).reshape(-1,sub1pre.shape[1]*sub1pre.shape[2])
labels = np.concatenate((np.zeros(sub1pre.shape[0]),np.ones(sub1post.shape[0])))

In [None]:
print(f'LDA accuracy: {cross_val_score(LinearDiscriminantAnalysis(),sub1all,labels,cv=15).mean():.4f}')
print(f'QDA accuracy: {cross_val_score(QuadraticDiscriminantAnalysis(),sub1all,labels,cv=15).mean():.4f}')

In [None]:
def classify(sig1,sig2):
    clf = LinearDiscriminantAnalysis().fit()