<a href="https://colab.research.google.com/github/JessalynWang/neurotechML/blob/master/EEGNet_current.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install mne



In [2]:
import numpy as np

from sklearn.metrics import roc_auc_score, precision_score, recall_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable, gradcheck
from torch.utils.data import TensorDataset, DataLoader

import pandas as pd

from matplotlib import pyplot

import mne
from mne.io import concatenate_raws, read_raw_fif
import mne.viz

import math

from os import walk

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
data_file = '/content/drive/My Drive/data/P-09.fif'

epochs = mne.read_epochs(data_file, verbose='error')
print(epochs)

<EpochsFIF  |   380 events (all good), 0 - 1.49609 sec, baseline off, ~71.4 MB, data loaded,
 'FN': 109
 'FP': 100
 'FU': 83
 'NN': 32
 'NP': 35
 'NU': 21>


In [5]:
epochs_UN = epochs['FU', 'FN'] # Unpleasant vs. Neutral
epochs_UP = epochs['FU', 'FP'] # Unpleasant vs. Pleasant
epochs_NP = epochs['FN', 'FP'] # Neutral vs. Pleasant

# Dataset with unpleasant and neutral events
print(epochs_UN)
data_UN = epochs_UN.get_data() #we will classify between unpleasant and neutral
labels_UN = epochs_UN.events[:,-1]
print(len(labels_UN))

<EpochsFIF  |   192 events (all good), 0 - 1.49609 sec, baseline off, ~36.2 MB, data loaded,
 'FN': 109
 'FU': 83>
192


In [6]:
train_data_UN, test_data_UN, labels_train_UN, labels_test_UN = train_test_split(data_UN, labels_UN, test_size=0.3, random_state=42)

In [7]:
print(labels_train_UN.shape, labels_test_UN.shape, train_data_UN.shape[-1])
chunk_train = labels_train_UN.shape[0]
chunk_test = labels_test_UN.shape[0]
channels = train_data_UN.shape[1]
timepoints = train_data_UN.shape[2]


(134,) (58,) 384


In [8]:
BATCH_SIZE = 35

eeg_data_scaler = StandardScaler()

X_train = eeg_data_scaler.fit_transform(train_data_UN.reshape(-1, train_data_UN.shape[-1])).reshape(train_data_UN.shape)
X_test = eeg_data_scaler.fit_transform(test_data_UN.reshape(-1, test_data_UN.shape[-1])).reshape(test_data_UN.shape)

labels_train_UN = np.array([1 if x > 0 else 0 for x in labels_train_UN])
labels_test_UN = np.array([1 if x > 0 else 0 for x in labels_test_UN])

labels_train_UN = labels_train_UN.reshape((chunk_train, 1))
labels_train_UN = labels_train_UN.astype(np.float32)
X_actual = torch.from_numpy(labels_train_UN)

labels_test_UN = labels_test_UN.reshape((chunk_test, 1))
labels_test_UN = labels_test_UN.astype(np.float32)
X_test_actual = torch.from_numpy(labels_test_UN)

X_train = torch.from_numpy(X_train)
X_train = X_train.unsqueeze(1)
X_test = torch.from_numpy(X_test)
X_test = X_test.unsqueeze(1)

X_list = [0] * (math.ceil(X_train.shape[0] / BATCH_SIZE))
for i in range(len(X_list)):
    a, b = BATCH_SIZE * i, BATCH_SIZE * (i + 1)
    if i != len(X_list) - 1:
        X_list[i] = (X_train[a:b, :, : ], X_actual[a:b, :])
    else:
        X_list[i] = (X_train[a:, :, : ], X_actual[a:, :])


print(X_train.shape, X_actual.shape, X_test.shape)

torch.Size([134, 1, 64, 384]) torch.Size([134, 1]) torch.Size([58, 1, 64, 384])


In [9]:
freq, avg1stride, avg2stride = 256, (1, 4), (1, 8)
convstride = 1
conv1_neurons = 4
conv2_neurons = 8
conv3_neurons = 4
flat1_out = 12
kern1size = freq // 2

In [10]:
conv1outx, conv1outy = (channels, (timepoints - kern1size)/convstride + 1)

conv2outx, conv2outy = ((conv1outx - channels)/convstride + 1, conv1outy)
conv2outx, conv2outy = conv2outx // avg1stride[0], conv2outy // avg1stride[1]

