# MLP for seizure prediction

In [34]:
"""Models definition."""
import pytorch_lightning as pl
from torchmetrics import Accuracy, F1Score, Precision, Recall
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
from torch import Tensor
from typing import Any


class MLP(pl.LightningModule):
    """Simple MLP model."""

    def __init__(self, train_data: Any, val_data: Any, test_data: Any, n_workers: int = 8) -> None:
        """
        Initialize the MLP.

        :return: None
        """
        super().__init__()

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.n_workers = n_workers

        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(22 * 256 * 60, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 4),
        )

        self.train_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.val_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.test_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")

        self.val_f1 = F1Score(num_classes=4, average='macro', task="multiclass")
        self.test_f1 = F1Score(num_classes=4, average='macro', task="multiclass")

        self.val_precision = Precision(num_classes=4, average='macro', task="multiclass")
        self.test_precision = Precision(num_classes=4, average='macro', task="multiclass")

        self.val_recall = Recall(num_classes=4, average='macro', task="multiclass")
        self.test_recall = Recall(num_classes=4, average='macro', task="multiclass")
        
        # log hyperparameters
        self.save_hyperparameters()

    def forward(self, tensor: Tensor) -> Tensor:
        """
        Forward pass through the network.

        :param tensor: Input tensor
        :return: Output tensor
        """
        # convert tensor into torch.float32
        return self.network(tensor)

    def training_step(self, batch: Any, batch_idx: int) -> Tensor:
        """
        Perform one step of training.

        :param batch: Batch data
        :param batch_idx: Index of the batch
        :return: Loss tensor
        """
        x, y = batch
        y_hat = self.network(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)

        # update metrics
        self.train_accuracy(preds, y)
        # self.train_bal_accuracy(preds, y)
        # self.train_precision(preds, y)
        # self.train_recall(preds, y)
        
        # log metrics
        # self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True, logger=True)
        # self.log('train_recall', self.train_recall, on_step=True, on_epoch=True, logger=True)

        return loss

    def validation_step(self, batch: Any, batch_idx: int) -> None:
        """
        Perform one step of validation.

        :param batch: Batch data
        :param batch_idx: Index of the batch
        :return: None
        """
        x, y = batch
        y_hat = self.network(x)
        loss = nn.functional.cross_entropy(y_hat, y)

        preds = torch.argmax(y_hat, dim=1)
        
        # update metrics
        self.val_accuracy(preds, y)
        # self.val_bal_accuracy(preds, y)
        self.val_f1(preds, y)
        self.val_precision(preds, y)
        self.val_recall(preds, y)
        
        # log metrics
        self.log('val_loss', loss, sync_dist=True)
        self.log('val_acc', self.val_accuracy, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_f1', self.val_f1, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_precision', self.val_precision, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_recall', self.val_recall, on_step=True, on_epoch=True, prog_bar=True)


    def test_step(self, batch: Any, batch_idx: int) -> None:
        """
        Perform one step of testing.

        :param batch: Batch data
        :param batch_idx: Index of the batch
        :return: None
        """
        x, y = batch
        y_hat = self.network(x)
        loss = nn.functional.cross_entropy(y_hat, y)

        preds = torch.argmax(y_hat, dim=1)

        # update metrics
        self.test_accuracy(preds, y)
        # self.test_bal_accuracy(preds, y)
        self.test_f1(preds, y)
        self.test_precision(preds, y)
        self.test_recall(preds, y)

        # log metrics
        self.log('test_loss', loss, sync_dist=True)
        self.log('test_acc', self.test_accuracy, on_step=True, on_epoch=True, logger=True)
        self.log('test_f1', self.test_f1, on_step=True, on_epoch=True, logger=True)
        self.log('test_precision', self.test_precision, on_step=True, on_epoch=True, logger=True)
        self.log('test_recall', self.test_recall, on_step=True, on_epoch=True, logger=True)



    def configure_optimizers(self) -> torch.optim.Optimizer:
        """
        Configure the optimizer for training.

        :return: Optimizer
        """
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=64, shuffle=True, num_workers=self.n_workers)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=64, num_workers=self.n_workers)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=64, num_workers=self.n_workers)


## Data Loader

In [44]:
import os
import pandas as pd
from torch.utils.data import Dataset
import torch



