In [1]:
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_metric_learning.losses as pml_losses

In [2]:
with open('eeg/digit/data.pkl', 'rb') as f:
	data = pickle.load(f, encoding='latin1')

In [3]:
x_eval = data['x_train'][:4000]
y_eval = data['y_train'][:4000]
x_train = data['x_train'][4000:]
y_train = data['y_train'][4000:]
x_test = data['x_test']
y_test = data['y_test']

In [4]:
class EEGDataset(torch.utils.data.Dataset):
	def __init__(self, data, labels):
		self.data = data
		self.labels = labels
	
	def __len__(self):
		return len(self.data)
	
	def __getitem__(self, idx):
		return torch.tensor(self.data[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)

In [5]:
class LSTM(nn.Module):
    def __init__(self, num_classes, input_size, hidden_size, num_layers):
        super(LSTM, self).__init__()
        self.bn1 = nn.BatchNorm2d(14)
        self.lstm = nn.LSTM(
            input_size=input_size, 
            hidden_size=hidden_size, 
            num_layers=num_layers, 
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        """
        Shape of x: (N, T, C) 
        N: batch_size, T: seq_len(time), C: input_size(features/channels)
        """
        x = self.bn1(x)
        x = x.squeeze()
        x = x.permute(0, 2, 1)
        lstm_out, _ = self.lstm(x)
        logits = self.fc(lstm_out[:, -1, :])
        return logits

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
train_loader = DataLoader(EEGDataset(x_train, y_train), batch_size=64, shuffle=True)
eval_loader = DataLoader(EEGDataset(x_eval, y_eval), batch_size=64, shuffle=False)
test_loader = DataLoader(EEGDataset(x_test, y_test), batch_size=64, shuffle=False)

In [8]:
model = LSTM(10, 32, 128, 2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

In [9]:
def train(model, loader, optimizer):
	model.train()
	total_loss = 0.0
	for (x, y) in loader:
		x, y = x.to(device), y.to(device)
		optimizer.zero_grad()
		y_pred = model(x)
		loss = criterion(y_pred, y.argmax(dim=1))
		loss.backward()
		optimizer.step()
		total_loss += loss.item()
	return total_loss / len(loader)

def evaluate(model, loader):
	model.eval()
	total_loss = 0.0
	correct = 0
	total = 0
	with torch.no_grad():
		for (x, y) in loader:
			x, y = x.to(device), y.to(device)
			y_pred = model(x)
			loss = criterion(y_pred, y.argmax(dim=1))
			total_loss += loss.item()
			_, predicted = y_pred.max(1)
			correct += predicted.eq(y.argmax(dim=1)).sum().item()
			total += y.size(0)
	return total_loss / len(loader), correct / total

In [10]:
for epoch in range(20):
	train_loss = train(model, train_loader, optimizer)
	eval_loss, eval_acc = evaluate(model, eval_loader)
	test_loss, test_acc = evaluate(model, test_loader)
	print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}, Eval Acc: {eval_acc:.4f}, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')

Epoch 0, Train Loss: 1.8238, Eval Loss: 1.4408, Eval Acc: 0.5030, Test Loss: 2.0237, Test Acc: 0.3391
Epoch 1, Train Loss: 1.3076, Eval Loss: 1.0357, Eval Acc: 0.6767, Test Loss: 1.8265, Test Acc: 0.4502
Epoch 2, Train Loss: 0.9543, Eval Loss: 0.6379, Eval Acc: 0.8095, Test Loss: 1.6281, Test Acc: 0.5401
Epoch 3, Train Loss: 0.7087, Eval Loss: 0.3717, Eval Acc: 0.8938, Test Loss: 1.5109, Test Acc: 0.6179
Epoch 4, Train Loss: 0.5058, Eval Loss: 0.2069, Eval Acc: 0.9417, Test Loss: 1.4263, Test Acc: 0.6693
Epoch 5, Train Loss: 0.4073, Eval Loss: 0.1549, Eval Acc: 0.9517, Test Loss: 1.3323, Test Acc: 0.6895
Epoch 6, Train Loss: 0.3171, Eval Loss: 0.1659, Eval Acc: 0.9535, Test Loss: 1.4954, Test Acc: 0.6755
Epoch 7, Train Loss: 0.2562, Eval Loss: 0.1033, Eval Acc: 0.9730, Test Loss: 1.4243, Test Acc: 0.6985
Epoch 8, Train Loss: 0.2228, Eval Loss: 0.1253, Eval Acc: 0.9647, Test Loss: 1.4875, Test Acc: 0.6957
Epoch 9, Train Loss: 0.2129, Eval Loss: 0.1123, Eval Acc: 0.9700, Test Loss: 1.480

In [11]:
# Save model
torch.save(model.state_dict(), 'eeg_lstm.pth')