In [6]:
import os, ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import wfdb
from PIL import Image
import glob

# set base folder
DB = r"/Users/shayne/Documents/SUNWAY_UNI/sem8/capstone 1/dataset/ptb-xl-a-large-publicly-available-electrocardiography-dataset-1.0.3" # <- path to the PTB-XL folder
OUT_ROOT = r"/Users/shayne"        # png saved here

FNAME_COL = "filename_lr"                      
SR = 100      
# 1600 x 1200
TARGET_SIZE = (800, 600)                       # output image width × height (pixels)

os.makedirs(OUT_ROOT, exist_ok=True)


# read metadata CSVs
df = pd.read_csv(os.path.join(DB, "ptbxl_database.csv"))
scp = pd.read_csv(os.path.join(DB, "scp_statements.csv"), index_col=0)

# keep only diagnostic statements
diagnostic_codes = scp[scp["diagnostic"] == 1]


#map scp_codes convert to diagnostic superclasses
def to_superclasses(scp_codes_str):
    codes = ast.literal_eval(scp_codes_str)  # dict: code → weight
    diags = [c for c in codes.keys() if c in diagnostic_codes.index]
    supers = sorted({diagnostic_codes.loc[c, "diagnostic_class"] for c in diags})
    return supers

df["superclasses"] = df["scp_codes"].apply(to_superclasses)

# official stratified folds

train_df = df[df["strat_fold"].isin(range(1, 9))].copy()
val_df   = df[df["strat_fold"] == 9].copy()
test_df  = df[df["strat_fold"] == 10].copy()


#single-label primary class per record
# if an entry has multiple super classes, pick one based on the priority shown below
PRIORITY = ["MI", "STTC", "HYP", "CD", "NORM"]

# function that picks a super class 
def choose_primary_superclass(superclasses):
    if not superclasses:
        return None
    for c in PRIORITY:
        if c in superclasses:
            return c
    return superclasses[0]

# create a new column called primary class and apply the result of the super class function to it
for split in (train_df, val_df, test_df):
    split["primary_class"] = split["superclasses"].apply(choose_primary_superclass)

# drop the rows that have no primary class 
train_df = train_df.dropna(subset=["primary_class"])
val_df   = val_df.dropna(subset=["primary_class"])
test_df  = test_df.dropna(subset=["primary_class"])


# helpers for reading WFDB and saving plots

def load_signal_and_leads(rec_rel_path, base_dir=DB):
    """Read WFDB record. Returns (signal[T,12], lead_names[list])."""
    rec_path = os.path.join(base_dir, rec_rel_path)
    sig, meta = wfdb.rdsamp(rec_path)
    names = list(meta.sig_name) if hasattr(meta, "sig_name") else [f"Lead{i+1}" for i in range(sig.shape[1])]
    return sig.astype("float32"), names
    

# algorithm that turns the waveform data into plots
# 12 leads per image in this case
def save_12lead_strip(signal, lead_names, out_path, sr=SR, target_size=TARGET_SIZE):
    """Plot 12 leads in a 3×4 grid and save as PNG."""
    T, C = signal.shape
    fig_w, fig_h = 10, 6
    dpi = min(target_size[0]/fig_w, target_size[1]/fig_h)

    fig, axes = plt.subplots(3, 4, figsize=(fig_w, fig_h), dpi=dpi)
    axes = axes.ravel()
    t = np.arange(T) / float(sr)

    for i in range(min(C, 12)):
        ax = axes[i]
        ax.plot(t, signal[:, i], linewidth=0.8)
        # ax.set_title(lead_names[i] if i < len(lead_names) else f"Lead {i+1}", fontsize=8)
        ax.set_xlim([t[0], t[-1]])
        ax.axis("off")
    for j in range(C, len(axes)):
        axes[j].axis("off")

    plt.tight_layout(pad=0.15)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", pad_inches=0.03)
    plt.close(fig)

