<a href="https://colab.research.google.com/github/Deep-Learning-Qatar/EEG-Vision/blob/main/EEG_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Connect to drive**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **Import required Libraries**

In [None]:
import time
import torch
import random
import numpy as np
import torch.nn as nn
from torch.utils import data
import matplotlib.pyplot as plt
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence

# **Are you using GPU**

In [None]:
cuda = torch.cuda.is_available()

# **Move the data to local machine**

In [None]:
# change based on where you store your data and where you are gonna load it from
%mkdir data
!cp drive/MyDrive/dl_project_data/eeg_55_95_std.pth data/

In [None]:
data_path = 'data/eeg_55_95_std.pth'
data_dict = torch.load(data_path)

# **Hyper-parameters of Bi-LSTM**

In [None]:
bidirectional    = True 
lr               = 0.01
wd               = 5e-6
dropout_prob     = 0.5  
record_length    = 440    # Fares et al. (2019)
batch_size       = 440    # Fares et al. (2019)
input_size       = 128    # Fares et al. (2019)
feat_num         = 60     # Fares et al. (2019)  
num_classes      = 40     # Fares et al. (2019)
num_hidden_nodes = 128    # Fares et al. (2019) / might be different for us (128)
num_layers       = 2      # Fares et al. (2019)
num_epochs       = 2500   # Fares et al. (2019)
ICA_iteratoins   = 400    # Fares et al. (2019)

# **Independent component analysis classifier**

In [None]:
# source: https://towardsdatascience.com/independent-component-analysis-ica-in-python-a0ef0db0955e

np.random.seed(0)

def g(x):
    return np.tanh(x)

def g_der(x):
    return 1 - g(x) * g(x)

def center(X):
    X = np.array(X)
    mean = X.mean(axis=1, keepdims=True)
    return X- mean

def whitening(X):
    cov = np.cov(X)
    d, E = np.linalg.eigh(cov)
    D = np.diag(d)
    D_inv = np.sqrt(np.linalg.inv(D))
    X_whiten = np.dot(E, np.dot(D_inv, np.dot(E.T, X)))
    return X_whiten

def calculate_new_w(w, X):
    w_new = (X * g(np.dot(w.T, X))).mean(axis=1) - g_der(np.dot(w.T, X)).mean() * w
    w_new /= np.sqrt((w_new ** 2).sum())
    return w_new

def ica(X, iterations, tolerance=1e-5):
    X = center(X)
    X = whitening(X)
    components_nr = X.shape[0]
    
    W = np.zeros((components_nr, components_nr), dtype=X.dtype)
    for i in range(components_nr):        
            w = np.random.rand(components_nr)
            
            for j in range(iterations):
                w_new = calculate_new_w(w, X)
                if i >= 1:
                    w_new -= np.dot(np.dot(w_new, W[:i].T), W[:i])
               
                distance = np.abs(np.abs((w * w_new).sum()) - 1)
                w = w_new
                if distance < tolerance:
                    break

            W[i, :] = w
    S = np.dot(W, X)
    return S

# **Architecture of Bi-LSTM**

In [None]:
# BiLSTM for EEG features encoding 95% accuracy
class BiLSTM(nn.Module):
    def __init__(self, record_length, input_size, nodes_num, feat_num, num_classes, n_layers, bidirectional, dropout):
        #                  440,            128      128          60      40,         2           T           0.2
        super().__init__()
        
        # feature encoding layers

        # sequence input layer
        self.embedding = torch.nn.Sequential(
            nn.Conv1d(input_size, nodes_num, 3, padding=1, bias=False),
            nn.BatchNorm1d(input_size),
            nn.ReLU(inplace=True))

        # Stacked BiLSTM netwrok
        self.lstm = nn.LSTM(nodes_num,
                            nodes_num,
                            num_layers=n_layers,
                            bidirectional=bidirectional,
                            batch_first=True,
                            dropout=dropout,
                            bias=True)
        
        # Fully connected layer
        self.fc1 = nn.Linear(nodes_num * 2, nodes_num) #x2 because we num_directions=2
        
        # Relu layer
        self.relu = nn.ReLU()
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)

        # classification layer
        self.fc2 = nn.Linear(nodes_num, num_classes)

    def forward(self, x, x_len):

        # preprocessing to pass it to CNN (B x * x T)
        x2 = x.permute(0, 2, 1)

        # through CNN
        embedded = self.embedding(x2)            # (B x * x T)
        embedded = embedded.permute(0, 2, 1)    # (B x * x T)

        # through BiLSTM
        packed_embedded = pack_padded_sequence(embedded, x_len, batch_first=True) 
        packed_output, (hidden, cell) = self.lstm(packed_embedded)

        # Combine output of BiLSTM and extract features
        cat = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        features = self.fc1(cat)
        features_out =  self.dropout(self.relu(features))

        #classify 
        class_out = self.dropout(self.fc2(features))
        
        return features_out, class_out

