In [1]:
import os, ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import wfdb
from PIL import Image
import glob

# set base folder
DB = r"/Users/shayne/Documents/SUNWAY_UNI/sem8/capstone 1/dataset/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" # <- path to the PTB-XL folder
OUT_ROOT = r"/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images"        # png saved here

FNAME_COL = "filename_lr"                      
SR = 100      
# 1600 x 1200
TARGET_SIZE = (800, 600)                       # output image width × height (pixels)

os.makedirs(OUT_ROOT, exist_ok=True)


# read metadata CSVs
df = pd.read_csv(os.path.join(DB, "ptbxl_database.csv"))
scp = pd.read_csv(os.path.join(DB, "scp_statements.csv"), index_col=0)

# keep only diagnostic statements
diagnostic_codes = scp[scp["diagnostic"] == 1]


#map scp_codes convert to diagnostic superclasses
def to_superclasses(scp_codes_str):
    codes = ast.literal_eval(scp_codes_str)  # dict: code → weight
    diags = [c for c in codes.keys() if c in diagnostic_codes.index]
    supers = sorted({diagnostic_codes.loc[c, "diagnostic_class"] for c in diags})
    return supers

df["superclasses"] = df["scp_codes"].apply(to_superclasses)

# official stratified folds

train_df = df[df["strat_fold"].isin(range(1, 9))].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()


#single-label primary class per record
# if an entry has multiple super classes, pick one based on the priority shown below
PRIORITY = ["MI", "STTC", "HYP", "CD", "NORM"]

# function that picks a super class 
def choose_primary_superclass(superclasses):
    if not superclasses:
        return None
    for c in PRIORITY:
        if c in superclasses:
            return c
    return superclasses[0]

# create a new column called primary class and apply the result of the super class function to it
for split in (train_df, val_df, test_df):
    split["primary_class"] = split["superclasses"].apply(choose_primary_superclass)

# drop the rows that have no primary class 
train_df = train_df.dropna(subset=["primary_class"])
val_df   = val_df.dropna(subset=["primary_class"])
test_df  = test_df.dropna(subset=["primary_class"])


# helpers for reading WFDB and saving plots

def load_signal_and_leads(rec_rel_path, base_dir=DB):
    """Read WFDB record. Returns (signal[T,12], lead_names[list])."""
    rec_path = os.path.join(base_dir, rec_rel_path)
    sig, meta = wfdb.rdsamp(rec_path)
    names = list(meta.sig_name) if hasattr(meta, "sig_name") else [f"Lead{i+1}" for i in range(sig.shape[1])]
    return sig.astype("float32"), names
    

# algorithm that turns the waveform data into plots
# 12 leads per image in this case
def save_12lead_strip(signal, lead_names, out_path, sr=SR, target_size=TARGET_SIZE):
    """Plot 12 leads in a 3×4 grid and save as PNG."""
    T, C = signal.shape
    fig_w, fig_h = 10, 6
    dpi = min(target_size[0]/fig_w, target_size[1]/fig_h)

    fig, axes = plt.subplots(3, 4, figsize=(fig_w, fig_h), dpi=dpi)
    axes = axes.ravel()
    t = np.arange(T) / float(sr)

    for i in range(min(C, 12)):
        ax = axes[i]
        ax.plot(t, signal[:, i], linewidth=0.8)
        # ax.set_title(lead_names[i] if i < len(lead_names) else f"Lead {i+1}", fontsize=8)
        ax.set_xlim([t[0], t[-1]])
        ax.axis("off")
    for j in range(C, len(axes)):
        axes[j].axis("off")

    plt.tight_layout(pad=0.15)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", pad_inches=0.03)
    plt.close(fig)

# saves the plots into iamges with good file directory and names
def export_split_images(split_df, split_name, limit=None):
    """Save ECG plots into OUT_ROOT/split_name/<class>/<ecg_id>.png."""
    saved = 0
    for _, row in tqdm(split_df.iterrows(), total=len(split_df), desc=f"Export {split_name}"):
        if limit and saved >= limit:
            break
        label = row["primary_class"]
        if not label:
            continue
        try:
            signal, leads = load_signal_and_leads(row[FNAME_COL])
        except Exception as e:
            # print("Failed:", row[FNAME_COL], e)
            continue

        ecg_id = int(row["ecg_id"]) if "ecg_id" in row else _
        out_path = os.path.join(OUT_ROOT, split_name, label, f"{ecg_id}.png")
        save_12lead_strip(signal, leads, out_path)
        saved += 1
    print(f"[{split_name}] saved {saved} images → {os.path.join(OUT_ROOT, split_name)}")


