<a href="https://colab.research.google.com/github/LiQuinChing/25-26J-522/blob/Ischemia_CAD_detect-Thisal/CAD_Ischemia_D2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Step 01: Environmental Setup**

In [10]:
pip install wfdb numpy scipy torch scikit-learn matplotlib pandas



# **Step 02: Imports**

In [33]:
import os
import wfdb
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from scipy.signal import butter, filtfilt, resample
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score


# **Step 03: Mount Google Drive**

In [34]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# **Step 04: Paths**

In [35]:
PTBXL_PATH = "/content/drive/MyDrive/ECG_Datasets/PTBXL"
EDB_PATH   = "/content/drive/MyDrive/ECG_Datasets/European-STT"

SAVE_PATH = "/content/drive/MyDrive/ECG_Datasets"
PTB_X = os.path.join(SAVE_PATH, "X_ptb.npy")
PTB_Y = os.path.join(SAVE_PATH, "y_ptb.npy")
EDB_X = os.path.join(SAVE_PATH, "X_edb.npy")
EDB_Y = os.path.join(SAVE_PATH, "y_edb.npy")

PTB_CKPT = os.path.join(SAVE_PATH, "ptbxl_checkpoint.pt")
EDB_CKPT = os.path.join(SAVE_PATH, "edb_checkpoint.pt")


# **Step 05: Signal Preprocessing**

In [36]:
def bandpass_filter(signal, fs=250):
    b, a = butter(4, [0.5/(fs/2), 40/(fs/2)], btype='band')
    return filtfilt(b, a, signal, axis=0)

def normalize(signal):
    return (signal - np.mean(signal)) / (np.std(signal) + 1e-8)

def resample_signal(signal, orig_fs, target_fs=250):
    n_samples = int(len(signal) * target_fs / orig_fs)
    return resample(signal, n_samples)


# **Step 06: ECG Segmentation**

In [37]:
def segment_signal(signal, window=1250, step=1250):
    return np.array([
        signal[i:i+window]
        for i in range(0, len(signal)-window, step)
    ])


# **Step 07: PTB-XL Loader (PRETRAINING)**

Load metadata

In [38]:
meta = pd.read_csv(os.path.join(PTBXL_PATH, "ptbxl_database.csv"))
meta = meta.sample(3000, random_state=42)


Label mapping (ischemia-related)

In [39]:
def is_ischemia(row):
    scp = eval(row["scp_codes"])
    return int(any(k in scp for k in ["ISC_", "MI"]))


Load PTB-XL ECG

In [40]:
def load_ptbxl_record(record_path):
    record = wfdb.rdrecord(record_path)
    signal = record.p_signal[:, [1, 6]]  # Lead II, V5
    signal = resample_signal(signal, record.fs)
    signal = bandpass_filter(signal)
    signal = normalize(signal)
    return segment_signal(signal)



# **Step 08: Build or Load Preprocessed PTB-XL**

In [41]:
if os.path.exists(PTB_X) and os.path.exists(PTB_Y):
    print("Loading saved PTB-XL data...")
    X_ptb = np.load(PTB_X)
    y_ptb = np.load(PTB_Y)

else:
    print("Processing PTB-XL from scratch...")
    X_ptb, y_ptb = [], []

    for _, row in meta.iterrows():
        try:
            record_path = os.path.join(PTBXL_PATH, row["filename_hr"])
            segments = load_ptbxl_record(record_path)

            label = int(any(
                k.startswith(("ISC", "MI"))
                for k in eval(row["scp_codes"]).keys()
            ))

            X_ptb.extend(segments)
            y_ptb.extend([label] * len(segments))

        except:
            continue

    X_ptb = np.array(X_ptb)
    y_ptb = np.array(y_ptb)

    np.save(PTB_X, X_ptb)
    np.save(PTB_Y, y_ptb)

print("PTB-XL samples:", len(X_ptb))


Processing PTB-XL from scratch...
PTB-XL samples: 3000


In [42]:
print("PTB-XL segments:", len(X_ptb))
print("PTB-XL labels:", len(y_ptb))


PTB-XL segments: 3000
PTB-XL labels: 3000


# **Step 09: PyTorch Dataset**

In [43]:
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# **Step 10: CNN-LSTM Model**

