# 1. PTB ECG – Prepare Beat-Level Dataset

In [None]:
import os
from glob import glob

import numpy as np
import wfdb
from scipy.signal import find_peaks, butter, filtfilt


def list_ptb_records(ptb_root):
    """
    ptb_root: path to PTB root folder, e.g. 'data/ptbdb'
    Returns list of full record paths without extension, e.g. '.../patient001/s0010_re'
    """
    rec_paths = []
    for patient_dir in sorted(glob(os.path.join(ptb_root, 'patient*'))):
        for rec in sorted(glob(os.path.join(patient_dir, '*.hea'))):
            rec_paths.append(rec[:-4])  # strip '.hea'
    return rec_paths


def bandpass_filter(signal, fs=1000, low=5.0, high=15.0, order=2):
    """
    Simple bandpass filter for QRS detection.
    signal: 1D numpy array
    fs: sampling frequency
    """
    nyq = 0.5 * fs
    b, a = butter(order, [low / nyq, high / nyq], btype='band')
    return filtfilt(b, a, signal)
    

def detect_r_peaks(ecg_lead, fs=1000, distance_ms=300, height_factor=0.5):
    """
    Very basic R-peak detector on a single lead.
    ecg_lead: 1D numpy array
    distance_ms: minimum distance between peaks
    height_factor: relative threshold on peak height
    Returns indices of R-peaks.
    """
    filtered = bandpass_filter(ecg_lead, fs=fs, low=5.0, high=15.0)
    distance = int(distance_ms * fs / 1000.0)
    # rough threshold: height_factor * max
    height = height_factor * np.max(filtered)
    peaks, _ = find_peaks(filtered, distance=distance, height=height)
    return peaks


In [None]:
import os
import numpy as np
import wfdb
from collections import defaultdict

PTB_ROOT = "data/ptbdb"    # TODO: change to your path
OUT_FILE = "ptb_beats.npz" # output

# standard 12 leads in PTB header
STANDARD_LEADS = ['i', 'ii', 'iii', 'avr', 'avl', 'avf',
                  'v1', 'v2', 'v3', 'v4', 'v5', 'v6']

all_beats = []
all_record_names = []
all_subject_ids = []
all_labels = []  # you can later fill with MI/non-MI, etc.

rec_paths = list_ptb_records(PTB_ROOT)
print(f"Found {len(rec_paths)} records")

for rec_path in rec_paths:
    try:
        record = wfdb.rdrecord(rec_path)
    except Exception as e:
        print(f"Failed to read {rec_path}: {e}")
        continue

    sig = record.p_signal  # shape (T, n_channels)
    fs = int(record.fs)    # should be 1000
    ch_names = [ch.lower() for ch in record.sig_name]

    # map lead names to indices
    lead_indices = []
    for ln in STANDARD_LEADS:
        if ln in ch_names:
            lead_indices.append(ch_names.index(ln))
        else:
            print(f"Lead {ln} not found in {rec_path}, skipping record.")
            lead_indices = []
            break

    if not lead_indices:
        continue

    ecg_12 = sig[:, lead_indices].T  # shape (12, T)
    T = ecg_12.shape[1]

    # use lead II for R-peak detection
    lead2_index = ch_names.index('ii')
    lead2 = sig[:, lead2_index]
    r_peaks = detect_r_peaks(lead2, fs=fs, distance_ms=300, height_factor=0.4)

    if len(r_peaks) < 5:
        print(f"Few R-peaks in {rec_path}, skipping.")
        continue

    # 2-second window: [-500ms, +1500ms] → 2000 samples
    pre = int(0.5 * fs)
    post = int(1.5 * fs)
    win_len = pre + post

    for r in r_peaks:
        start = r - pre
        end = r + post
        if start < 0 or end > T:
            continue
        beat = ecg_12[:, start:end]  # (12, 2000)
        if beat.shape[1] != win_len:
            continue

        # simple per-beat normalization (median + MAD)
        m = np.median(beat, axis=1, keepdims=True)
        mad = np.median(np.abs(beat - m), axis=1, keepdims=True) + 1e-6
        beat_norm = (beat - m) / mad

        all_beats.append(beat_norm.astype(np.float32))
        all_record_names.append(os.path.basename(rec_path))
        # subject id is patient folder name
        all_subject_ids.append(os.path.basename(os.path.dirname(rec_path)))
        all_labels.append(0)  # placeholder; you can parse .hea for diagnosis later

all_beats = np.stack(all_beats, axis=0)  # (N, 12, 2000)
all_record_names = np.array(all_record_names)
all_subject_ids = np.array(all_subject_ids)
all_labels = np.array(all_labels)

print("PTB beats shape:", all_beats.shape)

np.savez_compressed(
    OUT_FILE,
    beats=all_beats,
    record_names=all_record_names,
    subject_ids=all_subject_ids,
    labels=all_labels,
    fs=np.array([1000], dtype=np.int32),
    description=np.array(["PTB 12-lead beats, 2s windows around R-peaks"], dtype=object)
)
print(f"Saved PTB beats to {OUT_FILE}")


# 2. Koch ECG–MCG – Prepare Paired Beat Dataset