<a href="https://colab.research.google.com/github/Zfeng0207/FIT3199-FYP/blob/dev%2Fzfeng/test_baseline_res_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Starting the Notebook



## Loading Dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main')

In [3]:
!pip install -qqqq mlflow torchmetrics pytorch_lightning

In [4]:
import mlflow

In [5]:
memmap_meta_path = "src/data/memmap/memmap_meta.npz"
memmap_path = "src/data/memmap/memmap.npy"
df_diag_path = "src/data/records_w_diag_icd10.csv"
df_memmap_pkl_path = "src/data/memmap/df_memmap.pkl"


## Merge dataset with labels and ecg paths

In [6]:
import pandas as pd

df_diag = pd.read_csv(df_diag_path)
df_mapped = pd.read_pickle(df_memmap_pkl_path)

In [14]:
def multihot_encode(diagnoses, icd_codes):
    num_classes = len(icd_codes)
    res = np.zeros(num_classes, dtype=np.float32)
    for diag in diagnoses:
        for i, code in enumerate(icd_codes):  # Iterate through icd_codes with index
            if diag.startswith(code):
                res[i] = 1
                break
    return res

In [38]:
import numpy as np
import ast

memmap_meta_path = "/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main/src/data/memmap/memmap_meta.npz"
memmap_meta = np.load(memmap_meta_path, allow_pickle=True)
df_full = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main/src/preprocessed_data/records_w_stroke_labels.csv")
df_full["start"] = memmap_meta["start"]
df_full["length"] = memmap_meta["length"]

df_labels =  df_full[["filename",
            "study_id",
            "patient_id",
            "ecg_time",
            "label_train",
            "all_diag_all",
            "label_stroke",
            "start",
            "length"]]

target_icd_codes = (
    "I20", "I21", "I24", "I25",
    "I42", "E87", "I48", "I44", "I45", "E11", "J44", "J45"
)

df_labels['res'] = df_labels['label_train'].apply(lambda diagnoses: multihot_encode(diagnoses, target_icd_codes))

df_labels['stroke_yn'] = df_labels['res'].apply(lambda x: 1 if 1 in x else 0)

df_rm_nan = df_labels[df_labels['all_diag_all'].apply(lambda x: len(ast.literal_eval(x)) > 0)]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_labels['res'] = df_labels['label_train'].apply(lambda diagnoses: multihot_encode(diagnoses, target_icd_codes))
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_labels['stroke_yn'] = df_labels['res'].apply(lambda x: 1 if 1 in x else 0)


# ECG Dataset

In [47]:
import torch
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, memmap, X, y):
        self.df = X.reset_index(drop=True)
        self.memmap = memmap
        self.y = y

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
      # Access data directly from the DataFrame
      start = self.df.loc[idx, 'start']
      length = self.df.loc[idx, 'length']
      # file_idx = self.df.loc[idx, 'file_idx'] # You might not need file_idx here anymore

      # Extract the flat signal slice
      signal = self.memmap[start : start + length * 12]  # 12 features per timestep
      signal = (signal - signal.mean(axis=0)) / (signal.std(axis=0) + 1e-6)

      # Reshape to [length, 12]
      signal = signal.reshape(length, 12)

      # Convert signal to PyTorch tensor before checking for NaN/inf
      signal = torch.tensor(signal, dtype=torch.float32)

      if torch.isnan(signal).any() or torch.isinf(signal).any():
        return None

      label = torch.tensor(self.y.iloc[idx]['res'], dtype=torch.float32)  # Load 'res' as float32
      return signal, torch.tensor(label, dtype=torch.long) # signal is already a tensor