# run export (try small limits first)
export_split_images(train_df, "train", limit=50)
export_split_images(val_df,   "val",   limit=20)
export_split_images(test_df,  "test",  limit=20)


# preview a few saved images
some = glob.glob(os.path.join(OUT_ROOT, "train", "*", "*.png"))[:5]
for p in some:
    print(p)
    img = Image.open(p)
    img.show()  # opens in default image viewer


Export train:   0%|                          | 50/17084 [00:02<15:19, 18.52it/s]


[train] saved 50 images → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train


Export val:   1%|▎                            | 20/2146 [00:01<01:54, 18.50it/s]


[val] saved 20 images → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/val


Export test:   1%|▎                           | 20/2158 [00:01<01:59, 17.85it/s]


[test] saved 20 images → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/test
/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train/MI/77.png
/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train/MI/50.png
/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train/STTC/22.png
/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train/STTC/54.png
/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train/HYP/45.png


In [2]:
train_df.head(5)

Unnamed: 0,ecg_id,patient_id,age,sex,height,weight,nurse,site,device,recording_date,...,static_noise,burst_noise,electrodes_problems,extra_beats,pacemaker,strat_fold,filename_lr,filename_hr,superclasses,primary_class
0,1,15709.0,56.0,1,,63.0,2.0,0.0,CS-12 E,1984-11-09 09:17:34,...,", I-V1,",,,,,3,records100/00000/00001_lr,records500/00000/00001_hr,[NORM],NORM
1,2,13243.0,19.0,0,,70.0,2.0,0.0,CS-12 E,1984-11-14 12:55:37,...,,,,,,2,records100/00000/00002_lr,records500/00000/00002_hr,[NORM],NORM
2,3,20372.0,37.0,1,,69.0,2.0,0.0,CS-12 E,1984-11-15 12:49:10,...,,,,,,5,records100/00000/00003_lr,records500/00000/00003_hr,[NORM],NORM
3,4,17014.0,24.0,0,,82.0,2.0,0.0,CS-12 E,1984-11-15 13:44:57,...,,,,,,3,records100/00000/00004_lr,records500/00000/00004_hr,[NORM],NORM
4,5,17448.0,19.0,1,,70.0,2.0,0.0,CS-12 E,1984-11-17 10:43:15,...,,,,,,4,records100/00000/00005_lr,records500/00000/00005_hr,[NORM],NORM


In [3]:
import os, json
import pandas as pd

CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]


# maps the class into binary as below
def to_multihot(superclasses):
    s = set(superclasses)
    result = {}

    for c in CLASS_ORDER:
        if c in s:
            result[c] = 1
        else:
            result[c] = 0
    return result

def build_labels_csv_from_existing(split_df, split_name, out_root=OUT_ROOT):
    
    # empty rows to store all entries
    rows = []
    # the directory of the split (train/val/test)
    split_dir = os.path.join(out_root, split_name)
    # for each entry in the dataframe [id, name, superclass, primaryclass]
    for r in split_df.iterrows():
        r = r[1]
        # we get the super class
        supers = r["superclasses"]          # e.g., ['CD','HYP']
        # if no superclass, then this entry is meaningless
        if not supers:
            continue
        # get primary class
        primary = r["primary_class"]
        # get the id of the ecg
        ecg_id = int(r["ecg_id"])
        # find the image that we saved that corresponds to the entry we're lookniga t right now
        img_path = os.path.join(split_dir, primary, f"{ecg_id}.png")
        if not os.path.exists(img_path):
            # might not exist if you used a small 'limit' during export
            continue

        # creates a multihot row
        mh = to_multihot(supers)
        row = {
            "image_path": img_path.replace("\\", "/"),
            "labels": json.dumps(sorted(supers))
        }
        # add the columns
        row.update(mh)                      # add NORM/MI/STTC/HYP/CD columns
        # add the rows
        rows.append(row)
      

    df_out = pd.DataFrame(rows)
    out_csv = os.path.join(out_root, f"{split_name}_labels.csv")
    df_out.to_csv(out_csv, index=False)
    print(f"Wrote {len(df_out)} rows → {out_csv}")
    return out_csv

