In [5]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

### Test Data

<h4>Test Data genration Param</h4>
<ls>
    <li>n_samples: 2</li>
    <li>n_channels: 12</li>
    <li>n_timeStep: 24</li>
    <li>encoding_len: 5</li>
</ls>

In [2]:
# NxCxt
def create_test_data(n_samples = 1, n_channels = 12, n_timeStep = 24, n_classes = 4):
        
        return torch.randn((n_samples, n_channels, n_timeStep)), torch.randn(n_samples, n_classes).argmax(axis = 1)


In [3]:
X_test, y_test = create_test_data(n_samples = 100)
print(f"X_test data shape: {X_test.shape} & y_test shape: {y_test.shape}")

X_test data shape: torch.Size([100, 12, 24]) & y_test shape: torch.Size([100])


### Custom Dataset

In [4]:
class eeg_dataset(Dataset):
    def __init__(self, X, y):
        self.data = X
        self.class_id = y
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
       return self.data[idx], self.class_id[idx]

In [5]:
eeg_data = eeg_dataset(X_test, y_test)
train_dataloader = DataLoader(eeg_data, batch_size = 32)

### Regional Level Information Extraction module

<h4> Testing Param </h4>

In [6]:
left_hms = [1,3,6] #left_hemisphere
right_hms = [0,2,5] #right_hemisphere
middle_hms = [8, 9] #middle_hemisphere

In [26]:
def regional_info_extraction(data, left_hms = left_hms, right_hms = right_hms, middle_hms = middle_hms):
    comb_idx = list(zip(left_hms, right_hms))
    hms_diff = []
    print(data.shape)
    data = data.reshape((data.shape[0], data.shape[1], data.shape[2], 1))

    for (i, j) in comb_idx:
        hms_diff.append(data[:, i, :, :] - data[:, j, :, :])

    D = torch.cat(hms_diff, dim = 2)
    S = data[:, middle_hms, : ,:].permute((0,2,1,3)).reshape((data.shape[0], data.shape[2], -1))
    X = torch.cat((D,S), dim = 2)
    print(f"no. of iter: {len(comb_idx)}")
    print(f"D shape: {D.shape} & D dimension {D.ndim}")
    print(f"S shape: {S.shape} & S dimension {S.ndim}")
    print(f"X shape: {X.shape} & X dimension {X.ndim}")
    return X
    

In [8]:
# X = regional_info_extraction(test_data)

### Feature Encoding module

<h4>Testing Param</h4>
<ls>
    <li>input_size: 25</li>
    <li>num_layer: 2</li>
    <li>n_nodes(hidden_size): 5</li>
    <li>LSTM Config: BiLSTM</li>
</ls>

In [9]:
class Feature_Encoding(nn.Module):
    def __init__(
        self,
        input_size = 25,
        out_features = 5,
        num_layer = 2,
        hidden_size = 5,
        lstm_config = "BiLSTM"
    ):
        super().__init__(self)
        self.Stack_BiLstm_layer = nn.LSTM(input_size = input_size,
                                          hidden_size = hidden_size,
                                          num_layers = num_layer,
                                          bias = True,
                                          batch_first = True,
                                          dropout=0.0,
                                          bidirectional = True if lstm_config == "BiLSTM" else False)
        self.fcn_layer = nn.Linear(in_features = hidden_size, out_features = out_features, bias=True)
        self.activation_layer = nn.ReLU()

    def forward(self, x):
        x = self.stack_BiLSTM_layer(x)
        x = self.fcn_layer(x[:, -1, :])
        x = self.activation_layer(x)
        return x
        

In [10]:
model = Feature_Encoding()

### Classification Module

In [151]:
import torch.nn.functional as F

In [153]:
class Classification_model(nn.Module):
    def __init__(self, input_size):
        super().__init__(self)
        self.fcn = nn.Linear(in_features = input_size, out_features = 40)
    def forward(self, x):
        return F.softmax(x, dim = 1)

### Optimizer

In [13]:
from torch.optim import lr_scheduler
from torch import optim 

In [15]:
optimizer = optim.Adam(model.parameters(), lr = 0.05 )
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma = 0.9)

### Training

In [37]:
def train_step(model, dataloader, optim, loss_fn, device):
    train_loss, train_acc = 0, 0
    model.train()
    for X,y in dataloader():
        X = X.to(device)
        y = y.to(device)
        X = regional_info_extraction(X)
        
        yhat = model(X)
        loss = loss_fn(yhat, y)
        train_loss += loss
        
        optim.zero_grad()
        loss.backward()
        optim.step()
        yhat_classes = torch.argmax((torch.softmax(yhat, dim=1)), dim = 1)
        train_acc += (yhat_classes == y).sum().item()/len(yhat)
    train_acc = train_acc/len(dataloader)
    train_loss = train_loss/len(dataloader)
    return train_loss, train_acc

In [38]:
def test_step(model, dataloader, loss_fn, device):
    test_loss, test_acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X,y in dataloader():
            X = X.to(device)
            y = y.to(device)
            X = regional_info_extraction(X)
            
            yhat = model(X)
            loss = loss_fn(yhat, y)
            test_loss += loss
            
            yhat_classes = torch.argmax((torch.softmax(yhat, dim=1)), dim = 1)
            test_acc += (yhat_classes == y).sum().item()/len(yhat)
    test_acc = test_acc/len(dataloader)
    test_loss = test_loss/len(dataloader)
    return test_loss,test_acc

In [39]:
def train(
    model,
    optimizer,
    scheduler,
    train_dataloader,
    test_dataloader,
    epochs,
    loss_fn,
    device
    ):
    res_dict = {
        "train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
      }
    for epoch in range(epochs):
        train_loss, train_acc = train_step(model, train_dataloader, optimizer, loss_fn, device)
        test_loss, test_acc = test_step(model, test_dataloader,  loss_fn, device)
        res_dict["train_loss"].append(train_loss)
        res_dict["train_acc"].append(train_acc)
        
        res_dict["test_loss"].append(test_loss)
        res_dict["test_acc"].append(test_acc)
        
        scheduler.step()
        print(
        f"Epoch: {epoch+1} | "
        f"train_loss: {train_loss:.4f} | "
        f"train_acc: {train_acc:.4f} | "
        f"test_loss: {test_loss:.4f} | "
        f"test_acc: {test_acc:.4f} | "
        f"lr {optimizer.param_groups[0]['lr']:.4f}"
        )
    return res_dict

        

In [18]:
from pathlib import Path
import os
import numpy as np

In [19]:
file_name = 'data(5-95).npy'

In [20]:
file_path = os.path.join(os.getcwd(),'data', file_name)

'C:\\Users\\Abhishek Rathore\\Desktop\\Paper & Code\\EEG_Imgcls_BiLSTM\\src\\data\\data(5-95).npy'

In [None]:
np.load(file_path, )

In [None]:
class EEG_dataset(Dataset):
    def __init__(self, dir = 'data', file_name):
        self.file_path = os.path.join(os.getcwd(),'data', file_name)
        