In [88]:
from torch.utils.data import DataLoader
class ECGDataModule(pl.LightningDataModule):
    def __init__(self, memmap, X_train, y_train, X_val, y_val, X_test, y_test, batch_size=32):
        super().__init__()
        self.memmap = memmap
        self.X_train = X_train
        self.y_train = y_train
        self.X_val = X_val
        self.y_val = y_val
        self.X_test = X_test
        self.y_test = y_test
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = ECGDataset(self.memmap, self.X_train, self.y_train)
        self.val_dataset = ECGDataset(self.memmap, self.X_val, self.y_val)
        self.test_dataset = ECGDataset(self.memmap, self.X_test, self.y_test)

    def setup_fold_data(self, train_idx, val_idx):
        """
        Prepares data for the current fold using the provided indices.
        """
        # Assuming X_train, y_train, etc. are numpy arrays
        self.X_train_fold = self.X_train[train_idx]
        self.y_train_fold = self.y_train[train_idx]
        self.X_val_fold = self.X_train[val_idx]  # Using a portion of train data for validation
        self.y_val_fold = self.y_train[val_idx]

        # Update datasets with fold data
        self.train_dataset = ECGDataset(self.memmap, self.X_train_fold, self.y_train_fold)
        self.val_dataset = ECGDataset(self.memmap, self.X_val_fold, self.y_val_fold)
        # self.test_dataset remains the same as it's not used in k-fold training


    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=11, collate_fn=safe_collate, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,  num_workers=11, collate_fn=safe_collate, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=11, collate_fn=safe_collate, pin_memory=True)


## Swish

In [49]:
import pytorch_lightning as pl
import torch

class Swish(pl.LightningModule):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * torch.sigmoid(x)

## ConvNormPool