# **Data pre-processsing and Data-set**

In [None]:
# Load data and create dataset for each split
def split_train_val_test(data_dict, splits=(0.8, 0.1, 0.1)):

    # Find all possible image IDs
    image_ids = set()
    for di in data_dict['dataset']:
        image_ids.add(di['image'])
    # print(len(image_ids))

    # Organise data by image (key: image ID, val: list of data dicts with all data)
    data_by_image = dict()
    for id in image_ids:
        data_by_image[id] = []
    for di in data_dict['dataset']:
        image_id = di['image']
        data_by_image[image_id].append(di)

    # Shuffle data so selection for splits are random
    image_ids_li = list(image_ids)
    random.shuffle(image_ids_li)
    data_by_image = {id: data_by_image[id] for id in image_ids_li}
    
    # Create val and test sets
    data_len = len(image_ids_li)
    val_len, test_len = int(splits[1]*data_len), int(splits[2]*data_len)
    val_data, test_data = dict(), dict()
    for i in range(val_len):
        k, v = data_by_image.popitem()
        val_data[k] = v
    for i in range(test_len):
        k, v = data_by_image.popitem()
        test_data[k] = v
    train_data = data_by_image
    # print(train_data.keys())
    # print(val_data.keys())
    # print(test_data.keys())
    
    # Return all sets
    return train_data, val_data, test_data


# General data set for EEG data dictionary
class EEGDataSet(data.Dataset):
    """
    Possible labels: eeg, label (ID), image (ID), subject
    Interval indicates what section of the ~500ms EEG signal should be returned
    """
    def __init__(self, data_dict, x_label='eeg', y_label='label', interval=(20, 460)):
        dl = list(data_dict.values())
        self.data_list = [item for sublist in dl for item in sublist]
        self.length = len(self.data_list)
        self.interval = interval
        self.x_label = x_label
        self.y_label = y_label
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, index):
        data_entry = self.data_list[index]
        x = data_entry[self.x_label]
        if self.x_label == 'eeg':
            x = x[:, self.interval[0]:self.interval[1]]
            x = torch.transpose(x, 0, 1).float()
        else:
            x = torch.as_tensor(x).float()
        y = data_entry[self.y_label]
        if self.y_label == 'eeg':
            y = y[:, self.interval[0]:self.interval[1]]
            y = torch.transpose(y, 0, 1).long()
        else:
            y = torch.as_tensor(y).long()
        return x, y

    # def collate(self, batch):
    #     # Seperate data and labels
    #     X = [x[0] for x in batch]
    #     Y = [y[1] for y in batch]

    #     # Get lengths
    #     X_length = [batch_size for i in X]
    #     # X_length = torch.LongTensor([len(i) for i in X])

    #     return torch.tensor(X), torch.tensor(Y), X_length


train_data, val_data, test_data = split_train_val_test(data_dict)

# **Load data into Data-loader**

In [None]:
# Training
train_dataset = EEGDataSet(train_data)
train_loader_args = dict(shuffle=True, batch_size=batch_size, num_workers=4) if cuda else dict(shuffle=True, batch_size=batch_size)
train_loader = data.DataLoader(train_dataset, **train_loader_args)

# Validation
val_dataset = EEGDataSet(val_data)
val_loader_args = dict(shuffle=False, batch_size=batch_size, num_workers=4) if cuda else dict(shuffle=False, batch_size=batch_size)
val_loader = data.DataLoader(val_dataset, **val_loader_args)

