# Dataset Generation



## Info

预训练所需的数据集不需要label, 所以所有数据集都可以使用
但, 下游分类任务需要有label的数据集来进行监督训练, 所以只能采用有label的数据集

可以成功处理的数据集:

- `afdb` - MIT-BIH Atrial Fibrillation Database
- `apnea-ecg` - Apnea-ECG Database
- `edb` - European ST-T Database
- `fantasia` - Fantasia Database
- `incartdb` - St Petersburg INCART 12-lead Arrhythmia Database
- `ltafdb` - Long Term AF Database
- `mitdb` - MIT-BIH Arrhythmia
- `nsrdb` - MIT-BIH Normal Sinus Rhythm Database
- `ptbdb` - PTB Diagnostic ECG Database
- `qtdb` - QT Database
- `sddb` - Sudden Cardiac Death Holter Database

等待处理的数据集:

- `ptb-xl` - (1.7G)PTB-XL, a large publicly available
- `chapman` - (2.3G)A large scale 12-lead electrocardiogram database for arrhythmia study

以下数据集分别包含多个数据集, 处理起来比较麻烦, 暂时不考虑

- `challenge-2017` - (1.4G)AF Classification from a Short Single Lead ECG Recording: The PhysioNet/Computing in Cardiology Challenge 2017
- `challenge-2020` - (不确定大小)Classification of 12-lead ECGs: The PhysioNet/Computing in Cardiology Challenge 2020
- `chfdb` - (580.5MB)BIDMC Congestive Heart Failure Database
- `challenge-2022` - (1.3GB, 但包含多个数据集)Heart Murmur Detection from Phonocardiogram Recordings: The George B. Moody PhysioNet Challenge 2022

以下数据集太大, 暂时不考虑

- `ltstdb` - (9.5GB)Long Term ST Database
- `mimic3wdb` - (6.7TB)MIMIC-III Waveform Database

In [1]:
import sys

sys.path.append("..")

## Part 1: Pretraining Dataset Builder (v1, single-lead)

### 1.1 Setup & Global Config

Goal of this notebook:

1. For multiple PhysioNet ECG databases, build a unified pretraining dataset:
   - Single-lead only (for now)
   - Unified sampling rate (TARGET_FS)
   - Fixed window length (WINDOW_SEC)
   - Per-window z-score normalization
2. For each database, save sharded `.npy` files to avoid huge files.
3. Optionally merge all databases into a global pretraining dataset (also sharded).

Later we will extend from single-lead to 12-lead, but v1 focuses on single-lead.

In [2]:
import os
import wfdb
import numpy as np

from tqdm import tqdm
from pathlib import Path
from scipy.signal import resample_poly

In [3]:
# --------------------
# Global config
# --------------------
TARGET_FS = 500  # target sampling rate in Hz
WINDOW_SEC = 10  # window length in seconds
WINDOW_STRIDE_SEC = 10  # stride for pretraining windows (no overlap for now)

# Each shard will contain at most this many windows
MAX_WINDOWS_PER_SHARD = 100_000

# Root paths (adapt these to your own environment)
DATA_ROOT = Path(r"../data")
OUT_ROOT = DATA_ROOT / f"pretrain/pretrain_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s"

OUT_ROOT.mkdir(parents=True, exist_ok=True)

# Databases that are already confirmed to be readable
DBS = {
    "afdb": DATA_ROOT / "afdb",
    "apnea": DATA_ROOT / "apnea",
    "edb": DATA_ROOT / "edb",
    "fantasia": DATA_ROOT / "fantasia",
    "incartdb": DATA_ROOT / "incartdb",
    "ltafdb": DATA_ROOT / "ltafdb",
    "mitdb": DATA_ROOT / "mitdb",
    "nsrdb": DATA_ROOT / "nsrdb",
    "ptbdb": DATA_ROOT / "ptbdb",
    "qtdb": DATA_ROOT / "qtdb",
    "sddb": DATA_ROOT / "sddb",
    # "ptb-xl": DATA_ROOT / "ptb-xl",
    # "chapman": DATA_ROOT / "chapman",
}

# Databases you plan to handle later (non-WFDB, CSV-based, etc.)
DBS_TODO = {
    # "ptb-xl": DATA_ROOT / "ptb-xl",
    "chapman": DATA_ROOT
    / "chapman",
}

print("TARGET_FS:", TARGET_FS)
print("WINDOW_SEC:", WINDOW_SEC)
print("OUT_ROOT:", OUT_ROOT)
print("Ready DBs:", list(DBS.keys()))
print("TODO DBs:", list(DBS_TODO.keys()))

TARGET_FS: 500
WINDOW_SEC: 10
OUT_ROOT: ..\data\pretrain\pretrain_singlelead_500hz_10s
Ready DBs: ['afdb', 'apnea', 'edb', 'fantasia', 'incartdb', 'ltafdb', 'mitdb', 'nsrdb', 'ptbdb', 'qtdb', 'sddb']
TODO DBs: ['chapman']


