In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
from torchvision import transforms, utils
import time
# get the device type of machine
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from torchsummary import summary
from eeg_net.solver import * 
%load_ext autoreload
%autoreload 2
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['figure.dpi'] = 200

X_test = np.load("data/X_test.npy")
y_test = np.load("data/y_test.npy")
person_train_valid = np.load("data/person_train_valid.npy")
X_train_valid = np.load("data/X_train_valid.npy")
y_train_valid = np.load("data/y_train_valid.npy")
person_test = np.load("data/person_test.npy")

print ('Training/Valid data shape: {}'.format(X_train_valid.shape))
print ('Test data shape: {}'.format(X_test.shape))
print ('Training/Valid target shape: {}'.format(y_train_valid.shape))
print ('Test target shape: {}'.format(y_test.shape))
print ('Person train/valid shape: {}'.format(person_train_valid.shape))
print ('Person test shape: {}'.format(person_test.shape))
X_test_dir = './data/X_test.npy'
y_test_dir = './data/y_test.npy' 
X_train_valid_dir = './data/X_train_valid.npy' 
y_train_valid_dir = './data/y_train_valid.npy'
X_test_dsample_dir = './data/X_test_downsample.npy'
y_test_dsample_dir = './data/y_test_downsample.npy' 
X_train_valid_dsample_dir = './data/X_train_valid_downsample.npy' 
y_train_valid_dsample_dir = './data/y_train_valid_downsample.npy'
X_test_ds = np.load(X_test_dsample_dir)
y_test_ds = np.load(y_test_dsample_dir)
X_train_valid_ds = np.load(X_train_valid_dsample_dir)
y_train_valid_ds = np.load(y_train_valid_dsample_dir)

X_train_val_05_70_dir = './data/band_pass_data/X_train_val_05_70.npy'
X_train_val_01_70_dir = './data/band_pass_data/X_train_val_01_70.npy' 
X_train_val_05_70_ds_dir = './data/band_pass_data/X_train_val_downsample_05_70.npy'
X_train_val_01_70_ds_dir = './data/band_pass_data/X_train_val_downsample_01_70.npy' 
X_train_val_01_45_dir = './data/band_pass_data/X_train_val_01_45.npy'
X_train_val_01_45_ds_dir = './data/band_pass_data/X_train_val_downsample_01_45.npy' 

X_train_val_05_70 = np.load(X_train_val_05_70_dir)
X_train_val_01_70 = np.load(X_train_val_01_70_dir) 
X_train_val_05_70_ds = np.load(X_train_val_05_70_ds_dir)
X_train_val_01_70_ds = np.load(X_train_val_01_70_ds_dir)
y_train_valid -= 769
y_test -= 769

Training/Valid data shape: (2115, 22, 1000)
Test data shape: (443, 22, 1000)
Training/Valid target shape: (2115,)
Test target shape: (443,)
Person train/valid shape: (2115, 1)
Person test shape: (443, 1)


In [2]:
def make_steps(samples,samples_per_frame,stride):
    '''
    in:
    samples - number of samples in the session
    samples_per_frame - number of samples in the frame
    stride - the gap between succesive frames
    out: list of tuple ranges
    '''
    
    i = 0
    intervals = []
    while i+samples_per_frame <= samples:
        intervals.append((i,i+samples_per_frame))
        i = i + stride
    return intervals

def make_win_data_pipeline(data_arr,label_arr,num_samples_frame,stride):
    '''
    in:
    data_arr - original data array without windowing
    label_arr - labels of the data array without windowing
    num_samples_frame - number of samples in the frame
    stride - the gap between succesive frames
    
    out:
    data_win_arr - windowed data array
    label_win_arr - labels of the windowed data array
    
    '''
    
    num_trials = data_arr.shape[0]
    num_channels = data_arr.shape[1]
    num_samples = data_arr.shape[2]
    
    steps_list = make_steps(num_samples,num_samples_frame,stride)
    num_windows = len(steps_list)
    
    data_win_arr = np.zeros((num_trials*num_windows,num_channels,num_samples_frame))
    label_win_arr = []
    k = 0
    
    for i in range(num_trials):
        
        trial_label = label_arr[i]
        trial_data = data_arr[i,:,:]
        
        for m,n in enumerate(steps_list):
            start_ind = n[0]
            end_ind = n[1]
            
            win_data = trial_data[:,start_ind:end_ind]
            data_win_arr[k,:,:] = win_data
            label_win_arr.append(trial_label)
            k = k+1
    
    label_win_arr = np.asarray(label_win_arr)
    return data_win_arr, label_win_arr


In [3]:
# Creating the custom dataset

class EEGDataset(Dataset):
    
    """EEG dataset"""
    def __init__(self, subset, transform=None):
        
        'Initialization'
        
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        
        'Generates one sample of data'
        
        x, y = self.subset[index]
        if self.transform:
          pass 
            # x = self.transform(x)
            # y = self.transform(y)
        return x, y
        
    def __len__(self):
        
        'Denotes the total number of samples'
        return len(self.subset)
    