class CustomDataset(Dataset):
    """Custom dataset class."""

    def __init__(self, folder_path: str):
        """
        Initialize the dataset.

        :param folder_path: Path to the folder with the data files
        """
        self.folder_path = folder_path
        self.label_mapping = {
            "preictal": 0,
            "ictal": 1,
            "prepreictal": 2,
            "interictal": 3,
        }
        self.file_paths = []
        self.labels = []
        self._load_file_paths()

    def __len__(self) -> int:
        """
        Get the total number of samples in the dataset.

        :return: Total number of samples
        """
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> tuple:
        """
        Get a sample by index.

        :param idx: Index
        :return: Sample and label
        """
        file_path = self.file_paths[idx]
        data = pd.read_csv(file_path)
        # convert to float32
        data = data.astype("float32")
        # transpose the data
        data = data.transpose()
        return data.values, self.labels[idx]

    def _load_file_paths(self) -> None:
        """
        Load all the file paths and their labels.

        :return: None
        """
        for filename in os.listdir(self.folder_path):
            if filename.endswith(".csv"):
                file_path = os.path.join(self.folder_path, filename)
                label = filename.split("_")[2]
                if label in self.label_mapping:
                    self.file_paths.append(file_path)
                    self.labels.append(self.label_mapping[label])
                else:
                    print(f"Ignoring file {file_path} with unexpected label {label}")


## Train

In [30]:
!module load nvhpc
!nvidia-smi

Wed Jun 21 11:57:58 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A30          On   | 00000000:03:00.0 Off |                    0 |
| N/A   36C    P0    99W / 165W |    263MiB / 24576MiB |     72%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A30          On   | 00000000:85:00.0 Off |                    0 |
| N/A   18C    P0    28W / 165W |   1455MiB / 24576MiB |      0%      Default |
|       

In [31]:
import os
num_cpus = os.cpu_count()
print("Número de CPUs disponibles:", num_cpus)

Número de CPUs disponibles: 20


In [37]:
"""Train the model."""
import mlflow
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import random_split
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import json

torch.set_float32_matmul_precision('medium')

# mlf_logger = MLFlowLogger(experiment_name="lightning_logs")#, tracking_uri="/home/mnsosa/CHB-MIT_Seizure_Prediction/mlruns")

# Define los parámetros de tu experimento
params = {
    "batch_size": 64,
    "lr": 0.001,
    "epochs": 50,
}
# Carga los datos
folder_path = "data/windows_per_csv"
dataset = CustomDataset(folder_path)

# Divide los datos
train_len = int(len(dataset) * 0.7)
val_len = (len(dataset) - train_len) // 2
test_len = len(dataset) - train_len - val_len
train_data, val_data, test_data = random_split(dataset, [train_len, val_len, test_len])

# Crea la instancia del modelo
model = MLP(train_data=train_data, val_data=val_data, test_data=test_data)

# Callback para que no se estanque el entrenamiento
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=20,
    verbose=True,
    mode="min",
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    filename='best_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    save_weights_only=True,
)


trainer = Trainer(
    max_epochs=params["epochs"],
    # callbacks=[early_stop_callback],
    devices=1,
    # logger=mlf_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    benchmark=True,
    accelerator="gpu",
    enable_progress_bar=True,
    enable_model_summary=True,
)


# Entrena el modelo
trainer.fit(
    model,
    train_dataloaders=model.train_dataloader(),
    val_dataloaders=model.val_dataloader(),
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name           | Type                | Params
-------------------------------------------------------
0 | network        | Sequential          | 21.6 M
1 | train_accuracy | MulticlassAccuracy  | 0     
2 | val_accuracy   | MulticlassAccuracy  | 0     
3 | test_accuracy  | MulticlassAccuracy  | 0     
4 | val_f1         | MulticlassF1Score   | 0     
5 | test_f1        | MulticlassF1Score   | 0     
6 | val_precision  | MulticlassPrecision | 0     
7 | test_precision | MulticlassPrecision | 0     
8 | val_recall     | MulticlassRecall    | 0     
9 | test_recall    | MulticlassRecall    | 0     
-------------------------------------------------------
21.6 M    Trainable params
0         Non-trainable params
21.6 M    Total params
86.517    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [38]:
trainer.test(model, dataloaders=model.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.testing metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.4676409065723419
      test_f1_epoch         0.28222620487213135
        test_loss            87.87905883789062
  test_precision_epoch      0.28250351548194885
    test_recall_epoch       0.28241658210754395
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 87.87905883789062,
  'test_acc_epoch': 0.4676409065723419,
  'test_f1_epoch': 0.28222620487213135,
  'test_precision_epoch': 0.28250351548194885,
  'test_recall_epoch': 0.28241658210754395}]

