## Classification of ictal and preictal processed Multi-channel EEG data(CHB-MIT) using 2-Layer of LSTM developed from Scratch

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from sklearn.preprocessing import StandardScaler
import torch.utils.data as data
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import time

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# device Configuration
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def create_dataset(dataset, windowlen):
    """Transform a time series into a prediction dataset

    Args:
        dataset: A numpy array of time series, first dimension is the time steps
        lookback: Size of window for prediction
    """
    # feature_scalling
    sc = StandardScaler()
    signal = sc.fit_transform(dataset)
    X = []
    for i in range(int(len(signal)/windowlen)):
        feature = signal[i*windowlen:(i+1)*windowlen]
        X.append(feature)
    return torch.tensor(np.array(X)).float()

def data_generator( batch_size, windowlen):
    print('Loading CHB-MIT Interical and preictal dataset...')
    preictal_data = pd.read_csv('/content/drive/MyDrive/data/ictal_data.csv')
    ictal_data = pd.read_csv('/content/drive/MyDrive/data/preictal_data.csv')

    class1 = create_dataset(preictal_data, windowlen=windowlen)
    y_1= torch.zeros(class1.shape[0],1)

    class2 = create_dataset(ictal_data, windowlen=windowlen)
    y_2 = torch.ones(class2.shape[0],1)

    datasets = torch.cat((class1, class2),0)
    labels = torch.cat((y_1, y_2), 0)

    X_train, X_test, y_train, y_test = train_test_split(datasets, labels, test_size=0.25, shuffle=True, random_state=42)
    print(f" Shape of the Training data is {X_train.shape,}, and Testing data is {X_test.shape}" )

    train_loader = data.DataLoader(data.TensorDataset(X_train, y_train),  batch_size=batch_size, shuffle=True)
    test_loader = data.DataLoader(data.TensorDataset(X_test, y_test),  batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

In [None]:
class Embedding(nn.Module):
    def __init__(self, d_dim, embed_dim):
        super(Embedding, self).__init__()
        self.d_dim = d_dim
        self.embed_dim = embed_dim
        self.embed = nn.Linear(self.d_dim, self.embed_dim)

    def forward(self, x):
        return self.embed(x)

In [None]:
batch_size =50
window_length = 64
train_loader, test_loader = data_generator(batch_size, window_length)

Loading CHB-MIT Interical and preictal dataset...
 Shape of the Training data is (torch.Size([24574, 64, 23]),), and Testing data is torch.Size([8192, 64, 23])


In [None]:
X_train, label= next(iter(train_loader))
# Hyperparameters
num_classes =1
num_epoch = 10
batch_size = 50
learning_rate = 0.001
feature_length = X_train.shape[2]
sequence_length = X_train.shape[1]
embed_dim = 40
hidden_length = 50
log_interval =10
lr = 0.001
epochs = 30

In [None]:
class LSTM_Scratch(nn.Module):
    """
    A simple LSTM from Scratch
    """
    def __init__(self, feature_length, hidden_length):
        super(LSTM_Scratch, self).__init__()
        self.feature_length = feature_length
        self.hidden_length = hidden_length

        # forget gate components
        self.linear_forget_w1 = nn.Linear(self.feature_length, self.hidden_length, bias=True)
        self.linear_forget_r1 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_forget = nn.Sigmoid()

        # input gate components
        self.linear_gate_w2 = nn.Linear(self.feature_length, self.hidden_length, bias=True)
        self.linear_gate_r2 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_gate = nn.Sigmoid()

        # cell memory components
        self.linear_gate_w3 = nn.Linear(self.feature_length, self.hidden_length, bias=True)
        self.linear_gate_r3 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.activation_gate = nn.Tanh()

        # out gate components
        self.linear_gate_w4 = nn.Linear(self.feature_length, self.hidden_length, bias=True)
        self.linear_gate_r4 = nn.Linear(self.hidden_length, self.hidden_length, bias=False)
        self.sigmoid_hidden_out = nn.Sigmoid()

        self.activation_final = nn.Tanh()

    def forget(self, x, h):
        x = self.linear_forget_w1(x)
        h = self.linear_forget_r1(h)
        return self.sigmoid_forget(x + h)

    def input_gate(self, x, h):
        # Equation 1. input gate
        x_temp = self.linear_gate_w2(x)
        h_temp = self.linear_gate_r2(h)
        return self.sigmoid_gate(x_temp + h_temp)

    def cell_memory_gate(self, i, f, x, h, c_prev):
        x = self.linear_gate_w3(x)
        h = self.linear_gate_r3(h)

        # new information part that will be injected in the new context
        k = self.activation_gate(x + h)
        g = k * i

        # forget old context/cell info
        c = f * c_prev
        # learn new context/cell info
        c_next = g + c
        return c_next

    def out_gate(self, x, h):
        x = self.linear_gate_w4(x)
        h = self.linear_gate_r4(h)
        return self.sigmoid_hidden_out(x + h)

    def forward(self, x_t, tuple_in ):
        hidden_seq = []
        (h, c_prev) = tuple_in
        for t in range(sequence_length):
            x = x_t[:,t,:]
            # Equation 1. input gate
            i = self.input_gate(x, h)

            # Equation 2. forget gate
            f = self.forget(x, h)

            # Equation 3. updating the cell memory
            c_next = self.cell_memory_gate(i, f, x, h,c_prev)

            # Equation 4. calculate the main output gate
            o = self.out_gate(x, h)

            # Equation 5. produce next hidden output
            h_next = o * self.activation_final(c_next)

            c_prev = c_next
            h = h_next

            hidden_seq.append(h.unsqueeze(0))

        #reshape hidden_seq p/ retornar
        hidden_seq = torch.cat(hidden_seq, dim=0)
        hidden_seq = hidden_seq.transpose(0, 1).contiguous()

        return hidden_seq, h_next, c_next

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = Embedding(feature_length,embed_dim)
        self.lstm = LSTM_Scratch(embed_dim, hidden_length) #nn.LSTM(28, 28, batch_first=True)
        self.fc1 = nn.Linear(hidden_length, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        xe = self.embedding(x)
        h_t = torch.zeros(x.size(0), hidden_length).to(device)
        c_t = torch.zeros(x.size(0), hidden_length).to(device)


        x_t, h_t, c_t = self.lstm(xe, (h_t, c_t))

        out = self.fc1(h_t)
        out = self.sigmoid(out)
        return out

In [None]:
# Loss and Optimizer
model= Net().to(device)


In [None]:
model_name = "Model_{}_dim_{}_lr_{}".format(
            'LSTM',embed_dim, lr)

message_filename =  'r_' + model_name + '.txt'
model_filename =  'm_' + model_name + '.pt'
with open(message_filename, 'w') as out:
    out.write('start\n')


def output_s(message, save_filename):
    print (message)
    with open(save_filename, 'a') as out:
        out.write(message + '\n')

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr= lr)


def train(ep):
    targets = list()
    preds = list()
    train_loss = 0
    correct = 0

    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        loss.backward()

        optimizer.step()
        train_loss += loss
        pred = output.round()
        correct += (pred== target).sum().item()
        targets += list(target.detach().cpu().numpy())
        preds += list(pred.detach().cpu().numpy())
        acc = 100. * correct / ((batch_idx+1) * batch_size)

        if batch_idx > 0 and batch_idx % log_interval == 0:
            print("Train Epoch: {} [{}/{} ({:.2f}%)]\tLoss: {:.2f} \t Acc: {:.2f}".format(
                ep, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), train_loss.item()/(batch_idx),acc))

    return 100. * correct / len(train_loader.dataset), train_loss.item()/batch_size,


