In [21]:
#1. IMPORTS AND PATHS
import os, ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import wfdb
from PIL import Image
import glob

import matplotlib
matplotlib.use("Agg")   # headless; prevents notebook/GUI rendering
import matplotlib.pyplot as plt
plt.ioff()              # just in case; disables interactive mode

current_dir = os.getcwd()

#set based folder
DB = os.path.join(current_dir, "ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3")
OUT_ROOT = os.path.join(current_dir, "ecg_images")

print(f"Dataset Path: {DB}")
print(f"Output Path: {OUT_ROOT}")

FNAME_COL = "filename_lr"                      # use low-rate (100 Hz) files
SR = 100      # sampling rate for filename_lr
# 1600 x 1200
TARGET_SIZE = (800, 600)                       # output image width × height (pixels)

os.makedirs(OUT_ROOT, exist_ok=True)


#2. READ METADATA CSVs
df = pd.read_csv(os.path.join(DB, "ptbxl_database.csv"))
scp = pd.read_csv(os.path.join(DB, "scp_statements.csv"), index_col=0)

# keep only diagnostic statements
diagnostic_codes = scp[scp["diagnostic"] == 1]


#3. MAP SCP CODES TO DIAGNOSTIC SUPERCLASS
def to_superclasses(scp_codes_str):
    codes = ast.literal_eval(scp_codes_str)  # dict: code → weight
    diags = [c for c in codes.keys() if c in diagnostic_codes.index]
    supers = sorted({diagnostic_codes.loc[c, "diagnostic_class"] for c in diags})
    return supers

df["superclasses"] = df["scp_codes"].apply(to_superclasses)


#4. OFFICIAL STRATIFIED FOLDS
train_df = df[df["strat_fold"].isin(range(1, 9))].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()

#5. SINGLE-LABEL PRIMARY CLASS PER RECORD
# basically if an entry has multiple super classes, pick one based on the priority shown below
PRIORITY = ["MI", "STTC", "HYP", "CD", "NORM"]

# function that picks a super class 
def choose_primary_superclass(superclasses):
    if not superclasses:
        return None
    for c in PRIORITY:
        if c in superclasses:
            return c
    return superclasses[0]

# create a new column called primary class and apply the result of the super class function to it
for split in (train_df, val_df, test_df):
    split["primary_class"] = split["superclasses"].apply(choose_primary_superclass)

# drop the rows that have no primary class 
train_df = train_df.dropna(subset=["primary_class"])
val_df   = val_df.dropna(subset=["primary_class"])
test_df  = test_df.dropna(subset=["primary_class"])


#6. HELPERS FOR READING WFDB AND SAVING PLOTS
def load_signal_and_leads(rec_rel_path, base_dir=DB):
    """Read WFDB record. Returns (signal[T,12], lead_names[list])."""
    rec_path = os.path.join(base_dir, rec_rel_path)
    sig, meta = wfdb.rdsamp(rec_path)
    names = list(meta.sig_name) if hasattr(meta, "sig_name") else [f"Lead{i+1}" for i in range(sig.shape[1])]
    return sig.astype("float32"), names

import gc

def save_12lead_strip(signal, lead_names, out_path, sr=SR, target_size=TARGET_SIZE):
    """
    Plot 12 leads in a 3×4 grid and save as PNG/JPG without blowing RAM.
    """
    T, C = signal.shape

    # Keep the pixel buffer bounded: (figsize * dpi) ≈ target_size
    fig_w, fig_h = 10, 6
    dpi = max(72, min(target_size[0]/fig_w, target_size[1]/fig_h))

    fig, axes = plt.subplots(3, 4, figsize=(fig_w, fig_h), dpi=dpi)
    axes = axes.ravel()
    try:
        t = np.arange(T, dtype=np.float32) / float(sr)
        for i in range(min(C, 12)):
            ax = axes[i]
            ax.plot(t, signal[:, i], linewidth=0.8)
            ax.set_xlim(t[0], t[-1])
            ax.set_ylim(-3.0, 3.0) #added this to accomodate large hypertrophy spikes without clipping, small signals look small, big signals look big
            ax.axis("off")
        for j in range(C, len(axes)):
            axes[j].axis("off")

        # Avoid heavy auto-layout work on huge batches
        # plt.tight_layout(pad=0.15)   # optional; comment out for speed
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        fig.savefig(out_path, bbox_inches="tight", pad_inches=0.03)
    finally:
        plt.close(fig)     # CRITICAL: frees Agg buffer
        del fig, axes      # drop references

