In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [2]:
with open('eeg/digit/data.pkl', 'rb') as f:
	data = pickle.load(f, encoding='latin1')

In [3]:
x_eval = data['x_train'][:4000]
y_eval = data['y_train'][:4000]
x_train = data['x_train'][4000:]
y_train = data['y_train'][4000:]
x_test = data['x_test']
y_test = data['y_test']

In [4]:
class EEGDataset(torch.utils.data.Dataset):
	def __init__(self, data, labels):
		self.data = data
		self.labels = labels
	
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, idx):
		return torch.tensor(self.data[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)

In [5]:
class CNNModel(nn.Module):
	def __init__(self, channels, observations, num_classes):
		super(CNNModel, self).__init__()
		self.bn1 = nn.BatchNorm2d(channels)
		self.conv1 = nn.Conv2d(1, observations, kernel_size=(1, 4))
		self.relu1 = nn.ReLU()
		self.conv2 = nn.Conv2d(observations, 25, kernel_size=(channels, 1))
		self.relu2 = nn.ReLU()
		self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 3))
		self.conv3 = nn.Conv2d(1, 50, kernel_size=(4, 25))
		self.relu3 = nn.ReLU()
		self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 3))
		self.conv4 = nn.Conv2d(1, 100, kernel_size=(50, 2))
		self.flatten = nn.Flatten()
		self.bn2 = nn.BatchNorm1d(100)
		self.fc1 = nn.Linear(100, 100)
		self.bn3 = nn.BatchNorm1d(100)
		self.fc2 = nn.Linear(100, num_classes)

	def forward(self, x):
		x = self.bn1(x)
		x = x.permute(0, 3, 1, 2)
		x = self.conv1(x)
		x = self.relu1(x)
		x = self.conv2(x)
		x = self.relu2(x)
		x = self.maxpool1(x)
		x = x.permute(0, 2, 3, 1)
		x = self.conv3(x)
		x = self.relu3(x)
		x = x.permute(0, 3, 1, 2)
		x = self.maxpool2(x)
		x = self.conv4(x)
		x = self.flatten(x)
		x = self.bn2(x)
		x = self.fc1(x)
		x = self.bn3(x)
		x = self.fc2(x)
		return x

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
train_loader = DataLoader(EEGDataset(x_train, y_train), batch_size=64, shuffle=True)
eval_loader = DataLoader(EEGDataset(x_eval, y_eval), batch_size=64, shuffle=False)
test_loader = DataLoader(EEGDataset(x_test, y_test), batch_size=64, shuffle=False)

In [8]:
model = CNNModel(channels=14, observations=32, num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, nesterov=True)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

In [9]:
# Training
def train(model, loader, criterion, optimizer):
	model.train()
	running_loss = 0.0
	for i, (x, y) in enumerate(loader):
		x, y = x.to(device), y.to(device)
		optimizer.zero_grad()
		output = model(x)
		loss = criterion(output, y)
		loss.backward()
		optimizer.step()
		running_loss += loss.item()
	return running_loss / len(loader)

# Evaluation
def evaluate(model, loader, criterion):
	model.eval()
	running_loss = 0.0
	correct = 0
	total = 0
	with torch.no_grad():
		for i, (x, y) in enumerate(loader):
			x, y = x.to(device), y.to(device)
			output = model(x)
			loss = criterion(output, y)
			running_loss += loss.item()
			_, predicted = torch.max(output, 1)
			y = torch.argmax(y, 1)
			total += y.size(0)
			correct += (predicted == y).sum().item()
	return running_loss / len(loader), correct / total

In [10]:
for epoch in range(250):
	train_loss = train(model, train_loader, criterion, optimizer)
	scheduler.step()
	eval_loss, eval_acc = evaluate(model, eval_loader, criterion)
	if epoch % 5 == 0:
		test_loss, test_acc = evaluate(model, test_loader, criterion)
	print(f'Epoch {epoch+1:03d}: Train Loss = {train_loss:.4f}, Eval Loss = {eval_loss:.4f}, Eval Acc = {eval_acc:.4f}, Test Loss = {test_loss:.4f}, Test Acc = {test_acc:.4f}')

Epoch 001: Train Loss = 1.8770, Eval Loss = 1.7319, Eval Acc = 0.4208, Test Loss = 2.1771, Test Acc = 0.2788
Epoch 002: Train Loss = 1.4791, Eval Loss = 1.5489, Eval Acc = 0.4387, Test Loss = 2.1771, Test Acc = 0.2788
Epoch 003: Train Loss = 1.2534, Eval Loss = 1.1592, Eval Acc = 0.6210, Test Loss = 2.1771, Test Acc = 0.2788
Epoch 004: Train Loss = 1.0783, Eval Loss = 1.0372, Eval Acc = 0.6760, Test Loss = 2.1771, Test Acc = 0.2788
Epoch 005: Train Loss = 0.9648, Eval Loss = 0.8629, Eval Acc = 0.7585, Test Loss = 2.1771, Test Acc = 0.2788
Epoch 006: Train Loss = 0.8685, Eval Loss = 0.8489, Eval Acc = 0.7360, Test Loss = 1.9635, Test Acc = 0.4679
Epoch 007: Train Loss = 0.7876, Eval Loss = 0.7131, Eval Acc = 0.7825, Test Loss = 1.9635, Test Acc = 0.4679
Epoch 008: Train Loss = 0.7109, Eval Loss = 0.6462, Eval Acc = 0.8140, Test Loss = 1.9635, Test Acc = 0.4679
Epoch 009: Train Loss = 0.6773, Eval Loss = 0.5556, Eval Acc = 0.8415, Test Loss = 1.9635, Test Acc = 0.4679
Epoch 010: Train Lo

In [11]:
# Save model
torch.save(model.state_dict(), 'eeg_cnn.pth')