## Leeanable parameters counts ###
def test():
    model.eval()

    targets = list()
    preds = list()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target)
            pred = output.round()
            correct += (pred== target).sum().item()
            targets += list(target.detach().cpu().numpy())
            preds += list(pred.detach().cpu().numpy())

        Acc = 100. * correct / len(test_loader.dataset)
        test_loss /= len(test_loader.dataset)
        print('\nTest set: Average loss: {:.3f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset), Acc))
        #output_s(message, message_filename)
        return targets, preds, Acc, test_loss

#model_total_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# In[112]:

In [None]:
if __name__ == "__main__":
    exec_time = 0
    for epoch in range(1, epochs+1):
        start = time.time()
        train_acc, train_loss = train(epoch)
        end = time.time()
        t = end-start
        exec_time+= t
        preds, targets, test_acc, test_loss = test()
        message = ('Train Epoch: {}, Train loss: {:.4f}, Time taken: {:.4f}, Train Accuracy: {:.4f}, Test loss: {:.4f}, Test Accuracy: {:.4f}' .format(
                epoch, train_loss, t, train_acc, test_loss, test_acc))
        output_s(message, message_filename)

        if epoch % 10 == 0:
            lr /= 10
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

        if epoch%(epochs)==0:
            print('Total Execution time for training:',exec_time)
            preds = np.array(preds)
            targets = np.array(targets)
            conf_mat= confusion_matrix(targets, preds)
            disp = ConfusionMatrixDisplay(confusion_matrix= conf_mat)
            disp.plot()
            print(classification_report(targets, preds, digits=4))


Test set: Average loss: 0.012, Accuracy: 5672/8192 (69.24%)

Train Epoch: 1, Train loss: 6.3704, Time taken: 47.1186, Train Accuracy: 63.0422, Test loss: 0.0119, Test Accuracy: 69.2383

Test set: Average loss: 0.013, Accuracy: 5216/8192 (63.67%)

Train Epoch: 2, Train loss: 5.9117, Time taken: 46.6705, Train Accuracy: 69.3416, Test loss: 0.0131, Test Accuracy: 63.6719

Test set: Average loss: 0.011, Accuracy: 5949/8192 (72.62%)

Train Epoch: 3, Train loss: 5.9827, Time taken: 44.7475, Train Accuracy: 68.0190, Test loss: 0.0112, Test Accuracy: 72.6196

Test set: Average loss: 0.010, Accuracy: 6415/8192 (78.31%)

Train Epoch: 4, Train loss: 5.1726, Time taken: 44.9652, Train Accuracy: 75.2747, Test loss: 0.0097, Test Accuracy: 78.3081

Test set: Average loss: 0.009, Accuracy: 6543/8192 (79.87%)

Train Epoch: 5, Train loss: 4.5434, Time taken: 47.3935, Train Accuracy: 79.4864, Test loss: 0.0091, Test Accuracy: 79.8706

Test set: Average loss: 0.005, Accuracy: 7374/8192 (90.01%)

Train Ep

KeyboardInterrupt: 

In [None]:
plt.plot(accuracy)
plt.title('Accuracy Vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')

In [None]:
plt.plot(losses)
plt.title('Loss Vs Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')

In [None]:
with torch.no_grad():

    n_correct = 0
    n_samples = 0
    for segments, labels in test_loader:
        segments = segments.to(device)
        labels = labels.to(device)

        outputs = model(segments)

        # max return(value, index)

        predicted_classes = outputs.round()
        #print(predicted)
        n_samples += labels.size(0)
        n_correct += (predicted_classes== labels).sum().item()

    acc = 100.0*n_correct/n_samples
    print(f'Accuracy of the network for ioctal and preictal class: {acc:.3f}%')