In [50]:
class ConvNormPool(pl.LightningModule):
    """Conv Skip-connection module"""
    def __init__(
        self,
        input_size,
        hidden_size,
        kernel_size,
        norm_type='bachnorm'
    ):
        super().__init__()

        self.kernel_size = kernel_size
        self.conv_1 = nn.Conv1d(
            in_channels=input_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.conv_2 = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.conv_3 = nn.Conv1d(
            in_channels=hidden_size,
            out_channels=hidden_size,
            kernel_size=kernel_size
        )
        self.swish_1 = Swish()
        self.swish_2 = Swish()
        self.swish_3 = Swish()
        if norm_type == 'group':
            self.normalization_1 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
            self.normalization_2 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
            self.normalization_3 = nn.GroupNorm(
                num_groups=8,
                num_channels=hidden_size
            )
        else:
            self.normalization_1 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_2 = nn.BatchNorm1d(num_features=hidden_size)
            self.normalization_3 = nn.BatchNorm1d(num_features=hidden_size)

        self.pool = nn.MaxPool1d(kernel_size=2)

    def forward(self, input):
        conv1 = self.conv_1(input)
        x = self.normalization_1(conv1)
        x = self.swish_1(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.conv_2(x)
        x = self.normalization_2(x)
        x = self.swish_2(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        conv3 = self.conv_3(x)
        x = self.normalization_3(conv1+conv3)
        x = self.swish_3(x)
        x = F.pad(x, pad=(self.kernel_size - 1, 0))

        x = self.pool(x)
        return x


# CNN

In [51]:
class CNN(pl.LightningModule):
    def __init__(
        self,
        input_size = 1,
        hid_size = 256,
        kernel_size = 5,
        num_classes = 5,
    ):

        super().__init__()

        self.conv1 = ConvNormPool(
            input_size=input_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size//2,
            kernel_size=kernel_size,
        )
        self.conv3 = ConvNormPool(
            input_size=hid_size//2,
            hidden_size=hid_size//4,
            kernel_size=kernel_size,
        )
        self.avgpool = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(in_features=hid_size//4, out_features=num_classes)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avgpool(x)
        # print(x.shape) # num_features * num_channels
        x = x.view(-1, x.size(1) * x.size(2))
        x = F.softmax(self.fc(x), dim=1)
        return x


# RNN

In [52]:
class RNN(pl.LightningModule):
    """RNN module(cell type lstm or gru)"""
    def __init__(
        self,
        input_size,
        hid_size,
        num_rnn_layers=1,
        dropout_p = 0.2,
        bidirectional = False,
        rnn_type = 'lstm',
    ):
        super().__init__()

        if rnn_type == 'lstm':
            self.rnn_layer = nn.LSTM(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )

        else:
            self.rnn_layer = nn.GRU(
                input_size=input_size,
                hidden_size=hid_size,
                num_layers=num_rnn_layers,
                dropout=dropout_p if num_rnn_layers>1 else 0,
                bidirectional=bidirectional,
                batch_first=True,
            )
    def forward(self, input):
        outputs, hidden_states = self.rnn_layer(input)
        return outputs, hidden_states


# RNN Model

In [53]:
class RNNModel(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hid_size,
        rnn_type,
        bidirectional,
        n_classes=5,
        kernel_size=5,
    ):
        super().__init__()

        self.rnn_layer = RNN(
            input_size=46,#hid_size * 2 if bidirectional else hid_size,
            hid_size=hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )
        self.conv1 = ConvNormPool(
            input_size=input_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.avgpool = nn.AdaptiveAvgPool1d((1))
        self.fc = nn.Linear(in_features=hid_size, out_features=n_classes)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv2(x)
        x, _ = self.rnn_layer(x)
        x = self.avgpool(x)
        x = x.view(-1, x.size(1) * x.size(2))
        x = F.sigmoid(self.fc(x), dim=1)#.squeeze(1)
        return x


# RNN Attention Model

In [84]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics.classification import MultilabelF1Score
from sklearn.metrics import hamming_loss

class RNNAttentionModel(pl.LightningModule):
    def __init__(
        self,
        input_size,
        hid_size,
        rnn_type,
        bidirectional,
        kernel_size=5,
        lr=1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()

        self.conv1 = ConvNormPool(
            input_size=input_size,  # input_size = 12 for ECG
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )
        self.conv2 = ConvNormPool(
            input_size=hid_size,
            hidden_size=hid_size,
            kernel_size=kernel_size,
        )

        self.rnn_layer = RNN(
            input_size=hid_size,
            hid_size=hid_size,
            rnn_type=rnn_type,
            bidirectional=bidirectional
        )

        self.attn = nn.Linear(hid_size, hid_size, bias=False)
        self.fc = nn.Linear(in_features=hid_size, out_features=12)  # Binary output
        self.loss_fn = nn.BCEWithLogitsLoss()
        self.lr = lr

        self.fc = nn.Linear(in_features=hid_size, out_features=12)  # Changed for multi-label
        self.loss_fn = nn.BCEWithLogitsLoss()  # Changed for multi-label

        # Metrics for multi-label
        # Remove the calls to hamming_loss during initialization
        # self.train_hamming_loss = hamming_loss  # Store the function itself
        # self.val_hamming_loss = hamming_loss  # Store the function itself
        self.train_f1 = MultilabelF1Score(num_labels=12)  # Macro F1 score by default
        self.val_f1 = MultilabelF1Score(num_labels=12)


    def forward(self, input):
        input = input.permute(0, 2, 1)  # (batch, 12, 1000)
        x = self.conv1(input)
        x = self.conv2(input)
        x = x.permute(0, 2, 1)  # (batch, time_steps, features)

        x_out, _ = self.rnn_layer(x)  # (batch, time, hid_size)

        attn_weights = torch.softmax(self.attn(x_out), dim=1)  # (batch, time, hid_size)
        x = torch.sum(attn_weights * x_out, dim=1)  # (batch, hid_size)

        logits = self.fc(x)  # (batch, 1)
        return logits

    def on_train_start(self):
        # Log model type as a parameter or tag
        # mlflow.pytorch.log_model(self, "model") # Registers the model
        mlflow.log_param("model_type", "RNNAttentionModel")  # Log as parameter
        mlflow.set_tag("model_type", "RNNAttentionModel")

    def training_step(self, batch, batch_idx):

        x, y = batch
        logits = self(x).squeeze()
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)
        preds = probs > 0.5

        # Calculate hamming loss within the training step
        hamming_loss_val = hamming_loss(y.cpu().numpy(), preds.cpu().numpy())

        acc = self.train_acc(probs, y.int())
        f1 = self.train_f1(probs, y.int())
        auc = self.train_auc(probs, y.int())


        self.log("train_loss", loss, prog_bar=True)
        self.log("train_hamming_loss", hamming_loss_val, prog_bar=True)
        self.log("train_f1", f1, prog_bar=True)

        self.log("train_auc", auc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze()
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)
        preds = probs > 0.5

        acc = self.val_acc(probs, y.int())
        f1 = self.val_f1(probs, y.int())
        auc = self.val_auc(probs, y.int())
        # Calculate hamming loss within the validation step
        hamming_loss_val = hamming_loss(y.cpu().numpy(), preds.cpu().numpy())


        self.log("val_hamming_loss", hamming_loss_val, prog_bar=True)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        self.log("val_auc", auc, prog_bar=True)

    def on_test_start(self):
        self.test_probs = []
        self.test_preds = []
        self.test_targets = []

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x).squeeze()
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).int()

        self.test_probs.append(probs.detach().cpu())
        self.test_preds.append(preds.detach().cpu())
        self.test_targets.append(y.detach().cpu())

    def on_test_end(self):
        self.all_probs = torch.cat(self.test_probs)
        self.all_preds = torch.cat(self.test_preds)
        self.all_targets = torch.cat(self.test_targets)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [70]:
## Safe collate

In [71]:
from torch.nn.utils.rnn import pad_sequence
def safe_collate(batch):
    # Filter out None entries
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None  # Skip entire batch if empty (optional, or raise Exception)

    signals, labels = zip(*batch)
    signals = pad_sequence(signals, batch_first=True)  # if variable-length ECG
    labels = torch.tensor(labels)
    return signals, labels


# Simple LSTM Model

In [72]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F
from torchmetrics.classification import BinaryF1Score, BinaryAUROC


class LSTMClassifier(pl.LightningModule):
    def __init__(self, input_size=12, hidden_size=64, num_layers=2, lr=1e-3):
        super().__init__()
        self.save_hyperparameters()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        self.train_f1 = BinaryF1Score()
        self.val_f1 = BinaryF1Score()
        self.test_f1 = BinaryF1Score()

        self.train_auc = BinaryAUROC()
        self.val_auc = BinaryAUROC()
        self.test_auc = BinaryAUROC()

        self.fc = nn.Linear(hidden_size * 2, 1)  # bidirectional
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(self.device))

    def forward(self, x):
        # x: (B, T, C) → needs to be (B, T, 12)
        out, _ = self.lstm(x)
        out = out[:, -1, :]  # take last timestep
        logits = self.fc(out)
        return logits.squeeze()

    def on_train_start(self):
      # Log model type as a parameter or tag
      mlflow.pytorch.log_model(self, "model") # Registers the model
      mlflow.log_param("model_type", "LSTM")  # Log as parameter
      mlflow.set_tag("model_type", "LSTM")

    def training_step(self, batch, batch_idx):
        self.train()
        x, y = batch
        x = x.to(self.device)
        y = y.to(self.device)
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)
        preds = probs > 0.5
        auc = self.train_auc(probs, y.int())
        acc = (preds == y).float().mean()
        f1 = self.train_f1(preds, y)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        self.log("train_f1", f1, prog_bar=True)
        self.log("train_auc", auc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)
        preds = probs > 0.5
        auc = self.train_auc(probs, y.int())
        acc = (preds == y).float().mean()
        f1 = self.val_f1(preds, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        self.log("val_auc", auc, prog_bar=True)

        return loss


    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        probs = torch.sigmoid(logits)
        preds = probs > 0.5
        auc = self.train_auc(probs, y.int())

        acc = (preds == y).float().mean()
        f1 = self.test_f1(preds, y)

        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)
        self.log("test_f1", f1, prog_bar=True)
        self.log("test_auc", auc, prog_bar=True)

        return loss

    # def configure_gradient_clipping(
    #     self,
    #     optimizer=None,
    #     optimizer_idx=None,
    #     gradient_clip_val=None,
    #     gradient_clip_algorithm=None
    # ):
    #     if optimizer is not None:
    #         torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)



    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)  # Reduced lr