In [45]:
class CNN1D(MLP):
    def __init__(self, train_data: Any, val_data: Any, test_data: Any, n_workers: int = 8) -> None:
        """
        Initialize the CNN1D.

        :return: None
        """
        super().__init__(train_data, val_data, test_data, n_workers)

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.n_workers = n_workers

        self.network = nn.Sequential(
            nn.Conv1d(
                in_channels=22, 
                out_channels=32, 
                kernel_size=256, 
                stride=256, 
                padding=1
            ),
            nn.Flatten(),
            nn.Linear(60*32, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 4)
        )

        self.train_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.val_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.test_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")

        self.val_f1 = F1Score(num_classes=4, average='macro', task="multiclass")
        self.test_f1 = F1Score(num_classes=4, average='macro', task="multiclass")

        self.val_precision = Precision(num_classes=4, average='macro', task="multiclass")
        self.test_precision = Precision(num_classes=4, average='macro', task="multiclass")

        self.val_recall = Recall(num_classes=4, average='macro', task="multiclass")
        self.test_recall = Recall(num_classes=4, average='macro', task="multiclass")
        
        # log hyperparameters
        self.save_hyperparameters()

In [46]:
"""Train the model."""
import mlflow
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from torch.utils.data import random_split
from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import json

torch.set_float32_matmul_precision('medium')

# mlf_logger = MLFlowLogger(experiment_name="lightning_logs")#, tracking_uri="/home/mnsosa/CHB-MIT_Seizure_Prediction/mlruns")

# Define los parámetros de tu experimento
params = {
    "batch_size": 64,
    "lr": 0.001,
    "epochs": 50,
}
# Carga los datos
folder_path = "data/windows_per_csv"
dataset = CustomDataset(folder_path)

# Divide los datos
train_len = int(len(dataset) * 0.7)
val_len = (len(dataset) - train_len) // 2
test_len = len(dataset) - train_len - val_len
train_data, val_data, test_data = random_split(dataset, [train_len, val_len, test_len])

# Crea la instancia del modelo
model = CNN1D(train_data=train_data, val_data=val_data, test_data=test_data)

# Callback para que no se estanque el entrenamiento
early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=20,
    verbose=True,
    mode="min",
)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    filename='best_model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    save_weights_only=True,
)


trainer = Trainer(
    max_epochs=params["epochs"],
    # callbacks=[early_stop_callback],
    devices=1,
    # logger=mlf_logger,
    log_every_n_steps=1,
    enable_checkpointing=True,
    check_val_every_n_epoch=1,
    benchmark=True,
    accelerator="gpu",
    enable_progress_bar=True,
    enable_model_summary=True,
)


# Entrena el modelo
trainer.fit(
    model,
    train_dataloaders=model.train_dataloader(),
    val_dataloaders=model.val_dataloader(),
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]

  | Name           | Type                | Params
-------------------------------------------------------
0 | network        | Sequential          | 305 K 
1 | train_accuracy | MulticlassAccuracy  | 0     
2 | val_accuracy   | MulticlassAccuracy  | 0     
3 | test_accuracy  | MulticlassAccuracy  | 0     
4 | val_f1         | MulticlassF1Score   | 0     
5 | test_f1        | MulticlassF1Score   | 0     
6 | val_precision  | MulticlassPrecision | 0     
7 | test_precision | MulticlassPrecision | 0     
8 | val_recall     | MulticlassRecall    | 0     
9 | test_recall    | MulticlassRecall    | 0     
-------------------------------------------------------
305 K     Trainable params
0         Non-trainable params
305 K     Total params
1.222     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [47]:
trainer.test(model, dataloaders=model.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Runningstage.testing metric      DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.5344467759132385
      test_f1_epoch          0.227859228849411
        test_loss            7.009418964385986
  test_precision_epoch      0.22744956612586975
    test_recall_epoch       0.2373500019311905
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 7.009418964385986,
  'test_acc_epoch': 0.5344467759132385,
  'test_f1_epoch': 0.227859228849411,
  'test_precision_epoch': 0.22744956612586975,
  'test_recall_epoch': 0.2373500019311905}]

In [None]:
class CNN2D(MLP):
    def __init__(self, train_data: Any, val_data: Any, test_data: Any, n_workers: int = 8) -> None:
        """
        Initialize the CNN2D.

        :return: None
        """
        super().__init__(train_data, val_data, test_data, n_workers)

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.n_workers = n_workers

        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=22, 
                out_channels=32, 
                kernel_size=(3,3), 
                stride=1, 
                padding=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(?, 64),  # you need to compute the output shape after conv and pooling
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 4)
        )

        self.train_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.val_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")
        self.test_accuracy = Accuracy(compute_on_step=False, num_classes=4, task="multiclass")

        self.val_f1 = F1Score(num_classes=4, average='macro', task="multiclass")
        self.test_f1 = F1Score(num_classes=4, average='macro', task="multiclass")

        self.val_precision = Precision(num_classes=4, average='macro', task="multiclass")
        self.test_precision = Precision(num_classes=4, average='macro', task="multiclass")

        self.val_recall = Recall(num_classes=4, average='macro', task="multiclass")
        self.test_recall = Recall(num_classes=4, average='macro', task="multiclass")
        
        # log hyperparameters
        self.save_hyperparameters()
