In [44]:
import os
import numpy as np
import scipy.io as sio
import scipy
from scipy import signal
from scipy.spatial import distance as hamm_dist

In [146]:
def STFT(x,N):
    
    hann_w = signal.hann(N)    # Hann Window
    hann_w = hann_w.reshape(np.shape(hann_w)[0],1)
    
    for i in range(0,x.shape[2]):
        horz = np.array([10000])
        horz = horz.reshape((horz.shape[0],1))
        
        for j in range(0,x.shape[1]):   
            mat_X = np.zeros((N,1))
            temp = x[:,j][:,i]
            
            for k in range(0,x.shape[0],48):
                if np.shape(temp[k:N+k])[0] == 32:
                    sample_window = temp[k:N+k]
                    sample_window = sample_window.reshape(np.shape(sample_window)[0],1)
                    
                    data_window = sample_window * hann_w    # element-wise multiplication
                    data_window = data_window.reshape(np.shape(data_window)[0],1)
                    
                    mat_X = np.hstack((mat_X,data_window))
            
            X_slice = mat_X[:,1:][3:8,:]
            
            for l in range(0,5):
                temp = X_slice[l]
                horz = np.vstack((horz,temp.reshape((temp.shape[0],1))))

            if(horz[0] == 10000):
                horz = horz[1:]
        
        if i == 0:
            result = np.zeros((len(horz),1))
            result = np.hstack((result,horz))
        else:
            result = np.hstack((result,horz))
    
    return result[:,1:]

In [154]:
# Calculate matrix A
def calc_mat_A(L,M):
    
    # To generate the same result on every run
    np.random.seed(1235)
    
    A = np.random.uniform(-1,1,L)
    for i in range(len(M) - 1):
        A_rand = np.random.uniform(-1,1,L)
        A = np.vstack((A,A_rand))
    
    sum_A = []
    for i in range(len(M)):
        sum_A.append(sum(A[i]))
    
    sum_A = np.array(sum_A)
    inv_sum_A = 1/sum_A
    
    res = A * inv_sum_A[:, np.newaxis]
    return res.T

# Calculate sign matrix Y
def calc_mat_Y(A,Z):
    Y = np.dot(A,Z)
    Y_sign = np.sign(Y)
    
    return Y,Y_sign

# Compare bit strings 
def calc_distance(Y,Y_test):
    dist = np.zeros((Y_test.shape[1],Y.shape[1]))
    dist = [[hamm_dist.hamming(Y_test[:,i],Y[:,j]) for j in range(dist.shape[1])] for i in range(dist.shape[0])]
    dist = np.array(dist)
    
    sorted_dist = dist.argsort()
    
    res = np.zeros((Y_test.shape[1],Y.shape[1]))
    res = [[y_train[sorted_dist[i][j], 0] for j in range(Y.shape[1])] for i in range(Y_test.shape[1])]    
    res = np.array(res)
    
    return res

#### Load data

In [155]:
eeg = scipy.io.loadmat('data/eeg.mat')

x_train = eeg['x_train']
x_test = eeg['x_te']

y_train = eeg['y_train']
y_test = eeg['y_te']

#### Execution

In [183]:
# Initialize variables
N = 32
L = 100
K = 22

Z = STFT(x_train,N)
Z_test = STFT(x_test,N)

accuracy = []
for i in range(10,L,5):
    for j in range(3,K,2):
        temp = []
        
        l = i
        k = j
        
        A = calc_mat_A(l,Z)
        
        Y,Y_sign = calc_mat_Y(A,Z)
        Y_test,Y_test_sign = calc_mat_Y(A,Z_test)
        
        index_mat = calc_distance(Y_sign,Y_test_sign)
        k_index_mat = index_mat[:,0:k]        
        
        y_test_pred = np.zeros((y_test.shape[0],1))
        for p in range(0,28):
            y_test_pred[p] = np.median(k_index_mat[p,:])

        count = 0    
        for p in range(0,28):
            if y_test_pred[p] == y_test[p]:
                count+=1

        acc = count/28
        
        temp.extend([i,j,acc])
        accuracy.append(temp)

sorted_reverse_accuracy = sorted(accuracy,key=lambda l:l[-1], reverse=True)

print("ACCURACY TABLE\n")
print("-----------------------------------")
for i in range(len(sorted_reverse_accuracy[0:15])):
    print('| L=' + str(sorted_reverse_accuracy[i][0]).zfill(2) + ' | K=' + str(sorted_reverse_accuracy[i][1]).zfill(2) + ' | Accuracy = ' + str(round(sorted_reverse_accuracy[i][2], 4)) + ' |')

print("-----------------------------------")

ACCURACY TABLE

-----------------------------------
| L=90 | K=03 | Accuracy = 0.7857 |
| L=40 | K=19 | Accuracy = 0.7143 |
| L=75 | K=21 | Accuracy = 0.7143 |
| L=40 | K=07 | Accuracy = 0.6786 |
| L=40 | K=21 | Accuracy = 0.6786 |
| L=40 | K=17 | Accuracy = 0.6429 |
| L=45 | K=15 | Accuracy = 0.6429 |
| L=65 | K=07 | Accuracy = 0.6429 |
| L=90 | K=05 | Accuracy = 0.6429 |
| L=90 | K=07 | Accuracy = 0.6429 |
| L=90 | K=09 | Accuracy = 0.6429 |
| L=15 | K=13 | Accuracy = 0.6071 |
| L=15 | K=17 | Accuracy = 0.6071 |
| L=35 | K=07 | Accuracy = 0.6071 |
| L=45 | K=17 | Accuracy = 0.6071 |
-----------------------------------