In [4]:
# Defining the shallow conv net


class ShallowConv(nn.Module):
    
    # Defining the building blocks of shallow conv net
    
    def __init__(self, in_channels, num_conv_filters, num_samples_frame, num_eeg_channels,classes):
    
        # Defining as a subclass
        super(ShallowConv, self).__init__()

        self.num_samples_frame = num_samples_frame
        self.num_conv_filters = num_conv_filters
        self.num_eeg_channels = num_eeg_channels
        
        # Define the convolution layer, https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
        self.conv1 = nn.Conv2d(in_channels, self.num_conv_filters, (1, 25), stride=1)
        self.conv_output_width =  int(self.num_samples_frame - (25-1) - 1 + 1)
        
        # Define the 2d batchnorm layer
        self.bnorm2d = nn.BatchNorm2d(self.num_conv_filters)
        
        # Define the 1d batchnorm layer
        self.bnorm1d = nn.BatchNorm1d(self.num_conv_filters)


        # Define the fc layer, https://pytorch.org/docs/stable/generated/torch.nn.Linear.html
        self.fc1 = nn.Linear(self.num_eeg_channels*self.num_conv_filters, self.num_conv_filters)
        
        # Define the elu activation
        self.elu = nn.ELU(0.2)

        # Define the avg pooling layer
        self.avgpool = nn.AvgPool1d(75, stride=15)
        
        self.num_features_linear = int(np.floor(((self.conv_output_width - 75)/15)+1))
        
        

        # Define the fc layer for generating the scores for classes 
        self.fc2 = nn.Linear(self.num_features_linear*self.num_conv_filters, classes)

        # Define the softmax layer for converting the class scores to probabilities
        self.softmax = nn.Softmax(dim=1)
        
    # Defining the connections of shallow conv net
    
    def forward(self, x):
        
        # Reshaping the input for 2-D convolution (B,22,num_samples_frame) -> (B,1,22,num_samples_frame)
        
        x = x.view(-1, 1, 22, self.num_samples_frame)
        
        # Performing the 2-D convolution (B,1,22,300) -> (B,40,22,x_shape_4dim)
        
        x = self.conv1(x)
        x_shape_4dim = x.shape[3]
        
        # ELU activation
        
        x = self.elu(x)
        
        # 2d Batch normalization
        
        x = self.bnorm2d(x)
        
        
        # Reshaping the input to dense layer (B,40,22,x_shape_4dim) -> (B,x_shape_4dim,880)
        
        x = x.permute(0,3,1,2) # (B,40,22,x_shape_4dim) -> (B,x_shape_4dim,40,22)
        x = x.view(-1,x_shape_4dim,880)
        
        # Passing through the dense layer (B,x_shape_4dim,880) -> (B,x_shape_4dim,40)
        
        x = self.fc1(x)
        
        # ELU activation
        
        x = self.elu(x)
        
        # Square activation
        
        x = torch.square(x)
        
        # Reshaping the input for average pooling layer (B,x_shape_4dim,40) -> (B,40,x_shape_4dim)
        
        x = x.permute(0,2,1)
        
        # Passing through the average pooling layer (B,40,x_shape_4dim) -> (B,40,x_pool_3dim)
        
        x = self.avgpool(x)
        x_pool_3dim = x.shape[2]
        
        # Log activation
        
        x = torch.log(x)
        
        # 1D Batch normalization
        
        x = self.bnorm1d(x)
        #print(x.shape)
        
        # Reshaping the input to dense layer (B,40,x_pool_3dim) -> (B,40*x_pool_3dim)
        
        x = x.reshape(-1, 40*x_pool_3dim)
        
        # Passing through the dense layer (B,40*x_pool_3dim) -> (B,classes)
        
        x = self.fc2(x)
        
        # Passing through the softmax layer
        
        x = self.softmax(x)
        
        return x

In [5]:
## Defining the training and validation function