### Inspect lead names for each database

Different PhysioNet databases use different lead configurations and names:
- Some are ["MLII", "V5"]
- Some are ["I", "II", "III", "aVR", "aVL", "aVF", "V1", ...]
- Some long-term Holter recordings only have 1–2 channels.

We first implement a helper function to:
- Scan a few `.hea` files for each DB
- Print:
  - record name
  - number of signals
  - signal names (`sig_name`)

This helps us design a robust `pick_single_lead(db_name, rec)` function later.

In [4]:
def inspect_leads_for_db(db_name: str, db_dir: Path, max_records: int = 5):
    """
    Inspect a few records from a PhysioNet-style database and print lead info.
    """
    hea_files = sorted(db_dir.glob("**/*.hea"))
    if not hea_files:
        print(f"[{db_name}] No .hea files found in {db_dir}")
        return

    print(f"\n=== Inspecting DB: {db_name} ===")
    print("Total .hea files:", len(hea_files))

    for hea in hea_files[:max_records]:
        rec_name = str(hea.with_suffix(""))
        try:
            rec = wfdb.rdrecord(rec_name)
        except Exception as e:
            print(f"  [WARN] Failed to read {rec_name}: {e}")
            continue

        sig_names = getattr(rec, "sig_name", None)
        n_sig = getattr(rec, "n_sig", None)

        print(f"  Record: {hea.name}")
        print(f"    n_sig: {n_sig}")
        print(f"    sig_name: {sig_names}")


for name, path in DBS.items():
    inspect_leads_for_db(name, path, max_records=3)


=== Inspecting DB: afdb ===
Total .hea files: 25
  [WARN] Failed to read ..\data\afdb\1.0.0\00735: sampto must be greater than sampfrom
  [WARN] Failed to read ..\data\afdb\1.0.0\03665: sampto must be greater than sampfrom
  Record: 04015.hea
    n_sig: 2
    sig_name: ['ECG1', 'ECG2']

=== Inspecting DB: apnea ===
Total .hea files: 86
  Record: a01.hea
    n_sig: 1
    sig_name: ['ECG']
  [WARN] Failed to read ..\data\apnea\a01er: Samples were not loaded correctly
  Record: a01r.hea
    n_sig: 4
    sig_name: ['Resp C', 'Resp A', 'Resp N', 'SpO2']

=== Inspecting DB: edb ===
Total .hea files: 90
  Record: e0103.hea
    n_sig: 2
    sig_name: ['V4', 'MLIII']
  Record: e0104.hea
    n_sig: 2
    sig_name: ['MLIII', 'V4']
  Record: e0105.hea
    n_sig: 2
    sig_name: ['MLIII', 'V4']

=== Inspecting DB: fantasia ===
Total .hea files: 40
  Record: f1o01.hea
    n_sig: 2
    sig_name: ['RESP', 'ECG']
  Record: f1o02.hea
    n_sig: 2
    sig_name: ['RESP', 'ECG']
  Record: f1o03.hea
    n_

### 1.3 Single-lead Selection Strategy

We create a `LEAD_PREFERENCE` mapping:

- For each database name, specify an ordered list of preferred lead names.
- When processing a record:
  1. If the record contains any preferred lead, we use the first found.
  2. If none of them exist, we fall back to the first channel and print a warning.

This function will be used by all later preprocessing steps.

Later, when we move to 12-lead, we can replace this with a `pick_multi_lead(db_name, rec)` that returns [T, n_leads].

In [5]:
# We should adjust these preferences after looking at Part 1 output if needed.
LEAD_PREFERENCE = {
    # Long-term AF / Holter style databases
    "afdb": ["ECG1", "ECG2"],
    "ltafdb": ["ECG"],  # both channels named "ECG"
    "sddb": ["ECG"],  # both channels named "ECG"
    "nsrdb": ["ECG1", "ECG2"],
    # Apnea-ECG: only records with real ECG channel
    # Some records only have respiration/SpO2; we will treat them specially
    "apnea-ecg": ["ECG"],
    # ST-T / ischemia related
    "edb": ["MLIII", "V4"],
    # Fantasia: [RESP, ECG]
    "fantasia": ["ECG"],
    # Incart: clean 12-lead, standard names
    "incartdb": ["I", "II", "MLII"],
    # Classic arrhythmia (MIT-BIH)
    "mitdb": ["MLII", "II", "V1", "V5"],
    # QT database: very similar style to MIT-BIH
    "qtdb": ["MLII", "II", "V5", "V2"],
    # PTB Diagnostic: 15-lead, lowercase names
    "ptbdb": ["i", "ii"],
}