## Data Sampling

In [73]:
# !pip install scikit-multilearn



In [None]:
X_train[0]

In [None]:
df.shape

In [None]:
# import pandas as pd
# import numpy as np

# # Define paths
# df_path = "src/data/df_memmap.csv"
# train_df_path = "src/data/train_df.csv"
# val_df_path = "src/data/val_df.csv"
# test_df_path = "src/data/test_df.csv"

#   # Load labels CSV
# df = df_labels.copy()

# # Now you can split the DataFrame while keeping track of ECG data pointers
# from sklearn.model_selection import train_test_split

# # Split test set with preserved stroke ratio
# train_val_df, test_df = train_test_split(
#     df, test_size=0.10, stratify=df['stroke_yn'], random_state=42
# )

# # Then split stroke/non-stroke from train_val_df as discussed before
# stroke_df = train_val_df[train_val_df['stroke_yn'] == 1]
# nonstroke_df = train_val_df[train_val_df['stroke_yn'] == 0]

# # Balanced sampling
# train_stroke, val_stroke = train_test_split(stroke_df, test_size=0.1, random_state=42)
# train_nonstroke = nonstroke_df.sample(n=len(train_stroke)*2, random_state=42)
# val_nonstroke = nonstroke_df.drop(train_nonstroke.index).sample(n=len(val_stroke)*2, random_state=42)

