<a href="https://colab.research.google.com/github/Zfeng0207/FIT3199-FYP/blob/dev%2Fzfeng/lstm_baseline_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Loading Dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/ECG-MIMIC-main')

TPU has a different way of downloading dependencies keep the cell below to avoid dependency conflicts

In [12]:
!pip install -qqqq mlflow torchmetrics pytorch_lightning

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/961.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m481.3/961.5 kB[0m [31m14.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m119.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m94.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m56.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━

In [13]:
import mlflow

In [10]:
memmap_meta_path = "src/data/memmap/memmap_meta.npz"
memmap_path = "src/data/memmap/memmap.npy"
df_mapped_path = "src/preprocessed_data/records_w_stroke_labels.csv"
df_pkl_path = "src/preprocessed_data/records_w_diag.pkl"

# Merge dataset with labels and ecg paths

In [6]:
import pandas as pd

df_pkl = pd.read_pickle(df_pkl_path)
df_mapped = pd.read_csv(df_mapped_path)  # Assuming df_mapped is saved as a pickle

merged_df = pd.merge(df_pkl, df_mapped, on=["study_id"], how="left")


In [11]:
import numpy as np

meta = np.load(memmap_meta_path, allow_pickle=True)
filenames = meta['filenames']
num_files = len(filenames)
print(f"Number of files: {num_files}")

Number of files: 1


In [None]:
print(meta.files)

In [None]:
df_pkl.shape

In [None]:
merged_df.shape

# Labeling stroke classes

In [None]:
df_labels = df_mapped['label_test'].apply(lambda x: 0 if x == '[]' else 1).to_frame(name='label_test_binary')

# Visualizing target class distribution

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Count the occurrences of each target class
target_counts = df_labels['Stroke_YN'].value_counts()

# Plot the distribution
plt.figure(figsize=(6, 4))
sns.barplot(x=target_counts.index, y=target_counts.values, palette="viridis")
plt.title("Target Distribution (Stroke_YN)", fontsize=14)
plt.xlabel("Stroke Y/N (0 = No Stroke, 1 = Stroke)", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.xticks([0, 1], labels=["No Stroke (0)", "Stroke (1)"])
plt.show()

# Setting up Mlflow for model baseline tracking

In [None]:
import mlflow
import mlflow.pytorch
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import os

os.environ['MLFLOW_TRACKING_USERNAME'] = "Zfeng0207"
os.environ['MLFLOW_TRACKING_PASSWORD'] = "af7c8365aec4d3ff7a40563a35ec94d4bc9b4512"
os.environ['MLFLOW_TRACKING_PROJECTNAME'] = "stroke-prediction-dagshub-repo"

experiment_name = "ecg-lstm-experiment"
mlflow.set_tracking_uri(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')
mlflow.set_experiment(experiment_name)

print(f"MLflow tracking experiment name: {experiment_name}")
print(f'https://dagshub.com/' + os.environ['MLFLOW_TRACKING_USERNAME'] + '/' + os.environ['MLFLOW_TRACKING_PROJECTNAME'] + '.mlflow')


# Data Class

In [None]:
import torch
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, memmap, starts, lengths, labels_df):
        self.memmap = memmap
        self.starts = starts
        self.lengths = lengths
        self.labels = labels_df['stroke_yn'].values
        self.indices = labels_df.index.values  # align with memmap meta

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        start = self.starts[self.indices[idx]]
        length = self.lengths[self.indices[idx]]
        signal = self.memmap[start:start+length]
        label = self.labels[idx]
        return torch.tensor(signal, dtype=torch.float32), torch.tensor(label, dtype=torch.long)


In [None]:
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class ECGDataModule(pl.LightningDataModule):
    def __init__(self, memmap, starts, lengths, train_df, val_df, test_df, batch_size=32):
        super().__init__()
        self.memmap = memmap
        self.starts = starts
        self.lengths = lengths
        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = ECGDataset(self.memmap, self.starts, self.lengths, self.train_df)
        self.val_dataset = ECGDataset(self.memmap, self.starts, self.lengths, self.val_df)
        self.test_dataset = ECGDataset(self.memmap, self.starts, self.lengths, self.test_df)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)


# Simple LSTM Model

In [14]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
import torch.nn.functional as F

class LSTMSleepClassifier(pl.LightningModule):
    def __init__(self, hparams, input_size=12, hidden_size=64, num_layers=2, lr=1e-3):
        super().__init__()
        self.save_hyperparameters(hparams)

        # Define metrics
        self.train_f1 = BinaryF1Score()
        self.val_f1 = BinaryF1Score()
        self.test_f1 = BinaryF1Score()

        self.train_auc = BinaryAUROC()
        self.val_auc = BinaryAUROC()
        self.test_auc = BinaryAUROC()

        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True
        )

        self.fc = nn.Linear(hidden_size * 2, 1)  # bidirectional
        self.loss_fn = nn.BCEWithLogitsLoss()

    def forward(self, x):
        # x: (B, T, C) → needs to be (B, T, 12)
        out, _ = self.lstm(x)
        out = out[:, -1, :]  # take last timestep
        logits = self.fc(out)
        return logits.squeeze()

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y).float().mean()
        self.train_f1.update(preds, y)
        self.train_auc.update(preds, y)
        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss, acc

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y.float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y).float().mean()
        self.val_f1.update(preds, y)
        self.val_auc.update(preds, y)
        self.log("val_loss", loss)
        self.log("val_acc", acc)
        return loss, acc

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        test_loss = F.binary_cross_entropy_with_logits(logits, y.float())
        preds = torch.sigmoid(logits) > 0.5
        acc = (preds == y).float().mean()
        self.test_f1.update(preds, y)
        self.test_auc.update(preds, y)
        self.log("test_acc", acc)
        self.log("test_loss", test_loss)

        return test_loss, acc

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


