In [2]:
from models import EEGNet
from dataloader import read_bci_data
import torch
import torch.nn as nn
from torch.utils.data import DataLoader,TensorDataset
import torch.optim as optim

In [3]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
def seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed(500)

In [5]:
train_data, train_label, test_data, test_label = read_bci_data()

train_data = torch.Tensor(train_data)
train_label = torch.Tensor(train_label)

test_data = torch.Tensor(test_data)
test_label = torch.Tensor(test_label)

train = TensorDataset(train_data, train_label)
train_dataload = DataLoader(train,batch_size=64, shuffle=True)

test = TensorDataset(test_data, test_label)
test_dataload = DataLoader(test,batch_size=64, shuffle=False)

## EEGNet Training

In [5]:
## EEGNet training

model=EEGNet()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.006, weight_decay=0.003)
loss_function = torch.nn.CrossEntropyLoss()

for epoch in range(300):
    
    train_loss = 0.0
    correct = 0
    
    for i, (data, label) in enumerate(train_dataload):
        label = label.type(torch.LongTensor)
        data = data.to(device)
        label = label.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        predicted = torch.max(output.data, 1)[1]
        correct += (predicted == label).sum().item()
        train_loss += loss.item()
    
    if((epoch+1)%100==0):
        print("-------{}th epoch-------".format(epoch+1))
        print("EEGNet Accuracy: ")
        print(correct/1080)
        print("Training Loss: ")
        print(train_loss)

-------100th epoch-------
EEGNet Accuracy: 
0.862037037037037
Training Loss: 
5.739576026797295
-------200th epoch-------
EEGNet Accuracy: 
0.8657407407407407
Training Loss: 
4.786291986703873
-------300th epoch-------
EEGNet Accuracy: 
0.8712962962962963
Training Loss: 
4.707398787140846


## EEGNet Testing

In [6]:
## EEGNet Testing
model.eval()
with torch.no_grad():
    
    correct = 0
    
    for i, (data, label) in enumerate(test_dataload):
        
        label = label.type(torch.LongTensor)
        data = data.to(device)
        label = label.to(device)
        
        output = model(data)
        
        predicted = torch.max(output.data, 1)[1]
        correct += (predicted == label).sum().item()
    
    acc = correct/1080
    print("EEGNet Accuracy: ")
    print(acc)


EEGNet Accuracy: 
0.8083333333333333


## Best EEGNet Model

In [10]:
## Best EEGNet Model
model = torch.load('EEG_model.pt')
model.to(device)
model.eval()

with torch.no_grad():
    
                correct = 0
    
                for iteration, (data, label) in enumerate(test_dataload):
        
                    label = label.type(torch.LongTensor)
                    data = data.to(device)
                    label = label.to(device)
        
                    output = model(data)
        
                    predicted = torch.max(output.data, 1)[1]
                    correct += (predicted == label).sum().item()
    
                acc = correct/1080
                print("EEG Accuracy: ")
                print(acc)


EEG Accuracy: 
0.8712962962962963