def export_split_images(split_df, split_name, limit=None, gc_every=200):
    """Save ECG plots into OUT_ROOT/split_name/<class>/<ecg_id>.jpg."""
    saved = 0
    root = os.path.join(OUT_ROOT, split_name)
    for idx, row in tqdm(split_df.iterrows(), total=len(split_df), desc=f"Export {split_name}"):
        if limit and saved >= limit:
            break

        label = row["primary_class"]
        if not label:
            continue

        try:
            signal, leads = load_signal_and_leads(row[FNAME_COL])
            ecg_id = int(row["ecg_id"]) if "ecg_id" in row.index else idx
            out_path = os.path.join(root, label, f"{ecg_id}.jpg")
            save_12lead_strip(signal, leads, out_path)
            saved += 1
        except Exception:
            # optional: log the path or index here
            pass
        finally:
            # Make sure big arrays are eligible for GC immediately
            try:
                del signal, leads
            except NameError:
                pass

        if saved % gc_every == 0:
            gc.collect()

    print(f"[{split_name}] saved {saved} images → {root}")

#7. RUN EXPORT
export_split_images(train_df, "train", limit=None) #only run this for first time generation, else comment out
export_split_images(val_df,   "val", limit=None)
export_split_images(test_df,  "test",limit=None)

Dataset Path: /home/user/21012125/Capstone2/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3
Output Path: /home/user/21012125/Capstone2/ecg_images


Export train: 100%|██████████| 17084/17084 [28:32<00:00,  9.98it/s]


[train] saved 17084 images → /home/user/21012125/Capstone2/ecg_images/train


Export val: 100%|██████████| 2146/2146 [03:34<00:00, 10.01it/s]


[val] saved 2146 images → /home/user/21012125/Capstone2/ecg_images/val


Export test: 100%|██████████| 2158/2158 [03:35<00:00, 10.01it/s]

[test] saved 2158 images → /home/user/21012125/Capstone2/ecg_images/test





In [22]:
#8 COMPUTE CLASS WEIGHTS 
CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]  # same order as model output

from collections import Counter
import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

#compute pos weight from train only
pos_counts = Counter()
total_train = len(train_df)

for supers in train_df["superclasses"]:
    s = set(supers)
    for cls in CLASS_ORDER:
        if cls in s:
            pos_counts[cls] += 1

# PyTorch expects per-class pos_weight = N_neg / N_pos
pos_weight_np = []
for cls in CLASS_ORDER:
    p = max(1, pos_counts[cls])  # avoid divide-by-zero
    n = total_train - p
    pos_weight_np.append(n / p)

#manually boosting weights for "HYP" because it is performing poorly (0.49)
print(f"Original HYP weight: {pos_weight_np[3]:.2f}")
pos_weight_np[3] *= 2  # Increase penalty for missing HYP 
print(f"Boosted HYP weight:  {pos_weight_np[3]:.2f}")

pos_weight = torch.tensor(pos_weight_np, dtype=torch.float32).to(device)
print("pos_weight (train-only):", dict(zip(CLASS_ORDER, [float(x) for x in pos_weight_np])))

example_weights = []
for supers in train_df["superclasses"]:
    labels = [cls for cls in CLASS_ORDER if cls in supers]
    if not labels:
        w = 1.0
    else:
        w = float(np.mean([pos_weight_np[CLASS_ORDER.index(cls)] for cls in labels]))
    example_weights.append(w)

from torch.utils.data import WeightedRandomSampler
train_sampler = WeightedRandomSampler(
    weights=torch.tensor(example_weights, dtype=torch.double),
    num_samples=len(example_weights),
    replacement=True,
)

Original HYP weight: 7.06
Boosted HYP weight:  14.12
pos_weight (train-only): {'NORM': 1.2490784623486044, 'MI': 2.9013473395752456, 'STTC': 3.0812231247013857, 'HYP': 14.124587069372346, 'CD': 3.372664448425902}


In [23]:
import os, json
import pandas as pd

CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]

def to_multihot(superclasses):
    s = set(superclasses)
    return {c: int(c in s) for c in CLASS_ORDER}