# build csvs for all splits
train_csv = build_labels_csv_from_existing(train_df, "train")
val_csv   = build_labels_csv_from_existing(val_df,   "val")
test_csv  = build_labels_csv_from_existing(test_df,  "test")


Wrote 50 rows → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/train_labels.csv
Wrote 20 rows → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/val_labels.csv
Wrote 20 rows → /Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images/test_labels.csv


In [4]:
import os, json
import pandas as pd
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, models

from sklearn.metrics import f1_score, precision_recall_fscore_support, roc_auc_score

# === Paths (update OUT_ROOT to your folder) ===
OUT_ROOT = r"/Users/shayne/Documents/SUNWAY_UNI/sem8/Capstone2/ecg_images"
train_csv = os.path.join(OUT_ROOT, "train_labels.csv")
val_csv   = os.path.join(OUT_ROOT, "val_labels.csv")
test_csv  = os.path.join(OUT_ROOT, "test_labels.csv")

# Class order used everywhere
CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]

# Device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

device

device(type='mps')

In [5]:
class MultiLabelECGImages(Dataset):
    """
    Reads rows: image_path, labels(JSON list), and returns:
      image tensor, multi-hot target tensor of shape [5]
    """
    # The behavior
    def __init__(self, labels_csv, transform=None):
        self.df = pd.read_csv(labels_csv)
        # Read csv and turn paths and labels into lists
        # [a,b,c]
        # [x,y,z]
        self.paths = self.df["image_path"].tolist()
        self.labels_json = self.df["labels"].tolist()
        # transform function to apply on the data 
        self.transform = transform
        
    def __len__(self): 
        # returns how many images/data we have
        return len(self.paths)
        
    # actions
    def __getitem__(self, idx):
        # Load image
        # you can use RGB or true grayscale, but we mostly use grayscale
        img = Image.open(self.paths[idx]).convert("RGB")
        # if there is a transformation function, apply it
        if self.transform:
            img = self.transform(img)
        # Parse labels JSON → list[str] → multi-hot
        # load the labels
        # [NORM, MI] or other variations like [MI, STTC]
        # json loads turns a string into an list 
        labs = json.loads(self.labels_json[idx])
        # create a "hot" arrray of size CLASS_ORDER filled with zeros
        # [NORM, 
        # [0,0,0,0,0]
        target = torch.zeros(len(CLASS_ORDER), dtype=torch.float32)
        # target = [0.0, 1.0, 0.0, 0.0, 0.0]
        # for every label that matches with class order 
        # ["NORM", "MI", "STTC", "HYP", "CD"], fill with 1.0
        # [NORM, MI] -> [1.0, 1.0, 0.0, 0.0, 0.0]
        # labs = [NORM, MI]
        for lab in labs:
            if lab in CLASS_ORDER:
                target[CLASS_ORDER.index(lab)] = 1.0
        # [Image, label]
        # [Image, [0.0, 1.0, 0.0, 0.0, 0.0]]
        return img, target

In [6]:
# Image size to train (you can try 224 or 256)
# 224 is a legendary image size that comes from ImageNet, which later used in AlexNet, VGG16, ResNet,
# enough to preserve detail and small enough to fit in GPU memory with small batch sizes
IMG_SIZE = 224

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    # Light augmentation (avoid horizontal flip; it reverses time)
    # We want it to understand small shifts in width and height, or small changes in zoom, but we do not want to make it rotate
    # rotation matters in ecg
    transforms.RandomApply([transforms.RandomAffine(degrees=0, translate=(0.02,0.02), scale=(0.98,1.02))], p=0.5),
    # turns it into a tensor floatTensor to be exact
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

# repeat the same for testing images and evaluation images
eval_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225]),
])

train_ds = MultiLabelECGImages(train_csv, transform=train_tfms)
val_ds   = MultiLabelECGImages(val_csv,   transform=eval_tfms)
test_ds  = MultiLabelECGImages(test_csv,  transform=eval_tfms)

BATCH_SIZE = 32
NUM_WORKERS = 0  # keep 0 on Windows to avoid issues; try 2 later

"""
# train_ds returns a dataset object (image, label)
# batch size = 32 calls train_ds 32 times to get 32 labels making it one batch
# when we call it in for batch_img, batch_labels in train_dl: we will split data into [32,32,32,32, ...] and for that list
# we'll loop through it
# num workers are just for parallel processing purposes (by default it'll run in parallel)
# pin memory locks a part of the cpu so that nothing can overwrite it, therefore gpu never accesses
# the wrong data
"""
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS, pin_memory=True)
val_dl   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_dl  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