# # Final splits
# train_df = pd.concat([train_stroke, train_nonstroke]).reset_index(drop=True)
# val_df = pd.concat([val_stroke, val_nonstroke]).reset_index(drop=True)
# test_df = test_df.reset_index(drop=True)

# # df.to_csv("src/data/df_memmap.csv", index=False)
# # train_df.to_csv("src/data/train_df.csv", index=False)
# # val_df.to_csv("src/data/val_df.csv", index=False)
# # test_df.to_csv("src/data/test_df.csv", index=False)

In [None]:
from sklearn.utils.class_weight import compute_class_weight

labels = y_train
class_weights = compute_class_weight(class_weight='balanced', classes=np.array([0, 1]), y=labels)

# class_weights[1] is the weight for positive class
pos_weight = class_weights[1] / class_weights[0]  # Convert to ratio


Model and data Initialization

In [None]:
# Create the data module
ecg_dm = ECGDataModule(
    memmap=memmap_data,
    X_train = X_train,
    X_val = X_val,
    X_test = X_test,
    y_train = y_train,
    y_val = y_val,
    y_test = y_test,
    batch_size=64
)

## Setting up Mlflow for model baseline tracking

In [None]:
from pytorch_lightning.loggers import MLFlowLogger
import mlflow

import mlflow
import mlflow.pytorch
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import os

os.environ['MLFLOW_TRACKING_USERNAME'] = "Zfeng0207"
os.environ['MLFLOW_TRACKING_PASSWORD'] = "af7c8365aec4d3ff7a40563a35ec94d4bc9b4512"
os.environ['MLFLOW_TRACKING_PROJECTNAME'] = "stroke-prediction-dagshub-repo"
# Setup
experiment_name = "lstm-ecg"
tracking_uri = f"https://dagshub.com/{os.environ['MLFLOW_TRACKING_USERNAME']}/{os.environ['MLFLOW_TRACKING_PROJECTNAME']}.mlflow"

mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)

print(f"MLflow tracking experiment name: {experiment_name}")
print(f"Tracking URI: {tracking_uri}")

# Use same URI in logger
mlf_logger = MLFlowLogger(
    experiment_name=experiment_name,
    tracking_uri=tracking_uri,
    log_model=True
)

# Model Selection

In [85]:
model = RNNAttentionModel(12, 64, 'lstm', False)
# model = RNNModel(12, 64, 'lstm', True)
# model = CNN(num_classes=5, hid_size=128)
# model = LSTMClassifier(input_size=12)