# saves the plots into iamges with good file directory and names
def export_split_images(split_df, split_name, limit=None):
    """Save ECG plots into OUT_ROOT/split_name/<class>/<ecg_id>.png."""
    saved = 0
    for _, row in tqdm(split_df.iterrows(), total=len(split_df), desc=f"Export {split_name}"):
        if limit and saved >= limit:
            break
        label = row["primary_class"]
        if not label:
            continue
        try:
            signal, leads = load_signal_and_leads(row[FNAME_COL])
        except Exception as e:
            # print("Failed:", row[FNAME_COL], e)
            continue

        ecg_id = int(row["ecg_id"]) if "ecg_id" in row else _
        out_path = os.path.join(OUT_ROOT, split_name, label, f"{ecg_id}.png")
        save_12lead_strip(signal, leads, out_path)
        saved += 1
    print(f"[{split_name}] saved {saved} images → {os.path.join(OUT_ROOT, split_name)}")


# run export (try small limits first)
export_split_images(train_df, "train", limit=50)
export_split_images(val_df,   "val",   limit=20)
export_split_images(test_df,  "test",  limit=20)


# preview a few saved images
some = glob.glob(os.path.join(OUT_ROOT, "train", "*", "*.png"))[:5]
for p in some:
    print(p)
    img = Image.open(p)
    img.show()  # opens in default image viewer


Export train:   0%|                          | 50/17084 [00:03<18:57, 14.98it/s]


[train] saved 50 images → /Users/shayne/train


Export val:   1%|▎                            | 20/2146 [00:01<01:57, 18.16it/s]


[val] saved 20 images → /Users/shayne/val


Export test:   1%|▎                           | 20/2158 [00:01<02:03, 17.36it/s]


[test] saved 20 images → /Users/shayne/test
/Users/shayne/train/MI/77.png
/Users/shayne/train/MI/50.png
/Users/shayne/train/STTC/22.png
/Users/shayne/train/STTC/54.png
/Users/shayne/train/HYP/45.png


In [5]:
import os, json
import pandas as pd

CLASS_ORDER = ["NORM", "MI", "STTC", "HYP", "CD"]


# maps the class into binary as below
def to_multihot(superclasses):
    s = set(superclasses)
    return {c: int(c in s) for c in CLASS_ORDER}

def build_labels_csv_from_existing(split_df, split_name, out_root=OUT_ROOT):
    
    # empty rows to store all entries
    rows = []
    # the directory of the split (train/val/test)
    split_dir = os.path.join(out_root, split_name)
    # for each entry in the dataframe [id, name, superclass, primaryclass]
    for _, r in split_df.iterrows():
        # we get the super class
        supers = r["superclasses"]          # e.g., ['CD','HYP']
        # if no superclass, then this entry is meaningless
        if not supers:
            continue
        # get primary class
        primary = r["primary_class"]
        # get the id of the ecg
        ecg_id = int(r["ecg_id"])
        # find the image that we saved that corresponds to the entry we're lookniga t right now
        img_path = os.path.join(split_dir, primary, f"{ecg_id}.png")
        if not os.path.exists(img_path):
            # might not exist if you used a small 'limit' during export
            continue

        # creates a multihot row
        mh = to_multihot(supers)
        row = {
            "image_path": img_path.replace("\\", "/"),
            "labels": json.dumps(sorted(supers))
        }
        # add the columns
        row.update(mh)                      # add NORM/MI/STTC/HYP/CD columns
        # add the rows
        rows.append(row)
      

    df_out = pd.DataFrame(rows)
    out_csv = os.path.join(out_root, f"{split_name}_labels.csv")
    df_out.to_csv(out_csv, index=False)
    print(f"Wrote {len(df_out)} rows → {out_csv}")
    return out_csv

# build csvs for all splits
train_csv = build_labels_csv_from_existing(train_df, "train")
val_csv   = build_labels_csv_from_existing(val_df,   "val")
test_csv  = build_labels_csv_from_existing(test_df,  "test")


Wrote 50 rows → /Users/shayne/train_labels.csv
Wrote 20 rows → /Users/shayne/val_labels.csv
Wrote 20 rows → /Users/shayne/test_labels.csv