len(train_ds), len(val_ds), len(test_ds)

(50, 20, 20)

In [7]:
"""
funny thing about imbalanced models 75-25

the classes are not balanced, there are way more NORM than everything else
assume a 75-25 ratio. due to this imbalance, we must make the model care 
just as much about the minority (25) through heavy penalties
"""
def compute_pos_weight(labels_csv):
    df = pd.read_csv(labels_csv)
    # Expect columns NORM, MI, STTC, HYP, CD as 0/1 flags
    # load training labels only because thats what we train on
    # count the number of positives per class
    # returns smth like [1200,80,250,200,90] just an example
    counts_pos = df[CLASS_ORDER].sum().values.astype(np.float32) 
    # find out which classs are rare or common
    counts_all = len(df)
    counts_neg = counts_all - counts_pos
    # Avoid division by zero
    # calculate the weights
    pos_weight = counts_neg / np.clip(counts_pos, 1.0, None)
    # the weights tell us how heavily we should penalize for failing to recognize 
    # certain classes (what does this remind u of?)
    return torch.tensor(pos_weight, dtype=torch.float32)

pos_weight = compute_pos_weight(train_csv).to(device)
pos_weight

tensor([ 0.1628, 24.0000, 24.0000, 49.0000,  9.0000], device='mps:0')

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    """Two 3×3 convs with BN+ReLU. Keeps H,W; downsampling happens in the following MaxPool."""
    # building neural networks with classes is we need to define two things
    # properties and behavior 
    # runner 
    def __init__(self, in_c, out_c):
        super().__init__()
        # input channel = 3 (RGB), filter number (3x3 stencil), padding to make sure stencil stays on edge
        # bias = False, we'll handle it in bn1
        # convolution runs a filter on an input noting its variations, patterns, nuances, etc
        # we do this many times across many layers to "learn" a pattern
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
        # batch normalizer stabilizes and normalizes activations
        # recentering / rescaling inputs to a predictable range
        # activations (outputs) -> channel/node (layer) 
        self.bn1   = nn.BatchNorm2d(out_c)
        # repeats but this time instead of 3 -> 32, we'll look 32 -> 32
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_c)

    def forward(self, x):
        # previously we defined how our model should work
        # forward function runs our model exactly n times (given you call it n times)
        # relu (adds non linearity)
        # image -> conv1 -> bn -> relu -> features
        x = F.relu(self.bn1(self.conv1(x)))  # conv → BN → ReLU
        x = F.relu(self.bn2(self.conv2(x)))  # conv → BN → ReLU
        return x

class ECGStripCNN(nn.Module):
    """
    Simple-but-solid CNN for 224x224 RGB images:
    4 stages (channels: 32, 64, 128, 256) with MaxPool after each stage,
    then global average pooling and a small classifier head to 5 logits.

    basically: 
    pattern spotters - convolutions
    zoom-out step - pooling 
    summary clipboard - global average pooling
    final judge - linear head that outputs 5 scores (1 per diagnosis)

    input: [B, 3, 224, 224] (B (optional), C, H, W)
    output: [B, 5] logits (5 classes)

    conv block basically marks every possible pattern/evidence it can find
    maxpool grabs the most convincing/relevant one from each segment (igores noise)
    """
    def __init__(self, n_classes=5, dropout_p=0.3):
        super().__init__()
        # conv block outputs 32 pattern maps
        self.stage1 = ConvBlock(3,32)   # [B,3,224,224] -> [B,32,224,224]
        # for all the pattterns maxpool (2x2) allows the maxpool
        # filter to traverse and generalize
        # after each step, the resolution gets smaller and smaller
        # but this is fine because even with fewer piels, the features are more generalized
        self.pool1  = nn.MaxPool2d(2)      #                  -> [B,32,112,112]

        # 32 - 64 - 128 - 256 (are the features)
        # why more channels as we move on? 
        self.stage2 = ConvBlock(32,  64)   # -> [B,64,112,112]
        self.pool2  = nn.MaxPool2d(2)      # -> [B,64,56,56]

        self.stage3 = ConvBlock(64, 128)   # -> [B,128,56,56]
        self.pool3  = nn.MaxPool2d(2)      # -> [B,128,28,28]

        self.stage4 = ConvBlock(128,256)   # -> [B,256,28,28]
        self.pool4  = nn.MaxPool2d(2)      # -> [B,256,14,14]

        # Global average pooling → [B,256,1,1]
        # we already have many pools, the function above just grabs the average
        # adaptive because regardless of inputsize it always returns a 1x1
        self.gap    = nn.AdaptiveAvgPool2d((1,1))
        # randomly zeros a fraction p
        # prevents classifier on depending on a small set of features 
        self.drop   = nn.Dropout(p=dropout_p)
        # linear layer to finish it off 
        # conv blocks: 32 -> 64 -> 128 -> 256
        # linear block: 256 -> 5 (5 classes)
        self.fc     = nn.Linear(256, n_classes)  # 5 logits for multi-label

    def forward(self, x):
        x = self.pool1(self.stage1(x))
        x = self.pool2(self.stage2(x))
        x = self.pool3(self.stage3(x))
        x = self.pool4(self.stage4(x))
        # calculates the average pooling
        x = self.gap(x)                 # shape: [B,256,1,1] output: 
        # flatten is just reshaping the output into torch understandable ones
        x = torch.flatten(x, 1)         # [B,256]
        # zeros features
        x = self.drop(x)
        # get the actual logits
        logits = self.fc(x)             # [B,5] (raw scores, not sigmoid) [0.42,0.456,0.5387,..,..
        return logits

