# COGS 189 Final Project 

**Team Name**  
Argus

**Team Members**  
Xiaolong Huang	(xih002@ucsd.edu),  
Bohan Lei		(blei@ucsd.edu),  
Mingze Xu		(m6xu@ucsd.edu),  
Weijia Zeng		(wezeng@ucsd.edu)     

**Project Goal**
- To fully understand and implement the vanilla Common Spatial Pattern (CSP) algorithm for EEG pattern recognition.
- To experiment with different types of Common Spatial Pattern (CSP) and Linear Discriminant Analysis (LDA) algorithms in motor imagery signal pattern classification.
- To improve motor imagery signal pattern classification accuracy by different CSP and LDA hyperparameter configurations.

**Sources**
- Data source: https://github.com/bregydoc/bcidatasetIV2a
- Data explanation: https://www.bbci.de/competition/iv/desc_2a.pdf
- CSP paper: https://ieeexplore.ieee.org/document/4408441
- SACSP paper: https://cogsci.ucsd.edu/~desa/Winter_conference_on_bci_2022.pdf

## Section 1: Setup

In [2]:
import numpy as np
from scipy.linalg import eigh
from scipy.signal import butter, sosfiltfilt, sosfreqz
import matplotlib.pyplot as plt

## Section 2: Data Import

In [3]:
n_subject = 9

data_dir = './data/'
data_prefix = 'A'
data_suffix = '.npz'
data_type = {'train':'T', 'test':'E'}
data_path = data_dir + data_prefix + '{subject:02d}{type_:s}' + data_suffix

In [4]:
raw_train_subject = []
raw_test_subject = []

for subject_num in range(1, n_subject+1):
    train_path = data_path.format(subject=subject_num, type_=data_type['train'])
    test_path  = data_path.format(subject=subject_num, type_=data_type['test'])
    raw_train_subject.append(np.load(train_path))
    raw_test_subject.append(np.load(test_path))

In [5]:
class_code = {'left':769, 'right':770, 'foot':771, 'tongue':772}

In [6]:
# globally
fs = 250                                    # sampling frequency

# within in each epoch
cue_start = 2.                              # cue starts at
cue_end = 3.25                              # cue ends at
smr_start = 3.                              # smr starts after
smr_end = 6.                                # smr ends after
cue_start_offset = round(cue_start * fs)    # the first cue sample is at 
cue_end_offset   = round(cue_end * fs)      # the last cue sample is at
smr_start_offset = round(smr_start * fs)    # the first smr sample is after
smr_end_offset   = round(smr_end * fs)      # the last smr sample is after

In [7]:
train_subject = []
for raw_data in raw_train_subject: 
    signal = raw_data['s'].T
    cue_class = raw_data['etyp']
    cue_position = raw_data['epos']
    
    X_train = {}
    
    for name, code in class_code.items():
        epoch_start = cue_position[cue_class==code] - cue_start_offset
        
        epochs = []
        for i in range(len(epoch_start)):
            epochs.append(signal[:, epoch_start[i]+smr_start_offset:epoch_start[i]+smr_end_offset])
        epochs = np.stack(epochs)
        
        X_train[name] = epochs
    
    train_subject.append(X_train)

In [9]:
test_subject = []
for raw_data in raw_test_subject: 
    signal = raw_data['s'].T
    cue_class = raw_data['etyp']
    cue_position = raw_data['epos']
    
    X_test = {}
    name = 'unknown'
    code = 783
    
    epoch_start = cue_position[cue_class==code] - cue_start_offset
    epochs = []
    for i in range(len(epoch_start)):
        epochs.append(signal[:, epoch_start[i]+smr_start_offset:epoch_start[i]+smr_end_offset])
    epochs = np.stack(epochs)

    X_test[name] = epochs
    
    test_subject.append(X_test)

In [12]:
test_subject[0]['unknown']

array([[[  1.22070312,  -2.19726562,   2.83203125, ...,  -0.09765625,
          -1.953125  ,  -6.8359375 ],
        [  8.88671875,   2.49023438,   7.51953125, ...,   1.46484375,
           0.68359375,  -4.54101562],
        [  2.88085938,  -0.9765625 ,   3.02734375, ...,   3.125     ,
           1.80664062,  -5.22460938],
        ...,
        [ -1.953125  ,   1.953125  ,   3.90625   , ...,   0.        ,
          -6.34765625, -12.20703125],
        [  6.8359375 ,   7.32421875,  11.23046875, ...,  -3.90625   ,
          -2.9296875 ,  -2.44140625],
        [  7.32421875,   2.9296875 ,   9.27734375, ...,  -3.90625   ,
          -2.9296875 ,  -8.30078125]],

       [[  1.22070312,  -1.70898438,   0.34179688, ...,  -9.08203125,
          -9.66796875,  -8.15429688],
        [  2.34375   ,   0.1953125 ,   3.66210938, ...,  -4.54101562,
          -7.91015625,  -7.08007812],
        [  2.734375  ,  -0.87890625,   0.68359375, ...,  -9.22851562,
         -10.44921875,  -9.27734375],
        ...,