def build_labels_csv_from_existing(split_df, split_name, out_root=OUT_ROOT):
    """
    create scvs with multihots corresponding to each ecg image, this data should be ready
    to be passed as "tensors" into the neural network

    """

    # empty rows to store all entries
    rows = []
    # the directory of the split (train/val/test)
    split_dir = os.path.join(out_root, split_name)
    # for each entry in the dataframe [id, name, superclass, primaryclass]
    for _, r in split_df.iterrows():
        # we get the super class
        supers = r["superclasses"]          # e.g., ['CD','HYP']
        # supers = NORM
        # if no superclass, then this entry is meaningless
        if not supers:
            continue
        # get primary class
        
        primary = r["primary_class"]
        # get the id of the ecg
        ecg_id = int(r["ecg_id"])
        # find the image that we saved that corresponds to the entry we're lookniga t right now
        # train / NORM / 1.png
        img_path = os.path.join(split_dir, primary, f"{ecg_id}.jpg")
        if not os.path.exists(img_path):
            # might not exist if you used a small 'limit' during export
            continue
        # creates a multihot row
        mh = to_multihot(supers)
        row = {
            "image_path": img_path.replace("\\", "/"),
            "labels": json.dumps(sorted(supers))
        }   
        # add the columns
        row.update(mh)                      # add NORM/MI/STTC/HYP/CD columns
        # add the rows
        rows.append(row)
        # we're essentially building a dataframe that has [imagepath, superclass, [superclasses_multihot]]

    df_out = pd.DataFrame(rows)
    out_csv = os.path.join(out_root, f"{split_name}_labels.csv")
    df_out.to_csv(out_csv, index=False)
    print(f"Wrote {len(df_out)} rows → {out_csv}")
    return out_csv

# build csvs for all splits
train_csv = build_labels_csv_from_existing(train_df, "train")
val_csv   = build_labels_csv_from_existing(val_df,   "val")
test_csv  = build_labels_csv_from_existing(test_df,  "test")

Wrote 17084 rows → /home/user/21012125/Capstone2/ecg_images/train_labels.csv
Wrote 2146 rows → /home/user/21012125/Capstone2/ecg_images/val_labels.csv
Wrote 2158 rows → /home/user/21012125/Capstone2/ecg_images/test_labels.csv


In [24]:
import os, json
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models

from sklearn.metrics import f1_score, precision_recall_fscore_support, roc_auc_score

OUT_ROOT = os.path.join(current_dir, "ecg_images")
train_csv = os.path.join(OUT_ROOT, "train_labels.csv")
val_csv   = os.path.join(OUT_ROOT, "val_labels.csv")
test_csv  = os.path.join(OUT_ROOT, "test_labels.csv")

# Class order used everywhere
CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


device(type='cuda')

In [25]:
# Transforms & DataLoaders 
IMG_SIZE = 512  # increased to 512 to capture exact height of spikes
IMG_H = 512
IMG_W = 512
from torchvision import transforms as T  # unify alias as T
from ecg_utils import MultiLabelECGImages

# Train transforms:
train_transforms = transforms.Compose([
    transforms.Resize((IMG_H, IMG_W)),
    
    # CHANGED: Added TrivialAugmentWide.
    transforms.TrivialAugmentWide(),
    
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.15)), 
])


# Validation transforms: deterministic (no random augments)
val_transforms = T.Compose([
    T.Resize((IMG_H, IMG_W)),  # match train size / your model input
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
])

# Test/eval transforms: same as validation
eval_transforms = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std =[0.229, 0.224, 0.225]),
])

# Datasets
train_ds = MultiLabelECGImages(train_csv, transform=train_transforms)
val_ds   = MultiLabelECGImages(val_csv,   transform=val_transforms)
test_ds  = MultiLabelECGImages(test_csv,  transform=eval_transforms)

# DataLoaders
BATCH_SIZE = 32
NUM_WORKERS = 8 

# Use the balanced sampler if you created 'train_sampler' earlier.
# If you did NOT create it, set sampler=None and use shuffle=True.
train_dl = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,        # <- FIXED name
    sampler=train_sampler,        # <- requires you defined train_sampler; else replace with shuffle=True
    # shuffle=True,               # (use this instead of sampler if you don't have train_sampler)
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