conv3outx, conv3outy = (conv2outx, (conv2outy - 16)/convstride + 1)
conv3outx, conv3outy = (conv3outx // avg2stride[0], conv3outy // avg2stride[1])
flat1_in = int(conv3outx * conv3outy * conv3_neurons)

In [11]:
CNNPoor = nn.Sequential(
    nn.Conv2d(1, conv1_neurons, (1, kern1size)),
    nn.ELU(),
    nn.BatchNorm2d(conv1_neurons, False),
    
    nn.Conv2d(conv1_neurons, conv2_neurons, (channels, 1)),
    nn.ELU(),
    nn.BatchNorm2d(conv2_neurons, False),
    nn.AvgPool2d(avg1stride),
    nn.Dropout(),
    
    nn.Conv2d(conv2_neurons, conv3_neurons, (1, 16)),
    nn.ELU(),
    nn.BatchNorm2d(conv3_neurons, False),
    nn.AvgPool2d(avg2stride),
    nn.Dropout(),
    
    nn.Flatten(),

    nn.Linear(flat1_in, flat1_out),
    nn.ELU(),
    nn.Linear(flat1_out, 1),
    nn.Sigmoid(),
)

CNNPoor.to(device)

Sequential(
  (0): Conv2d(1, 4, kernel_size=(1, 128), stride=(1, 1))
  (1): ELU(alpha=1.0)
  (2): BatchNorm2d(4, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (3): Conv2d(4, 8, kernel_size=(64, 1), stride=(1, 1))
  (4): ELU(alpha=1.0)
  (5): BatchNorm2d(8, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (6): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
  (7): Dropout(p=0.5, inplace=False)
  (8): Conv2d(8, 4, kernel_size=(1, 16), stride=(1, 1))
  (9): ELU(alpha=1.0)
  (10): BatchNorm2d(4, eps=False, momentum=0.1, affine=True, track_running_stats=True)
  (11): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
  (12): Dropout(p=0.5, inplace=False)
  (13): Flatten()
  (14): Linear(in_features=24, out_features=12, bias=True)
  (15): ELU(alpha=1.0)
  (16): Linear(in_features=12, out_features=1, bias=True)
  (17): Sigmoid()
)

In [12]:
loss_function = nn.BCELoss()
optimizer = optim.Adam(CNNPoor.parameters(), lr = 0.001)

In [13]:
def evaluate(model, data):
    
    with torch.no_grad():
        CNNPoor.eval()
        content, labels = data
        pred = model(content)
        pred = pred.numpy()
    return accuracy_score(labels, np.round(pred))

In [14]:
for i in range(25):
    print("Epoch: ", i)
    tot_loss = 0.0
    CNNPoor.train()

    for j in range(math.ceil(X_train.shape[0] / BATCH_SIZE)):
        data, labels = X_list[j]
        data, labels = Variable(data.float()), Variable(labels)
        data.to(device)
        labels.to(device)
        
        optimizer.zero_grad()
        
        classification = CNNPoor(data)
        loss = loss_function(classification, labels)
        loss.backward()
        
        optimizer.step()

        tot_loss += loss.item()
    print("Total loss = ", tot_loss)
    print("Train accuracy = ", evaluate(CNNPoor, (X_train.float(), X_actual)))
    print("Test accuracy = ", evaluate(CNNPoor, (X_test.float(), X_test_actual)))

Epoch:  0
Total loss =  2.661384105682373
Train accuracy =  0.5895522388059702
Test accuracy =  0.5172413793103449
Epoch:  1
Total loss =  2.716046452522278
Train accuracy =  0.6119402985074627
Test accuracy =  0.5172413793103449
Epoch:  2
Total loss =  2.688689708709717
Train accuracy =  0.6716417910447762
Test accuracy =  0.5172413793103449
Epoch:  3
Total loss =  2.645352780818939
Train accuracy =  0.7014925373134329
Test accuracy =  0.5344827586206896
Epoch:  4
Total loss =  2.5606635212898254
Train accuracy =  0.7388059701492538
Test accuracy =  0.5172413793103449
Epoch:  5
Total loss =  2.499046564102173
Train accuracy =  0.753731343283582
Test accuracy =  0.5
Epoch:  6
Total loss =  2.462516725063324
Train accuracy =  0.7761194029850746
Test accuracy =  0.5
Epoch:  7
Total loss =  2.3937642574310303
Train accuracy =  0.7985074626865671
Test accuracy =  0.5172413793103449
Epoch:  8
Total loss =  2.3383386731147766
Train accuracy =  0.8134328358208955
Test accuracy =  0.5172413793