## loading essential modules

In [9]:


import torch
import torch.nn as nn
import torchmetrics
import importlib

from torch.optim.lr_scheduler import StepLR

from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning import LightningModule,Trainer
from torch.utils.data import TensorDataset,DataLoader

import models
importlib.reload(models)
from models import ResNet1D, Bio, Conv1D_v2, EEGInceptionModel, ChronoNet
from dataset import EEG_inception

### instantiating the lightingmodule

In [2]:
class LModel(LightningModule):
    def __init__(self, attribute):
        super(LModel, self).__init__()
        self.attribute = attribute
        self.model = attribute["model"] # initialize the model
        self.lr = attribute["lr"]
        self.bs = 64
        self.worker = 1
        self.acc = torchmetrics.Accuracy(task="binary")
        self.criterion = nn.BCEWithLogitsLoss()
    
    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=0.0005)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    def train_dataloader(self):
        dataset = self.attribute["train_dataset"]
        return DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=True)

    def val_dataloader(self):
        dataset = self.attribute["val_dataset"]
        return DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=False)

    def training_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        preds = (torch.sigmoid(out.flatten()) > 0.5).long()
        acc = self.acc(preds, label.long().flatten())
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        preds = (torch.sigmoid(out.flatten()) > 0.5).long()
        acc = self.acc(preds, label.long().flatten())
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        return loss


## Loading the all the combinations of dataloaders

In [3]:
val_dataset =  EEG_inception(kind = "val", normalize= 1)
train_dataset_smote = EEG_inception(kind="train", normalize= 1, balancing="smote")

# train_dataset is the signal augmentation dataset based on the paper EEG inception
train_dataset = EEG_inception(kind="train", normalize= 1, balancing="inception")


100%|██████████| 856/856 [00:00<00:00, 1873.36it/s]
100%|██████████| 7650/7650 [00:04<00:00, 1859.44it/s]
100%|██████████| 7650/7650 [00:04<00:00, 1843.40it/s]


before (7650, 6000)


100%|██████████| 7650/7650 [00:04<00:00, 1803.22it/s]
100%|██████████| 19384/19384 [02:23<00:00, 134.83it/s]
100%|██████████| 7650/7650 [00:03<00:00, 1918.79it/s]


In [4]:
# sanity for all the models

# Generate a random input tensor of shape (batch_size=3, channels=8, time_steps=900)
x = torch.randn(3, 8, 750)

# Initialize an empty list to store the models
models = []

# 1. Sanity check for ResNet1D
model_1 = ResNet1D()
models.append(model_1)
output_1 = model_1(x)
print(f"Output of ResNet1D model: {output_1.shape}")

# 2. Sanity check for Bio model
model_3 = Bio(input_size=8)
models.append(model_3)
output_3 = model_3(x)
print(f"Output of Bio model: {output_3.shape}")

# 3. Sanity check for Conv1D_v2
model_4 = Conv1D_v2(channels=8)
models.append(model_4)
output_4 = model_4(x)
print(f"Output of Conv1D_v2 model: {output_4.shape}")

# 4. Sanity check for SimplifiedEEGInceptionModel
model_5 = EEGInceptionModel(in_channels=8)
models.append(model_5)
output_5 = model_5(x)
print(f"Output of SimplifiedEEGInceptionModel: {output_5.shape}")

# 5. Sanity check for ChronoNet
model_6 = ChronoNet(8)
models.append(model_6)
output_6 = model_6(x)
print(f"Output of ChronoNet model: {output_6.shape}")


Output of ResNet1D model: torch.Size([3, 1])




Output of Bio model: torch.Size([3, 1])
Output of Conv1D_v2 model: torch.Size([3, 1])
Output of SimplifiedEEGInceptionModel: torch.Size([3, 1])
Output of ChronoNet model: torch.Size([3, 1])


In [5]:
# organizing all combinations of models for a single run. 

attributes = {
    1: {"model": ResNet1D(), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0001},
    2: {"model": ResNet1D(), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0001},
    3: {"model": ResNet1D(), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0005},
    4: {"model": ResNet1D(), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0005},
    
    5: {"model": Bio(input_size=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0001},
    6: {"model": Bio(input_size=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0001},
    7: {"model": Bio(input_size=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0005},
    8: {"model": Bio(input_size=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0005},
    
    9: {"model": Conv1D_v2(channels=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0001},
    10: {"model": Conv1D_v2(channels=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0001},
    11: {"model": Conv1D_v2(channels=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0005},
    12: {"model": Conv1D_v2(channels=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0005},
    
    13: {"model": EEGInceptionModel(in_channels=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0001},
    14: {"model": EEGInceptionModel(in_channels=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0001},
    15: {"model": EEGInceptionModel(in_channels=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0005},
    16: {"model": EEGInceptionModel(in_channels=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0005},
    
    17: {"model": ChronoNet(channel=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0001},
    18: {"model": ChronoNet(channel=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0001},
    19: {"model": ChronoNet(channel=8), "train_dataset": train_dataset, "val_dataset": val_dataset, "lr": 0.0005},
    20: {"model": ChronoNet(channel=8), "train_dataset": train_dataset_smote, "val_dataset": val_dataset, "lr": 0.0005}
}

## Model training

In [None]:
for model_key, attribute in attributes.items():
    print(f"{attribute['train_dataset']}_{attribute['lr']}")

    model_name = attribute["model"].__class__.__name__
    print(model_name, "here")

    # Extract the learning rate (if present)
    lr = attribute.get("lr", None) 
    
    # Extract the train and validation datasets
    train_dataset = attribute["train_dataset"]
    val_dataset = attribute["val_dataset"]
    
    # Determine the dataset type (train_dataset_smote or train_dataset)
    dataset_type = 'train_dataset_smote' if train_dataset == train_dataset_smote else 'train_dataset'

    # Define EarlyStopping callback
    early_stopping_callback = EarlyStopping(
        monitor='val_acc',
        patience = 6,
        verbose=True,
        mode='max',
        check_finite=True
    )

    print(model_name, "here")
    # Modify ModelCheckpoint callback to use a custom filename with placeholders
    
    
    # Create a model name based on the above information
    model_identifier = f"{model_name}_lr_{lr}_dataset_{dataset_type}"
    
    model = attribute["model"]
    lr = attribute["lr"]
    train_dataset = attribute["train_dataset"]
    dataset_type = 'train_dataset_smote' if train_dataset == train_dataset_smote else 'train_dataset_inception'

    # Print the model identifier for clarity
    print(f"Training model: {model_identifier}")

    checkpoint_callback = ModelCheckpoint(
        dirpath=f"checkpoints_v3(1)_adamw_l2_0.0004_demo/{model_identifier}",
        filename='{epoch}_v{val_acc:.4f}_t{train_acc:.4f}',
        save_top_k=3,
        verbose=True,
        monitor='val_acc',
        mode='max'
    )

    trainer = Trainer(
        max_epochs=200,
        callbacks=[early_stopping_callback, checkpoint_callback],
    )
    
    # Train the model
    model = LModel(attribute)
    trainer.fit(model)
    print(trainer.callback_metrics)