val_dl  = DataLoader(val_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(len(train_ds), len(val_ds), len(test_ds))


17084 2146 2158


In [26]:
import torch.cuda.amp as amp

# Create the scaler 
scaler = amp.GradScaler(enabled=(device.type == "cuda"))

def run_epoch(model, loader, optimizer=None, scheduler=None):
    train_mode = optimizer is not None
    model.train() if train_mode else model.eval()

    total_loss = 0.0
    all_targets, all_probs = [], []

    # Iterate over batches
    for imgs, targets in loader:
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        if train_mode:
            optimizer.zero_grad(set_to_none=True) # Slightly faster than zero_grad()

        # --- SPEEDUP FIX: Use Automatic Mixed Precision (AMP) ---
        with torch.amp.autocast(device_type='cuda', enabled=(device.type == "cuda")):
            logits = model(imgs)
            loss = loss_function(logits, targets)

        if train_mode:
            # Use scaler to handle the lower precision math
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            if scheduler is not None:
                scheduler.step()

        total_loss += loss.item() * imgs.size(0)

        # Save predictions (move to CPU to save GPU RAM)
        with torch.no_grad():
            probs = torch.sigmoid(logits).detach().cpu().numpy()
            all_targets.append(targets.detach().cpu().numpy())
            all_probs.append(probs)

    # Combine all batches
    all_targets = np.concatenate(all_targets, axis=0)
    all_probs = np.concatenate(all_probs, axis=0)
    avg_loss = total_loss / len(loader.dataset)

    return avg_loss, all_targets, all_probs

def multilabel_metrics(y_true, y_prob, threshold=0.5):
    y_pred = (y_prob >= threshold).astype(int)
    # Per-class precision/recall/F1
    # prec, rec, f1, support (no. occurences of each label in y_true)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    macro_f1 = f1.mean()
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    # AUROC per class (guard against classes with single label value)
    
    aurocs = []
    for i in range(y_true.shape[1]):
        if len(np.unique(y_true[:, i])) == 2:
            aurocs.append(roc_auc_score(y_true[:, i], y_prob[:, i]))
        else:
            aurocs.append(float('nan'))
    return {
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "per_class_f1": dict(zip(CLASS_ORDER, f1)),
        "per_class_precision": dict(zip(CLASS_ORDER, prec)),
        "per_class_recall": dict(zip(CLASS_ORDER, rec)),
        "per_class_auroc": dict(zip(CLASS_ORDER, aurocs)),
    }

  scaler = amp.GradScaler(enabled=(device.type == "cuda"))


In [27]:
# ===== Model & Loss =====
import copy
import matplotlib.pyplot as plt # Ensure we have this for plotting

class AsymmetricLoss(nn.Module):
    def __init__(self, gamma_pos=0.0, gamma_neg=4.0, clip=0.05, eps=1e-8):
        super().__init__()
        self.gamma_pos, self.gamma_neg, self.clip, self.eps = gamma_pos, gamma_neg, clip, eps

    def forward(self, logits, targets):
        x_sig = torch.sigmoid(logits)
        if self.clip and self.clip > 0:
            x_sig = torch.clamp(x_sig, self.clip, 1.0 - self.clip)
        xs_pos = x_sig
        xs_neg = 1.0 - x_sig
        w = (1 - xs_pos) ** self.gamma_pos * targets + (1 - xs_neg) ** self.gamma_neg * (1 - targets)
        loss = - (targets * torch.log(xs_pos + self.eps) + (1 - targets) * torch.log(xs_neg + self.eps))
        return (w * loss).mean()

loss_function = AsymmetricLoss(gamma_pos=0.0, gamma_neg=4.0, clip=0.05)

# # ===== ResNet34 Backbone =====
from torchvision import models

# # backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)  #restnet50 was too big, switching to resnet34
# backbone = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
 
# in_features = backbone.fc.in_features
# backbone.fc = nn.Sequential(
#     nn.Dropout(p=0.5),
#     nn.Linear(in_features, len(CLASS_ORDER))
# )
# model = backbone.to(device)


# ===== DenseNet121 Backbone ===== 
backbone = models.densenet121(weights=models.DenseNet121_Weights.IMAGENET1K_V1)

in_features = backbone.classifier.in_features
backbone.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(in_features, len(CLASS_ORDER))
)

model = backbone.to(device)

# ===== Training setup: 2-phase (warmup + fine-tune) =====
WARMUP_EPOCHS = 5
TOTAL_EPOCHS  = 20

# Lists to store loss history for plotting 
train_losses = []
val_losses = []

# ---------- PHASE 1: freeze backbone, train only the head ----------
for name, p in model.named_parameters():
    if name.startswith("classifier."):     # final FC layer of ResNet18
        p.requires_grad = True
    else:
        p.requires_grad = False    # backbone frozen

# Only head params will be trainable
trainable_params = [p for p in model.parameters() if p.requires_grad]

optimizer = torch.optim.AdamW(
    trainable_params,
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    steps_per_epoch=len(train_dl),
    epochs=WARMUP_EPOCHS,
    pct_start=0.1,
    anneal_strategy="cos",
    div_factor=10.0,
    final_div_factor=10.0,
)

best_val_macro = 0.0
best_state_dict = None