In [44]:
class CNN_LSTM(nn.Module):
    def __init__(self):
        super().__init__()

        # self.cnn = nn.Sequential(
        #     nn.Conv1d(2, 32, 7, padding=3),
        #     nn.ReLU(),
        #     nn.MaxPool1d(2),
        #     nn.Conv1d(32, 64, 5, padding=2),
        #     nn.ReLU(),
        #     nn.MaxPool1d(2)
        # )

        self.cnn = nn.Sequential(
            nn.Conv1d(2, 32, 7, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.2),
            nn.Conv1d(32, 64, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool1d(2),
            nn.Dropout(0.2)
        )

        self.lstm = nn.LSTM(64, 64, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        # self.fc = nn.Linear(64, 1)

    def forward(self, x):
        x = x.permute(0, 2, 1)
        x = self.cnn(x)
        x = x.permute(0, 2, 1)
        _, (hn, _) = self.lstm(x)
        return self.fc(hn[-1]).squeeze()


# **Step 11: Pretraining on PTB-XL**

In [45]:
X_tr, X_val, y_tr, y_val = train_test_split(
    X_ptb, y_ptb, test_size=0.2, stratify=y_ptb, random_state=42
)

# train_ds = ECGDataset(X_ptb, y_ptb)
train_ds = ECGDataset(X_tr, y_tr)
val_ds   = ECGDataset(X_val, y_val)

# train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
# val_loader   = DataLoader(val_ds, batch_size=64, shuffle=False)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)

model = CNN_LSTM()
# criterion = nn.BCELoss()
# criterion = nn.BCEWithLogitsLoss()
pos_weight = torch.tensor(
    (len(y_ptb) - np.sum(y_ptb)) / np.sum(y_ptb),
    dtype=torch.float32
)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

start_epoch = 0
if os.path.exists(PTB_CKPT):
    ckpt = torch.load(PTB_CKPT)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["opt"])
    start_epoch = ckpt["epoch"] + 1
    print("Resuming PTB-XL training from epoch", start_epoch)

for epoch in range(start_epoch, 5):
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()

    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "opt": optimizer.state_dict()
    }, PTB_CKPT)

    print(f"PTB-XL Epoch {epoch+1}, Loss: {loss.item():.4f}")

    model.eval()
    val_logits, val_labels = [], []

    with torch.no_grad():
        for x, y in val_loader:
            val_logits.append(model(x))
            val_labels.append(y)

    val_logits = torch.cat(val_logits)
    val_labels = torch.cat(val_labels)
    val_auc = roc_auc_score(
        val_labels.cpu().numpy(),
        torch.sigmoid(val_logits).cpu().numpy()
    )

    print("Validation AUC:", round(val_auc, 3))



PTB-XL Epoch 1, Loss: 1.4557
Validation AUC: 0.648
PTB-XL Epoch 2, Loss: 1.4415
Validation AUC: 0.684
PTB-XL Epoch 3, Loss: 1.2879
Validation AUC: 0.699
PTB-XL Epoch 4, Loss: 1.5478
Validation AUC: 0.682
PTB-XL Epoch 5, Loss: 1.4474
Validation AUC: 0.697


# **Step 12: Freeze CNN (Transfer Learning)**

In [46]:
for p in model.cnn.parameters():
    p.requires_grad = False

# for name, p in model.cnn.named_parameters():
#     if "0" in name:  # first conv only
#         p.requires_grad = False



# **Step 13: Load European ST-T**

In [47]:
def load_edb_record(name):
    rec = wfdb.rdrecord(os.path.join(EDB_PATH, name))
    ann = wfdb.rdann(os.path.join(EDB_PATH, name), 'atr')

    signal = normalize(bandpass_filter(rec.p_signal))
    segments = segment_signal(signal)
    labels = np.zeros(len(segments))

    for s in ann.sample:
        idx = s // 1250
        if idx < len(labels):
            labels[idx] = 1

    return segments, labels

if os.path.exists(EDB_X) and os.path.exists(EDB_Y):
    print("Loading cached European ST-T...")
    X_edb = np.load(EDB_X)
    y_edb = np.load(EDB_Y)
else:
    X_edb, y_edb = [], []
    files = [f for f in os.listdir(EDB_PATH) if f.endswith(".hea")]
    for f in files:
        s, l = load_edb_record(f.replace(".hea", ""))
        X_edb.extend(s)
        y_edb.extend(l)
    X_edb = np.array(X_edb)
    y_edb = np.array(y_edb)
    np.save(EDB_X, X_edb)
    np.save(EDB_Y, y_edb)


# **Step 14: Fine-Tune on European ST-T**

In [48]:
# Split into training and validation
X_train, X_val, y_train, y_val = train_test_split(
    X_edb, y_edb, test_size=0.2, stratify=y_edb, random_state=42
)

edb_ds = ECGDataset(X_train, y_train)
val_ds   = ECGDataset(X_val, y_val)

edb_loader = DataLoader(edb_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)

# edb_ds = ECGDataset((X_edb), (y_edb))
# edb_loader = DataLoader(edb_ds, batch_size=32, shuffle=True)

# Freeze the entire CNN backbone
for p in model.cnn.parameters():
    p.requires_grad = False

# Only fine-tune LSTM + FC layers
optimizer = torch.optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=1e-5
)

# optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4)

# BCEWithLogitsLoss with pos_weight for imbalance
pos_weight = torch.tensor((len(y_train) - np.sum(y_train)) / np.sum(y_train), dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

start_epoch = 0
if os.path.exists(EDB_CKPT):
    ckpt = torch.load(EDB_CKPT)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["opt"])
    start_epoch = ckpt["epoch"] + 1
    print("Resuming EDB training from epoch", start_epoch)

