In [1]:
import os, glob
import numpy as np
import h5py
from scipy.interpolate import interp1d

# =========================
# Config
# =========================
PULSEDB_DIR = "/content/drive/MyDrive/Colab Notebooks/PulseDB"
SEGMENT_LIMIT = None  # 필요시 int로 제한
FILTERING_SWEEP = [False, True]

# BP thresholds (segment-wise)
SBP_HYPER = 140
DBP_HYPER = 90
SBP_HYPO  = 90
DBP_HYPO  = 60

# =========================
# Filtering helper (same as yours)
# =========================
def preprocess_ensemble_by_rpeaks(ppg_raw, rpeaks_raw, sbp, dbp, target_len=125, threshold_corr=0.7):
    # BP range filter (same)
    if not (50 <= sbp <= 250) or not (30 <= dbp <= 160):
        return None

    ppg = ppg_raw.squeeze()
    rpeaks = rpeaks_raw.squeeze()
    rpeaks = np.sort(rpeaks.astype(int))

    beats = []
    for i in range(len(rpeaks) - 1):
        start, end = rpeaks[i], rpeaks[i + 1]
        if start < 0 or end > len(ppg):
            continue
        beat_segment = ppg[start:end]
        if len(beat_segment) < 20:
            continue

        x_old = np.linspace(0, 1, len(beat_segment))
        x_new = np.linspace(0, 1, target_len)
        f_interp = interp1d(x_old, beat_segment, kind='linear', fill_value="extrapolate")
        beats.append(f_interp(x_new))

    if len(beats) < 5:
        return None

    beats = np.array(beats)
    ensemble_avg = np.mean(beats, axis=0)

    # normalize 0~1
    e_min, e_max = ensemble_avg.min(), ensemble_avg.max()
    if e_max - e_min > 1e-6:
        ensemble_avg = (ensemble_avg - e_min) / (e_max - e_min)

    # consistency check
    correlations = [np.corrcoef(ensemble_avg, b)[0, 1] for b in beats]
    consistent_beats_count = sum(1 for c in correlations if c >= threshold_corr)
    if (consistent_beats_count / len(beats)) < 0.7:
        return None

    return ensemble_avg.astype(np.float32)

# =========================
# Load labels (SegSBP/SegDBP) with same inclusion logic as filtering
# =========================
def load_labels_from_mat(mat_path, segment_limit=None, filtering=False):
    sbp_list, dbp_list = [], []
    skip_bp, skip_noise = 0, 0

    with h5py.File(mat_path, "r") as f:
        sw = f["Subj_Wins"]
        ppg_refs = sw["PPG_F"][0]
        sbp_refs = sw["SegSBP"][0]
        dbp_refs = sw["SegDBP"][0]

        total = min(len(ppg_refs), segment_limit) if segment_limit else len(ppg_refs)

        if not filtering:
            for i in range(total):
                sbp = float(f[sbp_refs[i]][()][0][0])
                dbp = float(f[dbp_refs[i]][()][0][0])
                sbp_list.append(sbp)
                dbp_list.append(dbp)
            return np.array(sbp_list, np.float32), np.array(dbp_list, np.float32), {"kept": len(sbp_list), "total": total, "skip_bp": 0, "skip_noise": 0}

        # filtering=True: keep only segments that pass preprocess_ensemble_by_rpeaks
        if "ECG_RPeaks" not in sw:
            raise KeyError("This patient file has no 'ECG_RPeaks' in Subj_Wins. filtering=True cannot be applied.")

        ecg_refs = sw["ECG_RPeaks"][0]
        for i in range(total):
            ppg_raw = f[ppg_refs[i]][()]
            sbp = float(f[sbp_refs[i]][()][0][0])
            dbp = float(f[dbp_refs[i]][()][0][0])
            rpeaks_raw = f[ecg_refs[i]][()]

            processed_ppg = preprocess_ensemble_by_rpeaks(ppg_raw, rpeaks_raw, sbp, dbp)

            if processed_ppg is None:
                if not (50 <= sbp <= 250) or not (30 <= dbp <= 160):
                    skip_bp += 1
                else:
                    skip_noise += 1
                continue

            sbp_list.append(sbp)
            dbp_list.append(dbp)

    return np.array(sbp_list, np.float32), np.array(dbp_list, np.float32), {"kept": len(sbp_list), "total": total, "skip_bp": skip_bp, "skip_noise": skip_noise}

# =========================
# Stats
# =========================
def summarize_patient_labels(sbp, dbp):
    if len(sbp) == 0:
        return None

    def q(x, p):  # percentile
        return float(np.percentile(x, p))

    hyper_mask = (sbp >= SBP_HYPER) | (dbp >= DBP_HYPER)
    hypo_mask  = (sbp <  SBP_HYPO)  | (dbp <  DBP_HYPO)

    return {
        "n": int(len(sbp)),
        "sbp_min": float(np.min(sbp)),
        "sbp_med": float(np.median(sbp)),
        "sbp_p95": q(sbp, 95),
        "sbp_max": float(np.max(sbp)),
        "dbp_min": float(np.min(dbp)),
        "dbp_med": float(np.median(dbp)),
        "dbp_p95": q(dbp, 95),
        "dbp_max": float(np.max(dbp)),
        "hyper_pct": float(100.0 * np.mean(hyper_mask)),
        "hypo_pct":  float(100.0 * np.mean(hypo_mask)),
    }

