In [3]:

import os
import numpy as np
import wfdb
from scipy.signal import butter, filtfilt

# default setup params
LOCAL_DIR     = "./mitdb"       # path to MIT-BIH directory
LEAD_IDX      = 0               # ECG lead used
WIN_SEC       = 5.0             # window length (seconds)
STRIDE_SEC    = 2.5             # stride (seconds) , potential overlap with win
HP_CUTOFF_HZ  = 0.5             # HPF (Hz) for baseline removal (None/0 to disable)
COVERAGE_THR  = 0.5             # ≥50% coverage => assign label 1 in this case
AFIB_TAGS     = {"AFIB"}        # rhythm labels counted as AFib
NORMAL_TAGS   = {"N", "NSR", "SR"}  # rhythm labels counted as Normal
NORMAL_BEAT_FALLBACK_FRAC = 0.80     # if no rhythm labels but ≥80% beats are 'N' -> Normal
SAVE_PATH     = "mitdb_windows_5s_binary_afib.npz"

def list_record_ids(local_dir):
    
    return sorted({os.path.splitext(f)[0] for f in os.listdir(local_dir) if f.endswith(".dat")})

def hp_filter(signal, fs, cutoff=0.5, order=2):
    if not cutoff or cutoff <= 0:  # disabled
        return signal
    b, a = butter(order, cutoff/(fs/2), btype='highpass')
    return filtfilt(b, a, signal)

def normalize_window(x):
    mu = x.mean()
    sd = x.std() + 1e-6
    return (x - mu) / sd

def clean_aux(aux):
    """Normalize aux_note strings → uppercase labels without parens/nulls/whitespace."""
    if aux is None:
        return ""
    s = str(aux).replace("(", "").replace(")", "").replace("\x00", "").strip().upper()
    return s

def load_record(rec_id):
    """Load a record and 'atr' annotations from local files."""
    base = os.path.join(LOCAL_DIR, rec_id)
    record = wfdb.rdrecord(base)
    ann    = wfdb.rdann(base, "atr")
    return record, ann

def audit_rhythm_labels():
    """Print rhythm labels (aux_note) seen per record; return dict rec_id -> sorted(list(labels))."""
    rec_ids = list_record_ids(LOCAL_DIR)
    label_map = {}
    print("=== Rhythm label audit (aux_note) ===")
    for rid in rec_ids:
        ann = wfdb.rdann(os.path.join(LOCAL_DIR, rid), "atr")
        labs = {clean_aux(a) for a in ann.aux_note if a and str(a).strip()}
        labs.discard("")
        label_map[rid] = sorted(labs)
        if labs:
            print(f"{rid} → {sorted(labs)}")
        else:
            print(f"{rid} → (no rhythm aux_note)")
    print("=====================================\n")
    return label_map

def find_afib_records():
    """List record IDs that contain AFIB rhythm labels after cleaning."""
    hits = []
    for rid in list_record_ids(LOCAL_DIR):
        ann = wfdb.rdann(os.path.join(LOCAL_DIR, rid), "atr")
        labs = {clean_aux(a) for a in ann.aux_note if a and str(a).strip()}
        if "AFIB" in labs:
            hits.append(rid)
    return hits

def build_rhythm_intervals(ann, last_sample):
    """Build rhythm intervals from aux_note changes → list of (start_sample, end_sample, label)."""
    changes = []
    for i, aux in enumerate(ann.aux_note):
        lab = clean_aux(aux)
        if lab:
            changes.append((ann.sample[i], lab))
    if not changes:
        return []
    intervals = []
    for j in range(len(changes)):
        s = changes[j][0]
        lab = changes[j][1]
        e = changes[j+1][0] if (j+1) < len(changes) else last_sample
        if e > s:
            intervals.append((s, e, lab))
    return intervals

def label_window_by_coverage(start_samp, end_samp, intervals):
    """Compute AFIB vs NORMAL coverage in [start, end). Return 1 (AFIB), 0 (NORMAL), or None."""
    wlen = end_samp - start_samp
    if wlen <= 0:
        return None
    af_cov = 0
    n_cov  = 0
    for s, e, lab in intervals:
        left  = max(start_samp, s)
        right = min(end_samp, e)
        ov    = max(0, right - left)
        if ov > 0:
            if lab in AFIB_TAGS:
                af_cov += ov
            elif lab in NORMAL_TAGS:
                n_cov  += ov
            else:
                pass  # ignore other rhythms for v1
    af_frac = af_cov / wlen
    n_frac  = n_cov  / wlen
    if af_frac >= COVERAGE_THR:
        return 1
    if n_frac  >= COVERAGE_THR:
        return 0
    return None  

def majority_is_normal_beats(ann, min_frac=NORMAL_BEAT_FALLBACK_FRAC):
    """Fallback: treat as Normal if ≥ min_frac beats are 'N' (when no rhythm labels exist)."""
    syms = [s for s in ann.symbol if s is not None]
    if not syms:
        return False
    frac = syms.count('N') / len(syms)
    return frac >= min_frac