In [9]:
def run_epoch(model, loader, optimizer=None):
    """
    Train if optimizer is provided; else evaluate.
    Returns: average loss, list of all targets, list of all probs
    """
    train_mode = optimizer is not None
    model.train() if train_mode else model.eval()

    total_loss = 0.0
    all_targets, all_probs = [], []

    with torch.set_grad_enabled(train_mode):
        for imgs, targets in loader: # train_df
            # remember, we work on the gpu
            # so we move all the data that we're working with
            # to the gpu
            imgs = imgs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            # get our initial predictions
            logits = model(imgs)               # [B,5]
            # calculate loss
            loss = loss_function(logits, targets)  # scalar

            # our usual step 
            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            total_loss += loss.item() * imgs.size(0)

            # store probs for metrics 2.581 -> 0.43
            # turn logits into probabilities called at [0,1]
            probs = torch.sigmoid(logits).detach().cpu().numpy() # detach moves probs from gpu to cpu to turn it into numpy array [1,2,3,4]
            # we dont want to use gpu unless its for parallel tasks

            # caches the data
            all_probs.append(probs)
            all_targets.append(targets.detach().cpu().numpy())

    N = len(loader.dataset)
    avg_loss = total_loss / max(1, N)
    all_probs = np.concatenate(all_probs, axis=0)
    all_targets = np.concatenate(all_targets, axis=0)
    return avg_loss, all_targets, all_probs

def multilabel_metrics(y_true, y_prob, threshold=0.5):
    """
    y_true: [N,5] {0,1}
    y_prob: [N,5] [0..1]
    NORM, MI, STTC, HYP, CD
    75% 25%

    [0,1,1,0,0]
    [0,1,0,0,1]
    precision: of all the cases we called (flagged) positive, how many are truly positive?
    recall: of all the positive cases that exist, how many did we catch?
    
    f1 (micro and macro): harmonic mean of prec & rec (fairness/balance score)

    if u try to catch everything, we might get tons of false alarms or super picky so that u miss real cases
    your f1 score will suck, it needs balance
    
    f1 macro: used to deal with imbalance. coast on common cases and fail on rarities (75-25)
    f1 micro: similar to what we're used to. tally every answer/flags and judge

    evaluate behavior not only results
    
    if macro is low but micro is high: fix rare cases
    if both are low: more data (more batches, higher res, etc), ecg augmentations, better lr schedules
    """
    y_pred = (y_prob >= threshold).astype(int)
    # Per-class precision/recall/F1
    # prec, rec, f1, support (no. occurences of each label in y_true)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    macro_f1 = f1.mean()
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    # AUROC per class (guard against classes with single label value)
    """
    thresholds are important too: yes/no decided by probability (0.7 +-) 
    low threshold: more data is accepted (precision down, recall up)
    high threshold: less data is accepted (recall down, precision up)

    ROC vs PR Curves (uses thresholds)
    NORM, MI, STTC, HYP, CD
    0.7, 0.3, 0.2, 0.5, 0.9 (threshold = 0.5)
    Y, N, N, Y, Y

    ROC: broad sense of seperability (principle) "do the beeps get louder for more dangerous items"
    1, 0.5, 0.7,0.2,0.1

    PR: precision vs recall (how clean are my positives as i try to catch more?)
    use when positives are rare (we obviously want to let more data flood)
    pr shows us the tradeoffs (more false alarms) when we raise our recall (accept more data)
    punishes "crying wolf"

    roc is like a thesis statement (do spam eails tend to get higher spam scores than normal emails)
    pr is a proposed solution (if i flag more emails as spam to catch all spam (high recall), do i accidentally flag
    a lot more real emails (low precision)?

    F1 at a tuned threshold: “given a chosen operating point, how balanced are my actual yes/no decisions?"    

    these 3 don't necessarily work together but looking at all of them gives us multiple perspectives
    """
    aurocs = []
    for i in range(y_true.shape[1]):
        if len(np.unique(y_true[:, i])) == 2:
            aurocs.append(roc_auc_score(y_true[:, i], y_prob[:, i]))
        else:
            aurocs.append(float('nan'))
    return {
        "macro_f1": macro_f1,
        "micro_f1": micro_f1,
        "per_class_f1": dict(zip(CLASS_ORDER, f1)),
        "per_class_precision": dict(zip(CLASS_ORDER, prec)),
        "per_class_recall": dict(zip(CLASS_ORDER, rec)),
        "per_class_auroc": dict(zip(CLASS_ORDER, aurocs)),
    }