def pick_single_lead(db_name: str, rec: wfdb.Record) -> np.ndarray | None:
    """
    Select a single lead from a multi-lead ECG record.

    Returns:
        1D signal of shape [T], or None if no suitable lead is found.
    """
    sig = rec.p_signal  # shape: [T, n_sig] or [T]
    sig_names = getattr(rec, "sig_name", None)
    n_sig = getattr(rec, "n_sig", sig.shape[1] if sig.ndim == 2 else 1)

    # Single-channel case: just return it
    if sig.ndim == 1 or n_sig == 1:
        return sig.astype(np.float32).reshape(-1)

    # Multi-channel case: try preference list
    prefs = LEAD_PREFERENCE.get(db_name, None)
    if prefs is not None and sig_names is not None:
        name_to_idx = {name: idx for idx, name in enumerate(sig_names)}
        for pref in prefs:
            if pref in name_to_idx:
                idx = name_to_idx[pref]
                return sig[:, idx].astype(np.float32).reshape(-1)

    # Special handling: for some DBs we prefer to SKIP records without ECG
    # rather than falling back to an arbitrary channel (e.g., apnea-ecg with Resp only).
    dbs_strict = {"apnea-ecg"}  # add more if needed

    if db_name in dbs_strict:
        print(
            f"[INFO] {db_name}: no preferred ECG lead found, skipping this record. "
            f"sig_names={sig_names}"
        )
        return None

    # Fallback: use first channel, but warn
    print(f"[WARN] {db_name}: using channel 0 as fallback, sig_names={sig_names}")
    return sig[:, 0].astype(np.float32).reshape(-1)

### 1.4 Core preprocessing functions

We implement three core utilities:

1. `resample_signal`: resample a 1D signal from its original sampling rate to `TARGET_FS`.
2. `segment_windows`: cut a long 1D signal into fixed-length windows.
3. `zscore_windows`: apply per-window z-score normalization.

In [6]:
def resample_signal(sig: np.ndarray, fs_orig: int, fs_target: int) -> np.ndarray:
    """
    Resample a 1D signal from fs_orig to fs_target using polyphase resampling.
    """
    if fs_orig == fs_target:
        return sig.astype(np.float32)

    # For simplicity, directly use fs_target, fs_orig as up/down.
    up = fs_target
    down = fs_orig
    sig_res = resample_poly(sig, up, down)
    return sig_res.astype(np.float32)


def segment_windows(
    sig: np.ndarray, fs: int, window_sec: int, stride_sec: int
) -> np.ndarray:
    """
    Segment a long 1D signal into fixed-length windows.

    Args:
        sig: 1D array of shape [T]
        fs: sampling rate
        window_sec: window length in seconds
        stride_sec: stride in seconds

    Returns:
        windows: array of shape [N, window_len]
    """
    window_len = fs * window_sec
    stride = fs * stride_sec

    if len(sig) < window_len:
        return np.zeros((0, window_len), dtype=np.float32)

    windows = []
    start = 0
    while start + window_len <= len(sig):
        win = sig[start : start + window_len]
        windows.append(win)
        start += stride

    if not windows:
        return np.zeros((0, window_len), dtype=np.float32)

    return np.stack(windows, axis=0)