# for epoch in range(start_epoch, 5):
#     model.train()
#     for x, y in edb_loader:
#         optimizer.zero_grad()
#         loss = criterion(model(x), y)
#         loss.backward()
#         optimizer.step()

#     torch.save({
#         "epoch": epoch,
#         "model": model.state_dict(),
#         "opt": optimizer.state_dict()
#     }, EDB_CKPT)

#     print(f"EDB Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Fine-tuning with validation
best_val_loss = float('inf')
for epoch in range(start_epoch, 5):
    model.train()
    train_loss = 0.0

    for x, y in train_loader:
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * len(y)

    train_loss /= len(train_ds)

    # Validation
    model.eval()
    val_loss = 0.0
    val_preds, val_labels = [], []
    with torch.no_grad():
        for x, y in val_loader:
            logits = model(x)
            loss = criterion(logits, y)
            val_loss += loss.item() * len(y)
            val_preds.append(torch.sigmoid(logits))
            val_labels.append(y)
    val_loss /= len(val_ds)
    val_preds = torch.cat(val_preds)
    val_labels = torch.cat(val_labels)
    val_auc = roc_auc_score(val_labels.cpu().numpy(), val_preds.cpu().numpy())

    print(f"EDB Epoch {epoch+1} → Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val AUC: {val_auc:.3f}")

    # Save checkpoint if validation improves
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            "epoch": epoch,
            "model": model.state_dict(),
            "opt": optimizer.state_dict()
        }, EDB_CKPT)
        print("Saved best model checkpoint.")


EDB Epoch 1 → Train Loss: 0.6938, Val Loss: 0.0001, Val AUC: 0.001
Saved best model checkpoint.
EDB Epoch 2 → Train Loss: 0.6819, Val Loss: 0.0001, Val AUC: 0.001
EDB Epoch 3 → Train Loss: 0.6723, Val Loss: 0.0001, Val AUC: 0.001
EDB Epoch 4 → Train Loss: 0.6646, Val Loss: 0.0001, Val AUC: 0.001
EDB Epoch 5 → Train Loss: 0.6584, Val Loss: 0.0001, Val AUC: 0.001


# **Step 15: Inference**

In [49]:
# def predict_ecg(signal):
#     signal = bandpass_filter(signal)
#     signal = normalize(signal)
#     segments = segment_signal(signal)

#     model.eval()
#     preds = []
#     with torch.no_grad():
#         for seg in segments:
#             preds.append(model(torch.tensor(seg).unsqueeze(0).float()).item())
#     return np.mean(preds)

def predict_ecg(signal):
    signal = normalize(bandpass_filter(signal))
    segments = segment_signal(signal)

    model.eval()
    probs = []

    with torch.no_grad():
        for seg in segments:
            logit = model(torch.tensor(seg).unsqueeze(0).float())
            prob = torch.sigmoid(logit).item()
            probs.append(prob)

    return float(np.mean(probs))



# **Step 16: Testing CAD**

In [65]:
test_record = "e0611"
record = wfdb.rdrecord(os.path.join(EDB_PATH, test_record))

test_signal = record.p_signal[:, :2]  # Use first 2 leads
score = predict_ecg(test_signal)

print("CAD Probability:", round(score, 3))
confidence = score * 100
print(f"Confidence: {confidence:.1f}%")

if score >= 0.5:
    print("Result: CAD (Ischemia) Detected")
else:
    print("Result: No CAD Detected")


CAD Probability: 0.519
Confidence: 51.9%
Result: CAD (Ischemia) Detected


In [60]:
for f in ["e0207", "e0103", "e0110"]:
    record = wfdb.rdrecord(os.path.join(EDB_PATH, f))
    score = predict_ecg(record.p_signal[:, :2])  # first 2 leads
    print(f"{f} → CAD Probability: {score:.3f}")
    if score >= 0.5:
        print("Result: CAD (Ischemia) Detected")
    else:
        print("Result: No CAD")

e0207 → CAD Probability: 0.514
Result: CAD (Ischemia) Detected
e0103 → CAD Probability: 0.509
Result: CAD (Ischemia) Detected
e0110 → CAD Probability: 0.512
Result: CAD (Ischemia) Detected


In [52]:
for f in ["e0207", "e0103", "e0110"]:
    record = wfdb.rdrecord(os.path.join(EDB_PATH, f))
    score = predict_ecg(record.p_signal[:, :2])
    print(f, "→", round(score, 3))


e0207 → 0.514
e0103 → 0.509
e0110 → 0.512


# **Evaluations**

In [101]:
def evaluate_roc_auc(model, loader):
    model.eval()
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for x, y in loader:
            logits = model(x)
            probs = torch.sigmoid(logits)  # IMPORTANT
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())

    auc = roc_auc_score(all_labels, all_probs)
    return auc

In [102]:
auc = evaluate_roc_auc(model, edb_loader)
print("European ST-T ROC-AUC:", round(auc, 4))


European ST-T ROC-AUC: 0.9369
