<a href="https://colab.research.google.com/github/110805/EEG_classification/blob/master/EEG_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [44]:
import dataloader 
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Hyperparameter setting
batch_size = 64
learning_rate = 0.01
epochs = 300
momentum = 0.9

# Loading data
train_data, train_label, test_data, test_label = dataloader.read_bci_data()
train_data = torch.from_numpy(train_data)
train_label = torch.from_numpy(train_label)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label)
train_loader = DataLoader(TensorDataset(train_data, train_label), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(test_data, test_label), batch_size= batch_size)

class EEGNet(nn.Module):
    def __init__(self):
        super(EEGNet, self).__init__()

        # firstconv
        self.conv1 = nn.Conv2d(1, 16, kernel_size=(1,51), stride=(1,1), padding=(0,25), bias=False)
        self.batchnorm1 = nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # depthwiseconv
        self.conv2 = nn.Conv2d(16, 32, kernel_size=(2,1), stride=(1,1), groups=16, bias=False)
        self.batchnorm2 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.avgpool1 = nn.AvgPool2d(kernel_size=(1,4), stride=(1,4), padding=0)

        # separableconv
        self.conv3 = nn.Conv2d(32, 32, kernel_size=(1,15), stride=(1,1), padding=(0,7), bias=False)
        self.batchnorm3 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.avgpool2 = nn.AvgPool2d(kernel_size=(1,8), stride=(1,8), padding=0)

        # classify
        self.linear1 = nn.Linear(736, 2, bias=True)

    def forward(self, x):
        # firstconv
        out = self.conv1(x)
        out = self.batchnorm1(out)

        # depthwiseconv
        out = self.conv2(out)
        out = F.relu(self.batchnorm2(out))
        out = F.dropout(self.avgpool1(out), p=0.25)

        # separableconv
        out = self.conv3(out)
        out = F.relu(self.batchnorm3(out))
        out = F.dropout(self.avgpool2(out), p=0.25)

        # classify
        out = out.view(-1,736)
        out = self.linear1(out)

        return out

model = EEGNet()
Loss = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
#optimizer = optim.Adam(model.parameters(), lr=learning_rate)

device = torch.device('cuda')
model.to(device)

def train(epoch):
    model.train() # switch to train mode
    correct = 0
    for x_batch, y_batch in train_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        y_pred = model(x_batch.float())
        loss = Loss(y_pred, y_batch.long())
        loss.backward()
        optimizer.step()
        pred = torch.argmax(y_pred, dim=1)
        correct += torch.sum(pred == y_batch.int())

    if (epoch+1)%30 == 0:
        print('Train epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct.item()/len(test_label)))

def test(epoch):
    model.eval()
    correct = 0
    for x_batch, y_batch in test_loader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        with torch.no_grad():
            y_pred = model(x_batch.float())
        
        pred = torch.argmax(y_pred, dim=1)
        correct += torch.sum(pred == y_batch.int())
    
    if (epoch+1)%30 == 0:
        print('Test epoch {} Accuracy: {:.2f}%'.format(epoch+1, 100*correct.item()/len(test_label)))

for epoch in range(epochs):
    train(epoch)
    test(epoch)
    optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.1)

(1080, 1, 2, 750) (1080,) (1080, 1, 2, 750) (1080,)
Train epoch 30 Accuracy: 87.41%
Test epoch 30 Accuracy: 77.69%
Train epoch 60 Accuracy: 94.91%
Test epoch 60 Accuracy: 80.56%
Train epoch 90 Accuracy: 96.30%
Test epoch 90 Accuracy: 79.91%
Train epoch 120 Accuracy: 98.15%
Test epoch 120 Accuracy: 78.89%
Train epoch 150 Accuracy: 98.06%
Test epoch 150 Accuracy: 79.72%
Train epoch 180 Accuracy: 98.06%
Test epoch 180 Accuracy: 80.09%
Train epoch 210 Accuracy: 98.43%
Test epoch 210 Accuracy: 80.09%
Train epoch 240 Accuracy: 99.07%
Test epoch 240 Accuracy: 80.37%
Train epoch 270 Accuracy: 99.35%
Test epoch 270 Accuracy: 80.09%
Train epoch 300 Accuracy: 99.35%
Test epoch 300 Accuracy: 81.30%