def train_val(model,optimizer,criterion,num_epochs):
    
    
    
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    
    for epoch in range(num_epochs):
        
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        
        
        
        for phase in ['train','val']:
            
            
            
            #Initializing the losses and accuracy
            
            training_loss = 0
            correct_train_preds = 0
            total_train_preds = 0
            batch_train_idx = 0
            
            validation_loss = 0
            correct_val_preds = 0
            total_val_preds = 0
            batch_val_idx = 0
            
            
            # Implementing the training phase
            
            if phase == 'train':
                
                # setting the model to training mode
                
                model.train()
                
                # Loading the training dataset in batches 
                
                for inputs, labels in dataloaders['train']:
                    
                    # Transfer input data and labels to device
                    
                    inputs = inputs.to(device)
                    labels = labels.to(device)
                    
                    # Incrementing the batch counter
                    
                    batch_train_idx += 1
                    
                    # Zeroing the gradient buffer
                    
                    optimizer.zero_grad()
                    
                    # Perform the forward pass
                    
                    outputs = model(inputs)
                    
                    # Compute loss
                    
                    loss = criterion(outputs,labels)
                    
                    
                    # Perform the backward pass
                    
                    loss.backward()
                    
                    # Perform optimization step
                    
                    optimizer.step()
                    
                    # Compute training statistics
                    
                    training_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total_train_preds += labels.size(0)
                    correct_train_preds += predicted.eq(labels).sum().item()
                    
                
                train_loss.append(training_loss)
                t_acc = correct_train_preds/total_train_preds
                train_acc.append(t_acc)
                print('Training loss:',training_loss)
                print('Training accuracy:',t_acc)
                
                    
            else:
                
                
                
                # setting the model to evaluation mode
                
                model.eval()
                
                # Disable gradient computation
                
                with torch.no_grad():
                    
                    # Loading the training dataset in batches 
                    
                    for val_inputs, val_labels in dataloaders['val']:
                        
                        
                        # Transfer input data and labels to device
                    
                        val_inputs = val_inputs.to(device)
                        val_labels = val_labels.to(device)
                        
                        # Incrementing the batch counter
                    
                        batch_val_idx += 1
                        
                        # Perform forward pass
                        
                        val_outputs = model(val_inputs)
                        
                        # Compute loss
                        
                        valid_loss = criterion(val_outputs,val_labels)
                        
                        
                        # Compute validation statistics
                    
                        validation_loss += valid_loss.item()
                        _, val_predicted = val_outputs.max(1)
                        total_val_preds += val_labels.size(0)
                        correct_val_preds += val_predicted.eq(val_labels).sum().item()
                        
                    val_loss.append(validation_loss)
                    v_acc = correct_val_preds/total_val_preds
                    val_acc.append(v_acc)
                    print('Validation loss:',validation_loss)
                    print('Validation accuracy:',v_acc)
            

            
           
        
    return model, train_loss, train_acc, val_loss, val_acc

In [6]:

num_samples_frame = 1000
stride = 50
X_train_win,y_train_win = make_win_data_pipeline(X_train_valid,y_train_valid,num_samples_frame,stride)

print ('Windowed Training/Valid data shape: {}'.format(X_train_win.shape))
print ('Windowed Training/Valid label shape: {}'.format(y_train_win.shape))

# Converting the numpy data to torch tensors

X_train_valid_tensor = torch.from_numpy(X_train_win).float().to(device)
y_train_valid_tensor = torch.from_numpy(y_train_win).float().long().to(device) 

print ('Training/Valid tensor shape: {}'.format(X_train_valid_tensor.shape))
print ('Training/Valid target tensor shape: {}'.format(y_train_valid_tensor.shape))

init_dataset = TensorDataset(X_train_valid_tensor, y_train_valid_tensor) 

# Spliting the dataset into training and validation

lengths = [int(len(init_dataset)*0.8), int(len(init_dataset)*0.2)] 
subset_train, subset_val = random_split(init_dataset, lengths) 

train_data = EEGDataset(subset_train, transform=None)
val_data = EEGDataset(subset_val, transform=None)

# Constructing the training and validation dataloaders

dataloaders = {
    'train': torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=0),
    'val': torch.utils.data.DataLoader(val_data, batch_size=8, shuffle=False, num_workers=0)
}

Windowed Training/Valid data shape: (2115, 22, 1000)
Windowed Training/Valid label shape: (2115,)
Training/Valid tensor shape: torch.Size([2115, 22, 1000])
Training/Valid target tensor shape: torch.Size([2115])


In [7]:

weight_decay = 0.15  # weight decay to alleviate overfiting
shallow_model = ShallowConv(in_channels=1, num_conv_filters=40,num_samples_frame=1000,num_eeg_channels=22,classes=4).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(shallow_model.parameters(), lr = 1e-4, weight_decay=weight_decay)

# Training and validating the model

shallow_model,t_l,t_a,v_l,v_a=train_val(shallow_model, optimizer, criterion, num_epochs=100)

Epoch 0/99
----------
Training loss: 72.47871208190918
Training accuracy: 0.32092198581560283
Validation loss: 70.79359459877014
Validation accuracy: 0.3617021276595745
Epoch 1/99
----------
Training loss: 69.54417634010315
Training accuracy: 0.44089834515366433
Validation loss: 69.58691704273224
Validation accuracy: 0.4326241134751773
Epoch 2/99
----------
Training loss: 67.3483636379242
Training accuracy: 0.49822695035460995
Validation loss: 68.2443071603775
Validation accuracy: 0.4846335697399527
Epoch 3/99
----------
Training loss: 65.15122604370117
Training accuracy: 0.5667848699763594
Validation loss: 66.39617872238159
Validation accuracy: 0.524822695035461
Epoch 4/99
----------
Training loss: 62.93272936344147
Training accuracy: 0.6217494089834515
Validation loss: 65.36586606502533
Validation accuracy: 0.5555555555555556
Epoch 5/99
----------
Training loss: 61.07078731060028
Training accuracy: 0.6684397163120568
Validation loss: 63.5584802031517
Validation accuracy: 0.5626477541