In [76]:
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer

checkpoint_callback = ModelCheckpoint(
    dirpath='./checkpoints',
    filename='ecgmodel-ep{epoch:02d}-vloss{val_loss:.2f}',
    save_top_k=1,
    monitor='val_loss',
    mode='min',
)


# Update the Trainer to use the callback
trainer = Trainer(
    # logger=mlf_logger,
    max_epochs=1,
    gradient_clip_val=1.0,
    # callbacks=[checkpoint_callback]  # Add callback here
)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


# Model Training

In [90]:
from skmultilearn.model_selection import iterative_train_test_split
import numpy as np

df = df_rm_nan
X = df.drop(columns='res')  # Features

# Convert 'res' column to a 2D numpy array
y = np.array(df['res'].tolist())

# Now perform the split
X_train, y_train, X_test, y_test = iterative_train_test_split(
    X.values, y, test_size=0.2
)

X_train, y_train, X_val, y_val = iterative_train_test_split(
    X_train, y_train, test_size=0.1
)

memmap_data = np.memmap('/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main/src/data/memmap/memmap.npy', dtype='float32', mode='r')
# --- Create your DataModule here ---
ecg_dm = ECGDataModule(memmap=memmap_data,  # Assuming you have memmap_data
                        X_train=X_train,
                        y_train=y_train,
                        X_val=X_val,
                        y_val=y_val,
                        X_test=X_test,
                        y_test=y_test,
                        batch_size=64)


In [92]:
from skmultilearn.model_selection import IterativeStratification

k_fold = IterativeStratification(n_splits=2, order=1)

# Convert X_train to DataFrame if necessary for proper indexing
if not isinstance(X_train, pd.DataFrame):
    X_train = pd.DataFrame(X_train, index=X.index[np.where(np.isin(X.index, X_train))])


for train_idx, val_idx in k_fold.split(X_train, y_train):  # Split based on X_train and y_train
    # Prepare data for the current fold
    ecg_dm.setup_fold_data(train_idx, val_idx)  # Call the setup_fold_data method to prep data

    # Now you can use the indices to select data for the current fold
    trainer.fit(model, datamodule=ecg_dm)  # Pass the model and DataModule
    result = trainer.predict(model, datamodule=ecg_dm)

KeyboardInterrupt: 

In [None]:
trainer.test(model, datamodule=ecg_dm)

## Evaluation Metrics

# Model Testing

In [None]:
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc,
    accuracy_score, f1_score, roc_auc_score
)
import matplotlib.pyplot as plt

# After test finishes
y_true = model.all_targets.numpy()
y_pred = model.all_preds.numpy()
y_prob = model.all_probs.numpy()

# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title("Confusion Matrix")
plt.show()

# ROC Curve
fpr, tpr, _ = roc_curve(y_true, y_prob)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC)")
plt.legend()
plt.grid(True)
plt.show()

# Metrics
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
roc_auc_score_val = roc_auc_score(y_true, y_prob)

print(f"Accuracy: {acc:.4f}")
print(f"F1 Score: {f1:.4f}")
print(f"AUC-ROC Score: {roc_auc_score_val:.4f}")


In [None]:
import pandas as pd

# Create lists to hold signals and labels
signals = []
labels = []

for idx in range(len(df)):
    start = df.loc[idx, 'start']
    length = df.loc[idx, 'length']

    # Get and normalize the signal
    raw = memmap_data[start : start + length * 12]
    normed = (raw - raw.mean()) / (raw.std() + 1e-6)

    if np.isnan(normed).any() or np.isinf(normed).any():
        continue  # skip bad sample

    signal = normed.reshape(length, 12)
    signals.append(signal)
    labels.append(df.loc[idx, 'Stroke_YN'])

# Create DataFrame
df_signals = pd.DataFrame({
    "signal": signals,  # Each row is a 2D numpy array (object dtype)
    "label": labels
})