def print_patient_summary(patient_name, summ, meta):
    if summ is None:
        print(f"  ❌ {patient_name}: kept=0 (total={meta['total']}, skip_bp={meta['skip_bp']}, skip_noise={meta['skip_noise']})")
        return
    print(f"  ✅ {patient_name}: kept={meta['kept']}/{meta['total']} (skip_bp={meta['skip_bp']}, skip_noise={meta['skip_noise']})")
    print(f"     SBP  min/med/p95/max: {summ['sbp_min']:.1f} / {summ['sbp_med']:.1f} / {summ['sbp_p95']:.1f} / {summ['sbp_max']:.1f}")
    print(f"     DBP  min/med/p95/max: {summ['dbp_min']:.1f} / {summ['dbp_med']:.1f} / {summ['dbp_p95']:.1f} / {summ['dbp_max']:.1f}")
    print(f"     Hyper% (SBP≥{SBP_HYPER} or DBP≥{DBP_HYPER}): {summ['hyper_pct']:.2f}%")
    print(f"     Hypo%  (SBP<{SBP_HYPO} or DBP<{DBP_HYPO}): {summ['hypo_pct']:.2f}%")

def mean_std(arr):
    arr = np.asarray(arr, np.float32)
    if len(arr) == 0:
        return float("nan"), float("nan")
    ddof = 1 if len(arr) >= 2 else 0
    return float(arr.mean()), float(arr.std(ddof=ddof))

# =========================
# Run all patients
# =========================
def run_label_audit(pulsedb_dir):
    patient_files = sorted(glob.glob(os.path.join(pulsedb_dir, "p*.mat")))
    if not patient_files:
        raise FileNotFoundError(f"No patient .mat files under: {pulsedb_dir}")

    for filtering in FILTERING_SWEEP:
        print("\n" + "#"*90)
        print(f"[LABEL AUDIT] filtering={filtering}  (ABP-derived SegSBP/SegDBP)")
        print("#"*90)

        collected = []  # (patient, summary)
        skipped = []

        for fp in patient_files:
            p = os.path.basename(fp)
            try:
                sbp, dbp, meta = load_labels_from_mat(fp, segment_limit=SEGMENT_LIMIT, filtering=filtering)
                summ = summarize_patient_labels(sbp, dbp)
                print_patient_summary(p, summ, meta)
                if summ is not None:
                    collected.append((p, summ))
                else:
                    skipped.append((p, "kept=0 after filtering"))
            except Exception as e:
                print(f"  ❌ {p}: ERROR -> {type(e).__name__}: {e}")
                skipped.append((p, f"{type(e).__name__}: {e}"))

        # ----- final aggregation across patients -----
        print("\n" + "-"*90)
        print(f"[PATIENT-LEVEL AGG] filtering={filtering} | used={len(collected)}/{len(patient_files)}")
        print("-"*90)

        # aggregate per-patient summary values (each patient contributes 1 number)
        def collect_key(key):
            return [s[key] for _, s in collected]

        for key, title in [
            ("sbp_med", "SBP median"),
            ("sbp_p95", "SBP p95"),
            ("sbp_max", "SBP max"),
            ("dbp_med", "DBP median"),
            ("dbp_p95", "DBP p95"),
            ("dbp_max", "DBP max"),
            ("hyper_pct", "Hyper%"),
            ("hypo_pct", "Hypo%")
        ]:
            m, sd = mean_std(collect_key(key))
            print(f"  {title:10s}: mean={m:.2f} | std={sd:.2f}")

        if skipped:
            print("\n[SKIPPED]")
            for p, r in skipped:
                print(f"  - {p}: {r}")

if __name__ == "__main__":
    run_label_audit(PULSEDB_DIR)



##########################################################################################
[LABEL AUDIT] filtering=False  (ABP-derived SegSBP/SegDBP)
##########################################################################################
  ✅ p017795.mat: kept=1726/1726 (skip_bp=0, skip_noise=0)
     SBP  min/med/p95/max: 70.5 / 94.8 / 117.8 / 132.1
     DBP  min/med/p95/max: 45.1 / 60.8 / 83.2 / 96.0
     Hyper% (SBP≥140 or DBP≥90): 3.19%
     Hypo%  (SBP<90 or DBP<60): 44.96%
  ✅ p021706.mat: kept=1763/1763 (skip_bp=0, skip_noise=0)
     SBP  min/med/p95/max: 94.7 / 121.1 / 144.6 / 176.6
     DBP  min/med/p95/max: 41.9 / 49.7 / 65.4 / 78.1
     Hyper% (SBP≥140 or DBP≥90): 6.41%
     Hypo%  (SBP<90 or DBP<60): 84.34%
  ✅ p028758.mat: kept=1754/1754 (skip_bp=0, skip_noise=0)
     SBP  min/med/p95/max: 70.0 / 97.1 / 128.0 / 155.8
     DBP  min/med/p95/max: 29.6 / 40.1 / 51.6 / 62.9
     Hyper% (SBP≥140 or DBP≥90): 1.08%
     Hypo%  (SBP<90 or DBP<60): 99.94%
  ✅ p029712.mat: kept=173