for epoch in range(1, WARMUP_EPOCHS + 1):
    print(f"\n[Warmup] Epoch {epoch}/{WARMUP_EPOCHS}")

    # IMPORTANT: pass scheduler into run_epoch so it can step per batch
    train_loss, y_true_tr, y_prob_tr = run_epoch(model, train_dl, optimizer, scheduler)
    val_loss, y_true_val, y_prob_val = run_epoch(model, val_dl, optimizer=None)
    
    #  Save history
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    metrics_val = multilabel_metrics(y_true_val, y_prob_val, threshold=0.5)
    macro_f1_val  = metrics_val["macro_f1"]
    micro_f1_val  = metrics_val["micro_f1"]
    per_class_au  = metrics_val["per_class_auroc"]  # this is a dict by class
    
    print(f"Val macro-F1: {macro_f1_val:.4f}")
    print(f"Val micro-F1: {micro_f1_val:.4f}")

    if macro_f1_val > best_val_macro:
        best_val_macro = macro_f1_val
        best_state_dict = copy.deepcopy(model.state_dict())


# ---------- PHASE 2: unfreeze and fine-tune entire network ----------
for p in model.parameters():
    p.requires_grad = True

backbone_params = []
head_params = []
for name, p in model.named_parameters():
    if not p.requires_grad:
        continue
    if name.startswith("classifier."):
        head_params.append(p)       # final classifier layer
    else:
        backbone_params.append(p)   # pretrained conv backbone

optimizer = torch.optim.AdamW(
    [
        {"params": backbone_params, "lr": 5e-5, "weight_decay": 0.01},  # smaller LR
        {"params": head_params,     "lr": 1e-4, "weight_decay": 0.01},  # larger LR
    ]
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=5,      
    T_mult=1,
    eta_min=1e-6
)

for epoch in range(WARMUP_EPOCHS + 1, TOTAL_EPOCHS + 1):
    print(f"\n[Fine-tune] Epoch {epoch}/{TOTAL_EPOCHS}")

    train_loss, y_true_tr, y_prob_tr = run_epoch(model, train_dl, optimizer, scheduler)
    val_loss, y_true_val, y_prob_val = run_epoch(model, val_dl, optimizer=None)
    
    # ### NEW CODE: Save history ###
    train_losses.append(train_loss)
    val_losses.append(val_loss)

    metrics_val = multilabel_metrics(y_true_val, y_prob_val, threshold=0.5)
    macro_f1_val  = metrics_val["macro_f1"]
    micro_f1_val  = metrics_val["micro_f1"]
    per_class_au  = metrics_val["per_class_auroc"]  # this is a dict by class
    
    print(f"Val macro-F1: {macro_f1_val:.4f}")
    print(f"Val micro-F1: {micro_f1_val:.4f}")

    if macro_f1_val > best_val_macro:
        best_val_macro = macro_f1_val
        best_state_dict = copy.deepcopy(model.state_dict())

# After training: restore best weights before doing VAL threshold search / TEST
model.load_state_dict(best_state_dict)

#save to disk
save_path = os.path.join(OUT_ROOT, "best_model_densenet121.pth")
torch.save(model.state_dict(), save_path)
print(f"Saved model to: {save_path}")