# Testing
test_dataset = EEGDataSet(test_data)
test_loader_args = dict(shuffle=False, batch_size=batch_size, num_workers=4) if cuda else dict(shuffle=False, batch_size=batch_size)
test_loader = data.DataLoader(test_dataset, **test_loader_args)

# **Set up the model**

In [None]:
model = BiLSTM(record_length, input_size, num_hidden_nodes, feat_num, num_classes, num_layers, bidirectional, dropout_prob)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=1)
device = torch.device("cuda" if cuda else "cpu")                        
model.to(device)
print(model)

BiLSTM(
  (embedding): Sequential(
    (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (lstm): LSTM(128, 128, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)
  (fc1): Linear(in_features=256, out_features=128, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=128, out_features=40, bias=True)
)


# **Training Loop**

In [None]:
def train_epoch(model, train_loader, criterion, optimizer):
    print("Training...")
    model.train()

    running_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0

    start_time = time.time()
    for batch_idx, (data, target) in enumerate(train_loader): 

        optimizer.zero_grad()
        data = (data).to(device)
        target = (target).to(device)
        X_length = [batch_size for i in data]

        features, outputs = model(data, X_length) # Run data through the model
        loss = criterion(outputs, target)            # Compare output to target
        running_loss += loss.item()                  # Add loss to the total

        loss.backward()                              # Backpropagate the loss
        optimizer.step()                             # Update the weights & biases

        for i in range(len(outputs)):
            if torch.argmax(outputs[i]) == target[i]:
                correct_predictions += 1
            total_predictions += 1

    end_time = time.time()

    running_loss /= len(train_loader)
    print('Training Loss: ', running_loss, ' Time: ', end_time-start_time, 's')
    acc = (correct_predictions/total_predictions)*100.0
    print('Training Accuracy: ', acc, '%')
    return running_loss, acc

# **Validation Loop**

In [None]:
def val_model(model, val_loader, criterion):
    print("Validating...")
    with torch.no_grad():
        model.eval()

        running_loss = 0.0
        total_predictions = 0.0
        correct_predictions = 0.0

        for batch_idx, (data, target) in enumerate(val_loader):

            data = (data).to(device)
            target = (target).to(device)
            X_length = [batch_size for i in data]

            features, outputs = model(data, X_length) # Run data through the model
            loss = criterion(outputs, target).detach()    # Compare output to target
            running_loss += loss.item()                   # Add loss to the total

            for i in range(len(outputs)):
                if torch.argmax(outputs[i]) == target[i]:
                    correct_predictions += 1
                total_predictions += 1


        running_loss /= len(val_loader)
        acc = (correct_predictions/total_predictions)*100.0
        print('Validating Loss: ', running_loss)
        print('Validating Accuracy: ', acc, '%')
        return running_loss, acc

# **Train & Validate**

In [None]:
Train_loss  = []
Train_acc   = []
Val_loss    = []
Val_acc     = []
trialNumber = 3

# train and get results
for i in range(num_epochs):
    print("Epoch Number: ", i)
    t_loss, t_acc = train_epoch(model, train_loader, criterion, optimizer)
    v_loss, v_acc = val_model(model, val_loader, criterion)
    scheduler.step(v_loss)

    # Add results
    Train_loss.append(t_loss)
    Val_loss.append(v_loss)
    Train_acc.append(t_acc)
    Val_acc.append(v_acc)

    # Save model (epoch)
    torch.save(model.state_dict(), 'drive/MyDrive/dl_project_data/results/run_'+str(trialNumber)+'/model_'+"epoch"+str(i)+'.pt')
    print("")
    print("==="*20)
    print("")

# **Plotting the resuls**

In [None]:

# PLot Training Loss
plt.title('Training Loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.plot(Train_loss)

In [None]:

# PLot Validating Loss
plt.title('Validating Loss')
plt.xlabel('Epoch Number')
plt.ylabel('Loss')
plt.plot(Val_loss)


In [None]:

#PLot Training Accuracy
plt.title('Training Accuracy')
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy (%)')
plt.plot(Train_acc)

In [None]:
#PLot Validating Accuracy
plt.title('Validatin Accuracy')
plt.xlabel('Epoch Number')
plt.ylabel('Accuracy (%)')
plt.plot(Val_acc)