# Data Sampling

In [None]:
import pandas as pd
import numpy as np

# Load metadata
meta = np.load("memmap.npz", allow_pickle=True)
start = meta["start"]
length = meta["length"]
file_idx = meta["file_idx"]
filenames = meta["filenames"]

# Load labels CSV
df = df_labels.copy()

# Sanity check
assert len(df) == len(start), "Mismatch between label and memmap metadata length"

# Add metadata into DataFrame
df['start'] = start
df['length'] = length
df['file_idx'] = file_idx

# Now you can split the DataFrame while keeping track of ECG data pointers
from sklearn.model_selection import train_test_split

# Split test set with preserved stroke ratio
train_val_df, test_df = train_test_split(
    df, test_size=0.10, stratify=df['stroke_yn'], random_state=42
)

# Then split stroke/non-stroke from train_val_df as discussed before
stroke_df = train_val_df[train_val_df['stroke_yn'] == 1]
nonstroke_df = train_val_df[train_val_df['stroke_yn'] == 0]

# Balanced sampling
train_stroke, val_stroke = train_test_split(stroke_df, test_size=0.1, random_state=42)
train_nonstroke = nonstroke_df.sample(n=len(train_stroke)*2, random_state=42)
val_nonstroke = nonstroke_df.drop(train_nonstroke.index).sample(n=len(val_stroke)*2, random_state=42)

# Final splits
train_df = pd.concat([train_stroke, train_nonstroke]).reset_index(drop=True)
val_df = pd.concat([val_stroke, val_nonstroke]).reset_index(drop=True)
test_df = test_df.reset_index(drop=True)


Model and data Initialization

In [None]:
memmap_data = np.load(memmap_path, allow_pickle=True)
starts = meta['start']
lengths = meta['length']

# Create the data module
ecg_dm = ECGDataModule(
    memmap=memmap_data,
    starts=starts,
    lengths=lengths,
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    batch_size=64
)

# Initialize model
model = LSTMSleepClassifier(input_size=12)


# Model Training

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import MLFlowLogger
import mlflow

# Optional: set MLflow tracking URI and experiment
# mlflow.set_tracking_uri("file:./mlruns")  # or your remote tracking URI
# mlflow.set_experiment("StrokePredictionECG")

mlf_logger = MLFlowLogger(
    experiment_name="ecg-lstm-experiment",
    # tracking_uri="file:./mlruns",  # optional if already set globally
    log_model=True  # logs the model checkpoint as artifact
)

# Pass the logger to the Trainer
trainer = Trainer(
    max_epochs=10,
    accelerator="auto",
    log_every_n_steps=10,
    deterministic=True,
    logger=mlf_logger,
)


# Evaluation Metrics

In [None]:
from torchmetrics.classification import BinaryF1Score, BinaryAUROC

self.f1 = BinaryF1Score()
self.auroc = BinaryAUROC()

# In validation_step:
self.f1(preds, y)
self.auroc(preds.float(), y)
