# EEG Wheelchair Binary Classification ML Model
## Using 1D Convolutional Networks and Recurrent Networks
This is the first experiment with trying to convert 4 channel EEG brain wave data into a binary classification of stop and go for the EEG wheelchair control system.

## Overview
The following notebook will follow these next steps:
1. Data Cleaning
    - Get rid of inconsistent samples or bad samples
2. Data Preprocessing
    - Convert to wavelet transforms and take signal squeezed signals
3. Data Filtering
    - Filter out irrelevant frequencies and disconnections
4. Build Model
    - Form CNN-RNN network for prediction modelling
5. Model Training
    - Train model on training set of EEG Samples
6. Hyperparameter Optimization
    - Optimize model hyperparameters by cross validation 
7. Model Validation
    - Validate model on test set


![overview diagram](docs/model_overview_diagram.jpg)
![recurrent diagram](docs/recurrent_model_diagram.png)
![conv_diagram](docs/conv_model_diagram.png)

In [2]:
# Data management
import pandas as pd

# Data processing
from ssqueezepy import ssq_cwt

# Model training
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv1D, MaxPooling1D, BatchNormalization
from tensorflow.keras.layers import LSTM
import matplotlib.pyplot as plt



2023-01-25 16:34:57.718391: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Importing Data

In [23]:
# Read the EEG data from dat directory
test_file = pd.read_csv('data/GO/dk1_1673475938.csv', header=None)
test_file

Unnamed: 0,0,1,2,3,4
0,2023-01-11 17:25:39.613477,830.715088,837.294983,847.164856,834.005005
1,2023-01-11 17:25:39.615483,832.360046,858.679749,848.809875,825.780151
2,2023-01-11 17:25:39.615483,824.135132,847.164856,848.809875,832.360046
3,2023-01-11 17:25:39.615483,837.294983,848.809875,843.874939,840.584961
4,2023-01-11 17:25:39.615483,847.164856,853.744812,838.940002,834.005005
...,...,...,...,...,...
3108,2023-01-11 17:25:53.663147,852.099792,861.969727,861.969727,840.584961
3109,2023-01-11 17:25:53.663147,834.005005,845.519897,840.584961,837.294983
3110,2023-01-11 17:25:53.663147,847.164856,845.519897,835.650024,838.940002
3111,2023-01-11 17:25:53.664149,837.294983,850.454834,842.229919,830.715088


## Convert to wavelet transform signal squeezed

In [36]:
# Convert to numpy arrays
timestamps, tp9, af7, af8, tp10 = test_file.T.to_numpy()

# Get the frequency of the samples
start = datetime.strptime(str(timestamps[0]), "%Y-%m-%d %H:%M:%S.%f")
end = datetime.strptime(str(timestamps[-1]), "%Y-%m-%d %H:%M:%S.%f")
duration = (end-start).total_seconds()
fs = timestamps.size/duration

Twtp9, Wtp9, *_ = ssq_cwt(tp9, fs=fs)

float

array([['2023-01-11 17:25:39.613477', '2023-01-11 17:25:39.615483',
        '2023-01-11 17:25:39.615483', ..., '2023-01-11 17:25:53.663147',
        '2023-01-11 17:25:53.664149', '2023-01-11 17:25:53.724123'],
       [830.715087890625, 832.3600463867188, 824.1351318359375, ...,
        847.1648559570312, 837.2949829101562, 820.84521484375],
       [837.2949829101562, 858.6797485351562, 847.1648559570312, ...,
        845.5198974609375, 850.454833984375, 848.8098754882812],
       [847.1648559570312, 848.8098754882812, 848.8098754882812, ...,
        835.6500244140625, 842.2299194335938, 842.2299194335938],
       [834.0050048828125, 825.7801513671875, 832.3600463867188, ...,
        838.9400024414062, 830.715087890625, 834.0050048828125]],
      dtype=object)