In [61]:
import os
import time

import h5py
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import torch.utils.data
import matplotlib.pyplot as plt
from tempfile import TemporaryDirectory

In [62]:
torch.cuda.is_available(), torch.cuda.get_device_name(torch.cuda.current_device())
device = torch.device('cuda')

In [63]:
labelsDict = {
    "rest": 0,
    "leftHand": 1,
    "rightHand": 2,
    "bothHands": 3,
    "bothFeet": 4
}

def norm_standardize(data):
    means = data.mean(axis=1, keepdims=True)
    stds = data.std(axis=1, keepdims=True)
    if not stds.all():
        print(f"first std is {stds[0]}")
        print(stds[0].shape)
        print("the 1st fucking data is:")
        print(data[0])
        print(data[0].shape)
        
    return (data - means) / stds
    
class CustomEEGDataset(torch.utils.data.Dataset):
    def __init__(self, h5_dir, transform=None, target_transform=None):
        self.transform = transform
        self.target_transform = target_transform
        
        # Load h5 file and get class names
        self.hf = h5py.File(h5_dir, 'r')
        print("Opened h5py file")
        print(f'Keys: {[key for key in self.hf.keys()]}')
        print(f'Samples shape: {self.hf["Samples"].shape}')
        print(f'Classes shape: {self.hf["Classes"].shape}')
        
    def __len__(self):
        return len(self.hf['Samples'])

    def __getitem__(self, idx):
        # idx is a list [start -> start + batch_size]
        return self.transform(self.hf['Samples'][idx]), self.hf['Classes'][idx].argmax(axis=1)
    
    def size(self):
        return self.__len__()
            
    def close(self):
        self.hf.close()
        print("h5 File Closed!")

In [64]:
class RandomBatchSampler(torch.utils.data.Sampler):
    """
    Sampling class to create random sequential batches for weak shuffling
    """
    
    def __init__(self, dataset, batch_size):
        self.batch_size = batch_size
        self.dataset_length = len(dataset)
        self.n_batches = self.dataset_length / self.batch_size
        self.batch_ids = torch.randperm(int(self.n_batches))
    
    def __len__(self):
        return self.batch_size
    
    def __iter__(self):
        # Yield list of indices for that particular batch
        for x in self.batch_ids:
            idx = torch.arange(x * self.batch_size, (x + 1) * self.batch_size)
            for index in idx:
                yield int(index)
                
        # Last batch is smaller than batch_size
        if int(self.n_batches) < self.n_batches:
            idx = torch.arange(int(self.n_batches) * self.batch_size, self.dataset_length)
            for index in idx:
                yield int(index)

def fast_loader(dataset, batch_size, drop_last=False, transforms=None):
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=None,  # None when using samplers
        sampler=torch.utils.data.BatchSampler(RandomBatchSampler(dataset, batch_size), batch_size=batch_size, drop_last=drop_last)
    )