def extract_windows_from_record(rec_id):
    """
    Return (X, y, ids) arrays for a single record, using rhythm coverage and fallback.
    - X: (n_win, win_len)
    - y: (n_win,)
    - ids: (n_win,) with value rec_id for each window
    """
    record, ann = load_record(rec_id)
    fs = int(record.fs)
    sig = record.p_signal[:, LEAD_IDX].astype(np.float32)
    # HPF / low noise removal
    sig = hp_filter(sig, fs, cutoff=HP_CUTOFF_HZ)

    last_sample = len(sig) - 1
    intervals = build_rhythm_intervals(ann, last_sample)

    win_len = int(WIN_SEC * fs)
    stride  = int(STRIDE_SEC * fs)
    stride  = stride if stride > 0 else win_len

    X_list, y_list, ids_list = [], [], []

    if intervals:
        # Use rhythm coverage for labeling
        start = 0
        while start + win_len <= len(sig):
            end = start + win_len
            y = label_window_by_coverage(start, end, intervals)
            if y is not None:
                x = normalize_window(sig[start:end])
                X_list.append(x); y_list.append(y); ids_list.append(rec_id)
            start += stride
    else:
        # Fallback: if beats are mostly normal, harvest as Normal
        if majority_is_normal_beats(ann):
            start = 0
            while start + win_len <= len(sig):
                end = start + win_len
                x = normalize_window(sig[start:end])
                X_list.append(x); y_list.append(0); ids_list.append(rec_id)  # Normal
                start += stride

    if not X_list:
        return (np.empty((0,)), np.empty((0,), dtype=int), np.empty((0,), dtype=object))
    return (np.stack(X_list), np.array(y_list, dtype=int), np.array(ids_list, dtype=object))


def main():
    # 1) Audit rhythm labels & show which records contain AFIB
    audit_rhythm_labels()
    print("Records with AFIB:", find_afib_records(), "\n")

    # 2) Build dataset across all records
    rec_ids = list_record_ids(LOCAL_DIR)
    X_all, y_all, id_all = [], [], []
    total_af, total_n = 0, 0

    for rid in rec_ids:
        try:
            Xr, yr, idr = extract_windows_from_record(rid)
            if len(yr) > 0:
                X_all.append(Xr); y_all.append(yr); id_all.append(idr)
                na = int(yr.sum()); nn = int(len(yr) - na)
                total_af += na; total_n += nn
                print(f"{rid}: kept {len(yr)} windows  (AFib {na} / Normal {nn})")
            else:
                print(f"{rid}: no labeled windows (skipped)")
        except Exception as e:
            print(f"{rid}: ERROR → {e}")

    if not X_all:
        print("No windows extracted. Consider lowering COVERAGE_THR, adding NORMAL_TAGS, or using smaller STRIDE_SEC.")
        return

    X = np.concatenate(X_all, axis=0)
    y = np.concatenate(y_all, axis=0)
    rec_ids_per_window = np.concatenate(id_all, axis=0)

    # 3) Report & Save
    af_frac = y.mean()
    print("\n=== Summary ===")
    print("Final dataset:", X.shape, y.shape)
    print(f"AFib windows: {int(y.sum())} | Normal windows: {int(len(y)-y.sum())}")
    print("AFib fraction:", round(af_frac, 3))

    np.savez_compressed(
        SAVE_PATH,
        X=X,
        y=y,
        rec_ids=rec_ids_per_window, 
        meta=np.array({
            "lead_idx": LEAD_IDX,
            "win_sec": WIN_SEC,
            "stride_sec": STRIDE_SEC,
            "hp_cutoff_hz": HP_CUTOFF_HZ,
            "coverage_thr": COVERAGE_THR,
            "afib_tags": list(AFIB_TAGS),
            "normal_tags": list(NORMAL_TAGS),
            "normal_fallback_frac": NORMAL_BEAT_FALLBACK_FRAC
        }, dtype=object)
    )
    print(f"Saved → {SAVE_PATH}")

if __name__ == "__main__":
    main()


=== Rhythm label audit (aux_note) ===
100 → ['N']
101 → ['N']
102 → ['N', 'P']
103 → ['N']
104 → ['N', 'P']
105 → ['N']
106 → ['B', 'N', 'T', 'VT']
107 → ['P']
108 → ['N']
109 → ['N']
111 → ['N']
112 → ['N']
113 → ['N']
114 → ['N', 'SVTA']
115 → ['N']
116 → ['N']
117 → ['N']
118 → ['N']
119 → ['B', 'N', 'T']
121 → ['N']
122 → ['N']
123 → ['N']
124 → ['IVR', 'N', 'NOD', 'T']
200 → ['B', 'N', 'VT']
201 → ['AFIB', 'N', 'NOD', 'SVTA', 'T']
202 → ['AFIB', 'AFL', 'N']
203 → ['AFIB', 'AFL', 'T', 'VT']
205 → ['N', 'VT']
207 → ['B', 'IVR', 'N', 'SVTA', 'VFL', 'VT']
208 → ['N', 'T']
209 → ['N', 'SVTA']
210 → ['AFIB', 'B', 'T', 'VT']
212 → ['N']
213 → ['B', 'N', 'VT']
214 → ['N', 'T', 'TS', 'VT']
215 → ['N', 'TS', 'VT']
217 → ['AFIB', 'B', 'P', 'VT']
219 → ['AFIB', 'B', 'MISSB', 'N', 'PSE', 'T']
220 → ['N', 'SVTA']
221 → ['AFIB', 'B', 'T', 'VT']
222 → ['AB', 'AFIB', 'AFL', 'N', 'NOD', 'SVTA']
223 → ['B', 'N', 'T', 'VT']
228 → ['B', 'N', 'TS']
230 → ['N', 'PREX']
231 → ['BII', 'MISSB', 'N']
232 → 