<a href="https://colab.research.google.com/github/110805/EEG_classification/blob/master/EEG_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import dataloader 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from sklearn.metrics import accuracy_score
from torch.utils.data import TensorDataset, DataLoader

# Hyperparameter setting
batch_size = 64
learning_rate = 1e-2
epochs = 300

# Loading data
train_data, train_label, test_data, test_label = dataloader.read_bci_data()
train_data = torch.from_numpy(train_data)
train_label = torch.from_numpy(train_label)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label)
train_loader = DataLoader(TensorDataset(train_data, train_label), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(test_data, test_label), batch_size= batch_size)

class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()

        # firstconv
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1,51), stride=(1,1), padding=(0,25), bias=False)
        self.batchnorm1 = nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # depthwiseconv
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(2,1), stride=(1,1), groups=16, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.avgpool1 = nn.AvgPool2d(kernel_size=(1,4), stride=(1,4), padding=0)

        # separableconv
        self.conv3 = nn.Conv2d(32, 32, kernel_size=(1,15), stride=(1,1), padding=(0,7), bias=False)
        self.batchnorm3 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.avgpool2 = nn.AvgPool2d(kernel_size=(1,8), stride=(1,8), padding=0)

        # classify
        self.linear1 = nn.Linear(736, 2, bias=True)

    def forward(self, x):
        # firstconv
        out = self.conv1(x)
        out = self.batchnorm1(out)

        # depthwiseconv
        out = self.conv2(out)
        out = F.elu(self.batchnorm2(out), alpha=1.0)
        out = F.dropout(self.avgpool1(out), p=0.25)

        # separableconv
        out = self.conv3(out)
        out = F.elu(self.batchnorm3(out), alpha=1.0)
        out = F.dropout(self.avgpool2(out), p=0.25)

        # classify
        out = out.view(-1,736)
        out = self.linear1(out)

        return out

model = EEGNet()
Loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

def train(epoch):
    model.train() # switch to train mode
    
    correct = 0
    for x_batch, y_batch in train_loader:
        optimizer.zero_grad()
        y_pred = model(x_batch.float())
        loss = Loss(y_pred, y_batch.long())
        loss.backward()
        optimizer.step()
        pred = torch.argmax(y_pred, dim=1)
        correct += torch.sum(pred == y_batch.int())

    print('Train epoch {} Accuracy: {:.2f}%'.format(epoch, 100*correct.item()/len(test_label)))

def test(epoch):
    model.eval()

    correct = 0
    for x_batch, y_batch in test_loader:
        with torch.no_grad():
            y_pred = model(x_batch.float())
        
        pred = torch.argmax(y_pred, dim=1)
        correct += torch.sum(pred == y_batch.int())
    
    print('Test epoch {} Accuracy: {:.2f}%'.format(epoch, 100*correct.item()/len(test_label)))

for epoch in range(1, epochs+1):
    train(epoch)
    test(epoch)

(1080, 1, 2, 750) (1080,) (1080, 1, 2, 750) (1080,)
Train epoch 1 Accuracy: 67.69%
Test epoch 1 Accuracy: 58.89%
Train epoch 2 Accuracy: 67.59%
Test epoch 2 Accuracy: 69.91%
Train epoch 3 Accuracy: 69.63%
Test epoch 3 Accuracy: 70.37%
Train epoch 4 Accuracy: 72.50%
Test epoch 4 Accuracy: 69.81%
Train epoch 5 Accuracy: 73.61%
Test epoch 5 Accuracy: 70.93%
Train epoch 6 Accuracy: 74.91%
Test epoch 6 Accuracy: 72.22%
Train epoch 7 Accuracy: 75.09%
Test epoch 7 Accuracy: 72.04%
Train epoch 8 Accuracy: 74.63%
Test epoch 8 Accuracy: 72.22%
Train epoch 9 Accuracy: 75.74%
Test epoch 9 Accuracy: 72.04%
Train epoch 10 Accuracy: 76.48%
Test epoch 10 Accuracy: 71.20%
Train epoch 11 Accuracy: 75.83%
Test epoch 11 Accuracy: 72.59%