In [10]:
model = ECGStripCNN(n_classes=5, dropout_p=0.3).to(device)

# Loss: multi-label
loss_function = nn.BCEWithLogitsLoss(pos_weight=pos_weight)  # pos_weight from earlier step
# Optimizer & scheduler (tweak LR as you like)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
# start with somewhat high LR, this helps learn faster and detect bigger patterns
# once we start to reach a plateau on accuracy
# the lr is adaptively reduced, smller learning rates help with finer pattern detection (in this case)
# runs until the end of epochs if no plateau, but if plateau even after the below function, we quit process early to save time and resources
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Train as before
EPOCHS = 10
best_val = float('inf')
best_path = os.path.join(OUT_ROOT, "best_custom_cnn_multilabel.pt")

for epoch in range(1, EPOCHS+1):
    # training (as usual)
    train_loss, _, _ = run_epoch(model, train_dl, optimizer=optimizer)   # uses your earlier utility
    # validation is training on unseen data
    val_loss, y_true, y_prob = run_epoch(model, val_dl, optimizer=None)
    # our scheculer keeps a history of the losses
    # if the losses dont improve, we adjust LR
    scheduler.step(val_loss)

    metrics = multilabel_metrics(y_true, y_prob, threshold=0.5)
    print(f"Epoch {epoch:02d} | "
          f"train_loss={train_loss:.4f}  val_loss={val_loss:.4f}  "
          f"macroF1={metrics['macro_f1']:.3f}  microF1={metrics['micro_f1']:.3f}")

    if val_loss < best_val:
        best_val = val_loss
        torch.save({"model_state": model.state_dict(),
                    "class_order": CLASS_ORDER}, best_path)
        print(f"  ↳ saved best to {best_path}")

  return x.astype(dtype, copy=copy, casting=casting)


ValueError: Input y_true contains NaN.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
FINE tuning:
0. epochs
1. dropout
2. lr
3. weight decay
4. factor (scheduler)
6. patience (scheduler)
7. more feature classes in ConvBlock
8. more layers for convblock
9. image augmentation
10. increase data size, increase image resolution, change to higher bitrate
"""

ckpt = torch.load(best_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
model.eval()

test_loss, y_true_test, y_prob_test = run_epoch(model, test_dl, optimizer=None)
test_metrics = multilabel_metrics(y_true_test, y_prob_test, threshold=0.5)

# Compute simple accuracies
threshold = 0.5
y_pred_test = (y_prob_test >= threshold).astype(int)
exact_match = np.all(y_pred_test == y_true_test, axis=1)
exact_acc = exact_match.mean() * 100
labelwise_acc = (y_pred_test == y_true_test).mean() * 100

print(f"\nTest loss={test_loss:.4f}")
print(f"MacroF1={test_metrics['macro_f1']:.3f}  MicroF1={test_metrics['micro_f1']:.3f}")
print(f"Exact-match accuracy: {exact_acc:.2f}%")
print(f"Label-wise mean accuracy: {labelwise_acc:.2f}%")
print("Per-class F1:", test_metrics["per_class_f1"])
print("Per-class AUROC:", test_metrics["per_class_auroc"])