# Plotting the graph
print("\nGenerating Loss Graph...")
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss', color='blue', marker='o')
plt.plot(val_losses, label='Validation Loss', color='orange', marker='o')
plt.title('Training vs Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Save the plot to the output folder
graph_path = os.path.join(OUT_ROOT, "loss_graph.png")
plt.savefig(graph_path)
print(f"Graph saved to: {graph_path}")
print("Please open this file in your file browser to view the graph.")


[Warmup] Epoch 1/5
Val macro-F1: 0.4652
Val micro-F1: 0.4744

[Warmup] Epoch 2/5
Val macro-F1: 0.4535
Val micro-F1: 0.4521

[Warmup] Epoch 3/5
Val macro-F1: 0.4684
Val micro-F1: 0.4759

[Warmup] Epoch 4/5
Val macro-F1: 0.4715
Val micro-F1: 0.4786

[Warmup] Epoch 5/5
Val macro-F1: 0.4645
Val micro-F1: 0.4698

[Fine-tune] Epoch 6/20
Val macro-F1: 0.6265
Val micro-F1: 0.6579

[Fine-tune] Epoch 7/20
Val macro-F1: 0.6449
Val micro-F1: 0.6731

[Fine-tune] Epoch 8/20
Val macro-F1: 0.6573
Val micro-F1: 0.6841

[Fine-tune] Epoch 9/20
Val macro-F1: 0.6460
Val micro-F1: 0.6702

[Fine-tune] Epoch 10/20
Val macro-F1: 0.6763
Val micro-F1: 0.7062

[Fine-tune] Epoch 11/20
Val macro-F1: 0.6652
Val micro-F1: 0.6924

[Fine-tune] Epoch 12/20
Val macro-F1: 0.6802
Val micro-F1: 0.7086

[Fine-tune] Epoch 13/20
Val macro-F1: 0.6802
Val micro-F1: 0.7095

[Fine-tune] Epoch 14/20
Val macro-F1: 0.6877
Val micro-F1: 0.7177

[Fine-tune] Epoch 15/20
Val macro-F1: 0.6864
Val micro-F1: 0.7156

[Fine-tune] Epoch 16/20

In [None]:
# RELOAD SAVED MODEL (SKIP TRAINING)
# Run this cell to skip training and load best model

model_path = os.path.join(OUT_ROOT, "best_model_densenet121.pth")

if os.path.exists(model_path):
    print(f"Found saved model at {model_path}. Loading...")
    
    # 1. Re-initialize the empty model structure 
    # (Since we ran Model Definition earlier, 'model' variable already exists)
    
    # 2. Load the weights from the file
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    
    model.to(device)
    model.eval() # Set to evaluation mode
    print("Model loaded successfully! Ready for diagnosis/testing.")
else:
    print("No saved model found. You must run the training cell first.")

In [28]:
train_loss_best, y_true_tr_best, y_prob_tr_best = run_epoch(model, train_dl, optimizer=None)
val_loss_best,   y_true_val_best, y_prob_val_best = run_epoch(model, val_dl, optimizer=None)

metrics_train = multilabel_metrics(y_true_tr_best, y_prob_tr_best, threshold=0.5)
metrics_val   = multilabel_metrics(y_true_val_best, y_prob_val_best, threshold=0.5)

print("Train macro-F1:", metrics_train["macro_f1"])
print("Val   macro-F1:", metrics_val["macro_f1"])
print("Val   micro-F1:", metrics_val["micro_f1"])

Train macro-F1: 0.7455217979270218
Val   macro-F1: 0.6994250161227579
Val   micro-F1: 0.7299802461631971


In [35]:
from sklearn.metrics import f1_score, precision_recall_fscore_support
import numpy as np

def best_thresholds_per_class(y_true_val, y_prob_val, grid=np.linspace(0.05, 0.95, 19)):
    C = y_true_val.shape[1]
    best = np.full(C, 0.5, dtype=np.float32)
    for c in range(C):
        yt, yp = y_true_val[:, c], y_prob_val[:, c]
        if len(np.unique(yt)) < 2:   # cannot tune if constant
            continue
        f1_best, thr_best = -1.0, 0.5
        for t in grid:
            f1 = f1_score(yt, (yp >= t).astype(int), zero_division=0)
            if f1 > f1_best:
                f1_best, thr_best = f1, t
        best[c] = thr_best
    return best

# ---- VAL: find best threshold per class ----
val_loss, y_true_val, y_prob_val = run_epoch(model, val_dl, optimizer=None)
per_class_thr = best_thresholds_per_class(y_true_val, y_prob_val)
print("Per-class thresholds (VAL, CLASS_ORDER):", dict(zip(CLASS_ORDER, per_class_thr)))

# ---- TEST: evaluate with base 0.5 threshold ----
test_loss, y_true_test, y_prob_test = run_epoch(model, test_dl, optimizer=None)

metrics_test = multilabel_metrics(y_true_test, y_prob_test, threshold=0.5)
macro_f1_05  = metrics_test["macro_f1"]
micro_f1_05  = metrics_test["micro_f1"]

print(f"\n[Test @ 0.5] Macro-F1: {macro_f1_05:.4f}")
print(f"[Test @ 0.5] Micro-F1: {micro_f1_05:.4f}")
print("Per-class F1 @ 0.5:", metrics_test["per_class_f1"])
print("Per-class AUROC:", metrics_test["per_class_auroc"])

# ---- TEST: evaluate with tuned thresholds from VAL ----
# per_class_thr is already aligned with CLASS_ORDER, same as y_prob_test columns
y_pred_test_tuned = (y_prob_test >= per_class_thr[None, :]).astype(int)

prec, rec, f1, _ = precision_recall_fscore_support(
    y_true_test, y_pred_test_tuned, average=None, zero_division=0
)

per_class_f1_tuned = dict(zip(CLASS_ORDER, f1))
macro_f1_tuned = float(np.mean(f1))

print(f"\n[Test @ tuned thresholds] Macro-F1 (tuned): {macro_f1_tuned:.4f}")
print("Per-class F1 (tuned):", per_class_f1_tuned)
print("Per-class thresholds (CLASS_ORDER):", dict(zip(CLASS_ORDER, per_class_thr)))


Per-class thresholds (VAL, CLASS_ORDER): {'NORM': np.float32(0.6), 'MI': np.float32(0.65), 'STTC': np.float32(0.65), 'HYP': np.float32(0.6), 'CD': np.float32(0.6)}

[Test @ 0.5] Macro-F1: 0.7016
[Test @ 0.5] Micro-F1: 0.7325
Per-class F1 @ 0.5: {'NORM': np.float64(0.8639391056137012), 'MI': np.float64(0.676056338028169), 'STTC': np.float64(0.709628506444276), 'HYP': np.float64(0.5671232876712329), 'CD': np.float64(0.6911764705882353)}
Per-class AUROC: {'NORM': 0.9414725600351066, 'MI': 0.8931829488919041, 'STTC': 0.9274719566830856, 'HYP': 0.9105781154378845, 'CD': 0.8835649091650168}

[Test @ tuned thresholds] Macro-F1 (tuned): 0.7248
Per-class F1 (tuned): {'NORM': np.float64(0.8578603716725264), 'MI': np.float64(0.6955736224028907), 'STTC': np.float64(0.746268656716418), 'HYP': np.float64(0.6039076376554174), 'CD': np.float64(0.7201783723522854)}
Per-class thresholds (CLASS_ORDER): {'NORM': np.float32(0.6), 'MI': np.float32(0.65), 'STTC': np.float32(0.65), 'HYP': np.float32(0.6), 'CD

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
FINE tuning:
0. epochs
1. dropout
2. lr
3. weight decay
4. factor (scheduler)
6. patience (scheduler)
7. more feature classes in ConvBlock
8. more layers for convblock
9. image augmentation
10. increase data size, increase image resolution, change to higher bitrate
"""

# Extra metrics at 0.5 threshold 
threshold = 0.5
y_pred_test_05 = (y_prob_test >= threshold).astype(int)
exact_match = np.all(y_pred_test_05 == y_true_test, axis=1)
exact_acc = exact_match.mean() * 100
labelwise_acc = (y_pred_test_05 == y_true_test).mean() * 100

print(f"\nExact-match accuracy @ 0.5: {exact_acc:.2f}%")
print(f"Label-wise mean accuracy @ 0.5: {labelwise_acc:.2f}%")


Exact-match accuracy @ 0.5: 43.61%
Label-wise mean accuracy @ 0.5: 83.49%


In [37]:
threshold = 0.5
y_pred_val_05 = (y_prob_val >= threshold).astype(int)

# Exact-match: all labels for a sample must be correct
exact_match_val = np.all(y_pred_val_05 == y_true_val, axis=1)
exact_acc_val = exact_match_val.mean() * 100

# Label-wise mean accuracy: average over all label positions
labelwise_acc_val = (y_pred_val_05 == y_true_val).mean() * 100

print(f"\n[VAL @ 0.5] Exact-match accuracy: {exact_acc_val:.2f}%")
print(f"[VAL @ 0.5] Label-wise mean accuracy: {labelwise_acc_val:.2f}%")



[VAL @ 0.5] Exact-match accuracy: 44.50%
[VAL @ 0.5] Label-wise mean accuracy: 83.44%


In [38]:
# After compute test_loss, y_true_test, y_prob_test
metrics_test_05 = multilabel_metrics(y_true_test, y_prob_test, threshold=0.5)
print("=== Test @ threshold=0.5 ===")
print("MacroF1:", metrics_test_05["macro_f1"])
print("Per-class F1:", metrics_test_05["per_class_f1"])


=== Test @ threshold=0.5 ===
MacroF1: 0.701584741669123
Per-class F1: {'NORM': np.float64(0.8639391056137012), 'MI': np.float64(0.676056338028169), 'STTC': np.float64(0.709628506444276), 'HYP': np.float64(0.5671232876712329), 'CD': np.float64(0.6911764705882353)}


In [39]:
val_support = y_true_val.sum(axis=0)
test_support = y_true_test.sum(axis=0)
print("VAL positives:", dict(zip(CLASS_ORDER, val_support)))
print("TEST positives:", dict(zip(CLASS_ORDER, test_support)))


VAL positives: {'NORM': np.float32(955.0), 'MI': np.float32(540.0), 'STTC': np.float32(528.0), 'HYP': np.float32(268.0), 'CD': np.float32(495.0)}
TEST positives: {'NORM': np.float32(963.0), 'MI': np.float32(550.0), 'STTC': np.float32(521.0), 'HYP': np.float32(262.0), 'CD': np.float32(496.0)}


In [40]:
def best_global_threshold(y_true, y_prob, grid=np.linspace(0.1, 0.9, 17)):
    best_t, best_f1 = 0.5, -1
    for t in grid:
        f1 = f1_score(y_true, (y_prob >= t).astype(int), average="macro", zero_division=0)
        if f1 > best_f1:
            best_f1, best_t = f1, t
    return best_t, best_f1

t_global, f1_global_val = best_global_threshold(y_true_val, y_prob_val)
print("Best global t on VAL:", t_global, "macroF1:", f1_global_val)

y_pred_test_global = (y_prob_test >= t_global).astype(int)
_, _, f1_test_global, _ = precision_recall_fscore_support(
    y_true_test, y_pred_test_global, average=None, zero_division=0
)
print("MacroF1 TEST @ global t:", float(f1_test_global.mean()))
print("Per-class F1 TEST @ global t:", dict(zip(CLASS_ORDER, f1_test_global)))


Best global t on VAL: 0.6 macroF1: 0.7258694532955733
MacroF1 TEST @ global t: 0.7232039270172734
Per-class F1 TEST @ global t: {'NORM': np.float64(0.8578603716725264), 'MI': np.float64(0.6834170854271356), 'STTC': np.float64(0.7506561679790026), 'HYP': np.float64(0.6039076376554174), 'CD': np.float64(0.7201783723522854)}


In [41]:
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
import torch

def diagnose_hyp_errors(model, loader, threshold=0.5):
    model.eval()
    all_preds = []
    all_targets = []
    
    print("Collecting predictions for diagnosis...")
    with torch.no_grad():
        for imgs, targets in loader:
            imgs = imgs.to(device)
            logits = model(imgs)
            probs = torch.sigmoid(logits)
            
            all_preds.append(probs.cpu().numpy())
            all_targets.append(targets.cpu().numpy())

    y_prob = np.concatenate(all_preds)
    y_true = np.concatenate(all_targets)
    
    # --- 1. Confusion Matrix for HYP ---
    # HYP is index 3 in CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]
    hyp_idx = 3
    hyp_true = y_true[:, hyp_idx]
    hyp_pred = (y_prob[:, hyp_idx] >= threshold).astype(int)
    
    cm = confusion_matrix(hyp_true, hyp_pred)
    
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Predicted No-HYP', 'Predicted HYP'],
                yticklabels=['Actual No-HYP', 'Actual HYP'])
    plt.title('Confusion Matrix for Hypertrophy (HYP)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    # Save the Confusion Matrix
    plt.savefig(os.path.join(OUT_ROOT, "hyp_confusion_matrix.png"))
    plt.show()
    
    # --- 2. Visualize False Negatives (The "Missed" Cases) ---
    # Find indices where True=1 but Pred=0
    fn_indices = np.where((hyp_true == 1) & (hyp_pred == 0))[0]
    
    print(f"Found {len(fn_indices)} HYP cases that were MISSED (False Negatives).")
    print("Showing top 3 with highest confidence in being 'Normal'...")
    
    # Sort by how "wrong" the model was (lowest probability for HYP)
    worst_misses = fn_indices[np.argsort(y_prob[fn_indices, hyp_idx])[:3]]
    
    # We need to fetch images from the dataset wrapper to visualize them
    # Note: We will use the Validation Dataset (val_ds)
    for i, idx in enumerate(worst_misses):
        img_tensor, _ = val_ds[idx] 
        
        # Simple denormalize for visualization (approximate)
        img_disp = img_tensor.permute(1, 2, 0).numpy()
        img_disp = (img_disp * np.array([0.229, 0.224, 0.225])) + np.array([0.485, 0.456, 0.406])
        img_disp = np.clip(img_disp, 0, 1)
        
        plt.figure(figsize=(10, 4))
        plt.imshow(img_disp)
        plt.title(f"Missed HYP Case #{idx} | Model Prob for HYP: {y_prob[idx, hyp_idx]:.4f}")
        plt.axis('off')
        # Save the missed case image
        plt.savefig(os.path.join(OUT_ROOT, f"missed_hyp_case_{i+1}.png"))
        plt.show()

# Run the function using your validation loader
diagnose_hyp_errors(model, val_dl, threshold=0.6)

Collecting predictions for diagnosis...
Found 102 HYP cases that were MISSED (False Negatives).
Showing top 3 with highest confidence in being 'Normal'...


In [19]:
#this is to delete images in ecg_images for resetting

# import shutil

# if os.path.exists(OUT_ROOT):
#     shutil.rmtree(OUT_ROOT)
#     print(f"Deleted {OUT_ROOT}. Ready to regenerate.")
# os.makedirs(OUT_ROOT, exist_ok=True)

Deleted /home/user/21012125/Capstone2/ecg_images. Ready to regenerate.