def zscore_windows(windows: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """
    Apply per-window z-score normalization.

    Args:
        windows: array of shape [N, L]

    Returns:
        normalized windows, same shape, float32
    """
    mean = windows.mean(axis=1, keepdims=True)
    std = windows.std(axis=1, keepdims=True)
    std = np.where(std < eps, 1.0, std)
    windows_norm = (windows - mean) / std
    return windows_norm.astype(np.float32)

### 1.4 Sharded saving for a single database

For each database:

1. Iterate over all `.hea` files.
2. For each record:
   - Read the record (`wfdb.rdrecord`).
   - Select a single lead (using `pick_single_lead`).
   - Resample to `TARGET_FS`.
   - Segment into windows (`WINDOW_SEC`, `WINDOW_STRIDE_SEC`).
   - Apply per-window z-score.
3. Accumulate windows in memory until we reach `MAX_WINDOWS_PER_SHARD`, then:
   - Save them as a shard npy file.
   - Reset the buffer.
4. At the end, save any remaining windows as the last shard.

Shard filename format:
`{db_name}_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shardXXX.npy`

In [7]:
def save_shard(windows_list, db_name: str, shard_id: int, out_dir: Path):
    """
    Save accumulated windows as a shard npy file.
    """
    if not windows_list:
        return None

    data = np.concatenate(windows_list, axis=0)
    shard_name = (
        f"{db_name}_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shard{shard_id:03d}.npy"
    )
    out_path = out_dir / shard_name
    np.save(out_path, data.astype(np.float32))
    print(f"  Saved shard {shard_id} -> {out_path}, windows={data.shape[0]}")
    return out_path


def process_db_to_shards(db_name: str, db_dir: Path, out_dir: Path):
    """
    Process a single PhysioNet-style database into sharded pretraining npy files.
    """
    out_dir.mkdir(parents=True, exist_ok=True)

    hea_files = sorted(db_dir.glob("**/*.hea"))
    if not hea_files:
        print(f"[{db_name}] No .hea files found at {db_dir}")
        return

    print(f"\n=== Processing DB: {db_name} ===")
    print("Total .hea files:", len(hea_files))

    shard_id = 0
    current_windows = []
    current_count = 0

    for hea in tqdm(hea_files, desc=f"{db_name}"):
        rec_name = str(hea.with_suffix(""))
        try:
            rec = wfdb.rdrecord(rec_name)
        except Exception as e:
            print(f"  [WARN] Failed to read {rec_name}: {e}")
            continue

        fs_orig = int(rec.fs)
        sig_1d = pick_single_lead(db_name, rec)

        # Skip records without valid ECG lead
        if sig_1d is None:
            continue

        sig_res = resample_signal(sig_1d, fs_orig, TARGET_FS)

        wins = segment_windows(
            sig_res,
            fs=TARGET_FS,
            window_sec=WINDOW_SEC,
            stride_sec=WINDOW_STRIDE_SEC,
        )

        if wins.size == 0:
            continue

        wins = zscore_windows(wins)
        current_windows.append(wins)
        current_count += wins.shape[0]

        if current_count >= MAX_WINDOWS_PER_SHARD:
            save_shard(current_windows, db_name, shard_id, out_dir)
            shard_id += 1
            current_windows = []
            current_count = 0

    if current_windows:
        save_shard(current_windows, db_name, shard_id, out_dir)

    print(f"Finished DB: {db_name}")

### 1.5 Run preprocessing for all ready databases

Now we call `process_db_to_shards` for each database in `DBS`.

You can:
- First test on a single DB (e.g., "mitdb") by commenting others.
- After confirming everything is correct, run all.

In [8]:
# Quick test on a single DB first (e.g., mitdb)
# process_db_to_shards("mitdb", DBS["mitdb"], OUT_ROOT)

In [9]:
# for db_name, db_dir in DBS.items():
#     process_db_to_shards(db_name, db_dir, OUT_ROOT)

In [10]:
# for db_name, db_dir in DBS_TODO.items():
#     process_db_to_shards(db_name, db_dir, OUT_ROOT)

### 1.6 Merge all databases into global pretraining shards

We now:

1. List all per-DB shard files in `OUT_ROOT`.
2. Iterate over them in sorted order.
3. Load each shard, accumulate windows into a buffer.
4. When the buffer reaches `MAX_WINDOWS_PER_SHARD`, save a global shard:
   `all_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shardXXX.npy`
5. Save any remaining windows at the end.

This step is optional:
- If your training pipeline can handle multiple per-DB shards directly, you do not need a global merge.
- If you prefer a single “global pretrain dataset”, global shards make it convenient.

In [11]:
# def build_global_pretrain_shards(
#     out_root: Path, max_windows_per_shard: int = MAX_WINDOWS_PER_SHARD
# ):
#     """
#     Merge all per-DB shards into global pretraining shards.
#     """
#     shard_files = sorted(
#         out_root.glob(f"*singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shard*.npy")
#     )
#     if not shard_files:
#         print("No per-DB shard files found in", out_root)
#         return

#     print("Found per-DB shard files:", len(shard_files))

#     global_shard_id = 0
#     buffer = []
#     buffer_count = 0

#     for path in shard_files:
#         print("Loading", path)
#         arr = np.load(path)

#         buffer.append(arr)
#         buffer_count += arr.shape[0]

#         if buffer_count >= max_windows_per_shard:
#             data = np.concatenate(buffer, axis=0)
#             fname = f"all_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shard{global_shard_id:03d}.npy"
#             out_path = out_root / fname
#             np.save(out_path, data.astype(np.float32))
#             print(
#                 f"  Saved global shard {global_shard_id} -> {out_path}, windows={data.shape[0]}"
#             )

#             global_shard_id += 1
#             buffer = []
#             buffer_count = 0

#     # Remaining windows
#     if buffer:
#         data = np.concatenate(buffer, axis=0)
#         fname = (
#             f"all_singlelead_{TARGET_FS}hz_{WINDOW_SEC}s_shard{global_shard_id:03d}.npy"
#         )
#         out_path = out_root / fname
#         np.save(out_path, data.astype(np.float32))
#         print(
#             f"  Saved final global shard {global_shard_id} -> {out_path}, windows={data.shape[0]}"
#         )

#     print("Done building global pretraining shards.")


# build_global_pretrain_shards(OUT_ROOT)