In [65]:
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        
        # Block 1
        self.block1 = nn.Sequential(
            # 64 x 81 x 31
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Block 2
        self.block2 = nn.Sequential(
            # 64 x 40 x 15 
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Block 3
        self.block3 = nn.Sequential(
            # 128 x 20 x 7
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )

        # Linear
        self.linear1 = nn.Flatten()
        self.linear2 = nn.Linear(in_features=7680, out_features=100)  
        self.linear3 = nn.Linear(in_features=100, out_features=5)  # 5 classes
        self.softmax = nn.Softmax()  

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        x = self.softmax(x)
        return x


In [66]:
class EEGModel:
    def __init__(self, network):
        self.model = network
        
        # Hyper-parameters
        self.optimizer = None 
        self.criterion = None
        
        # Data
        self.dataloaders = dict()
        self.dataset_size = dict()
    
    def set_dataloaders(self, train_loader, val_loader, train_size, val_size):
        self.dataloaders['train'] = train_loader
        self.dataloaders['val'] = val_loader
        self.dataset_size['train'] = train_size
        self.dataset_size['val'] = val_size
    
    def get_trainable_parameters(self):
        return self.model.parameters()
    
    def train_model(self, epochs): 
        # Send to GPU
        self.model = self.model.to(device)
        
        start_time = time.time()
        
        # Only stores the weights with best training accuracy
        hist = {
            'epoch_loss': [],
            'epoch_acc': []
        }
        
        # Saves training checkpoints to a temp directory
        with TemporaryDirectory() as tempdir:
            best_model_params_path = os.path.join(tempdir, 'best_model_params.pt')
            torch.save(self.model.state_dict(), best_model_params_path)
            
            # Looping through epochs
            for epoch in range(epochs):
                print(f'Epoch {epoch}/{epochs - 1}')
                print('-' * 10)

                # Each epoch has a training and validation phase
                for phase in ['train', 'val']:
                    if phase == 'train':
                        self.model.train()  
                    else:
                        self.model.eval()  
                        
                    running_loss = 0.0
                    running_corrects = 0
                    
                    # Iterate over samples
                    s = time.time()
                    for inputs, labels in self.dataloaders[phase]:
                        # Inputs are batched according to batch size
                        shape = inputs.shape
                        inputs = inputs.reshape([shape[0], shape[3], shape[1], shape[2]]) # PyTorch is channels first
                        inputs = inputs.to(device)
                        labels = labels.to(device)
                        
                        # Zero grad
                        self.optimizer.zero_grad()
                    
                        # Forward pass
                        with torch.set_grad_enabled(phase == 'train'):
                            outputs = self.model(inputs)
                            _, preds = torch.max(outputs, 1)
                            loss = self.criterion(outputs, labels)
                            
                            # Backwards pass and step optimizer if training phase
                            if phase == 'train':
                                loss.backward()
                                self.optimizer.step()
                        
                            # Running stats
                            running_loss += loss.item() * inputs.size(0)
                            running_corrects += torch.sum(preds == labels)
                        
                    # Epoch stats
                    hist['epoch_loss'].append(running_loss / self.dataset_size[phase])
                    print(f'Total running corrects: {running_corrects.double()}')
                    hist['epoch_acc'].append(running_corrects.double() / self.dataset_size[phase])
                    
                    print(f'Epoch Done after: {time.time() - s}')
                    print(f'{phase} Loss: {hist["epoch_loss"][-1]:.4f} Acc: {hist["epoch_acc"][-1]:.6f}')
                
                # Store weights if this is best so far
                if hist['epoch_acc'][-1] > max(hist['epoch_acc']):
                    print("Model saved!")
                    torch.save(self.model.state_dict(), best_model_params_path)
                        
            time_elapsed = time.time() - start_time
            print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
            print(f'Best Validation Accuracy: {max(hist["epoch_acc"]):4f}')
            
            # Take the model with best val accuracy
            self.model.load_state_dict(torch.load(best_model_params_path))
        
        # Save model
        torch.save(self.model.state_dict(), "final_model.pt")
        return hist

In [67]:
train_dataset = CustomEEGDataset("h5py/train.h5", transform=norm_standardize)
val_dataset = CustomEEGDataset("h5py/val.h5", transform=norm_standardize)

Opened h5py file
Keys: ['Classes', 'Samples']
Samples shape: (35080, 81, 31, 64)
Classes shape: (35080, 5)
Opened h5py file
Keys: ['Classes', 'Samples']
Samples shape: (7185, 81, 31, 64)
Classes shape: (7185, 5)


In [68]:
# Dataloader object
BATCH_SIZE = 256
train_loader = fast_loader(train_dataset, BATCH_SIZE)
val_loader = fast_loader(val_dataset, BATCH_SIZE)

# Train dataloader and validation dataloader combined as a dict
dataloaders = {
    'train': train_loader,
    'val': val_loader
}

In [69]:
net = Network()

In [70]:
# Create model object for training
model = EEGModel(net)
model.set_dataloaders(train_loader, val_loader, train_dataset.size(), val_dataset.size())

# Hyperparameters
WEIGHTS = [0.01, 1, 1, 1, 1]
class_weights = torch.FloatTensor(WEIGHTS).to(device)
LR = 0.005
MOMENTUM = 0.9
CRITERION = nn.CrossEntropyLoss(weight=class_weights)
OPTIMIZER = optim.SGD(model.get_trainable_parameters(), LR, 0.9)

model.criterion = CRITERION
model.optimizer = OPTIMIZER

In [71]:
model.dataset_size

{'train': 35080, 'val': 7185}

In [72]:
history = model.train_model(10)

Epoch 0/9
----------
first std is [[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
(1, 31, 64)
the 1st fucking data is:
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 ...

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 

  return (data - means) / stds


first std is [[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]]
(1, 31, 64)
the 1st fucking data is:
[[[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 ...

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]]

 [[0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ... 0. 0. 0.]
  ...
  [0. 0. 0. ... 0. 0. 0.]
  [0. 0. 0. ..

KeyboardInterrupt: 