In [1]:
import os
import pandas as pd
import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
import wfdb

In [None]:
# === CONFIG ===
BASE_PATH = "./ptb-xl/"
CSV_PATH = os.path.join(BASE_PATH, "ptbxl_database.csv")
SAVE_PATH = "./datasets/ptbxl_all_data.pt"
SEGMENT_LENGTH = 2500  # 10 seconds at 250 Hz
TARGET_LABELS = ['NORM', 'AFIB', 'PAC', 'PVC', 'SBRAD', 'STACH']
ARTIFACT_COLUMNS = ['baseline_drift', 'static_noise', 'burst_noise', 'electrodes_problems']

# === LOAD METADATA ===
df = pd.read_csv(CSV_PATH)
df = df[df['ecg_id'].notna()]
df['scp_codes'] = df['scp_codes'].apply(eval)

# === Load and filter valid diagnostic SCP codes ===
scp_statements = pd.read_csv(os.path.join(BASE_PATH, 'scp_statements.csv'), index_col=0)
diagnostic_classes = list(scp_statements.index)

# === Select highest-confidence SCP code among TARGET_LABELS
def select_top_label(code_dict):
    filtered = {k: v for k, v in code_dict.items() if k in diagnostic_classes and k in TARGET_LABELS}
    if not filtered:
        return None
    return max(filtered, key=filtered.get)

df['label'] = df['scp_codes'].map(select_top_label)
df = df[df['label'].notna()]

# === Remove rows where artifacts in Lead II are noted
for col in ARTIFACT_COLUMNS:
    df = df[~df[col].fillna('').str.contains(r'\bII\b', regex=True)]

# === Encode final labels
le = LabelEncoder()
le.classes_ = np.array(TARGET_LABELS)
df['label_idx'] = le.transform(df['label'])

# === Extract ECG Lead II and group by subject
subject_data = {}
print("Extracting and grouping ECGs by subject...")

for _, row in df.iterrows():
    patient_id = row['patient_id']
    file_path = os.path.join(BASE_PATH, row['filename_hr'])

    try:
        record = wfdb.rdrecord(file_path)
        ecg = record.p_signal.T  # shape: (12, N_samples)
        lead2 = ecg[1] # LEAD II
        downsampled = lead2[::2]  # 500 Hz → 250 Hz
        normalized = (downsampled - np.mean(downsampled)) / np.std(downsampled)

        x_tensor = torch.tensor(normalized, dtype=torch.float32)
        y_tensor = torch.tensor(row['label_idx'])

        if patient_id not in subject_data:
            subject_data[patient_id] = {'x': [], 'y': []}

        subject_data[patient_id]['x'].append(x_tensor)
        subject_data[patient_id]['y'].append(y_tensor)

    except Exception as e:
        print(f"Skipping {row['filename_hr']} (patient {patient_id}): {e}")

# === Convert lists to tensors
for pid in subject_data:
    subject_data[pid]['x'] = torch.stack(subject_data[pid]['x'])
    subject_data[pid]['y'] = torch.stack(subject_data[pid]['y'])

# === Save dataset
torch.save({
    'data_by_subject': subject_data,
    'label_encoder': le
}, SAVE_PATH)

print(f"Saved PTBXL data grouped by subject to: {SAVE_PATH}")

Extracting and grouping ECGs by subject...
Saved PTBXL data grouped by subject to: ./ptbxl_normal_data.pt


In [44]:
from torch.utils.data import Dataset
from typing import Union, List, Optional
from sklearn.model_selection import train_test_split

In [None]:
class PTBXL_Dataset(Dataset):
    """
    A PyTorch Dataset for ECG samples grouped by subject (PTB-XL).

    Args:
        data_path (str): Path to the .pt file with data grouped by subject.
        subject_ids (Union[float, List[float], None]): Subject(s) to include. If None, uses all subjects.
        split (str): 'train', 'test', or None — whether to return a subset.
        test_ratio (float): Proportion to reserve for test split (if split is specified).
        random_seed (int): Random seed for reproducibility.
    """
    def __init__(
        self,
        data_path: str,
        subject_ids: Optional[Union[float, List[float]]] = None,
        split: Optional[str] = None,
        test_ratio: float = 0.2,
        val_ratio: float = 0.2,
        random_seed: int = 42
    ):
        assert split in [None, 'train', 'val', 'test'], "split must be None, 'train', 'val', or 'test'"

        raw_data = torch.load(data_path)
        all_subject_data = raw_data['data_by_subject']
        self.label_encoder = raw_data['label_encoder']
        # label index: [0:'AFIB' 1:'NORM' 2:'PAC' 3:'PVC' 4:'SBRAD' 5:'STACH']

        # Normalize subject_ids to list
        if subject_ids is None:
            selected_subjects = list(all_subject_data.keys())
        elif isinstance(subject_ids, float):
            selected_subjects = [subject_ids]
        else:
            selected_subjects = subject_ids

        # Collect all (x, y) pairs per subject
        all_samples = []

        for sid in selected_subjects:
            subject_data = all_subject_data[sid]
            x_list = subject_data['x']
            y_list = subject_data['y']
            samples = list(zip(x_list, y_list))
            all_samples.extend(samples)

        # Split based on sample count
        if split is not None:
            stratify_labels = [int(y) for _, y in all_samples]

            # First split into temp (train+val) and test
            temp_idx, test_idx = train_test_split(
                range(len(all_samples)),
                test_size=test_ratio,
                random_state=random_seed,
                stratify=stratify_labels
            )

            temp_samples = [all_samples[i] for i in temp_idx]
            temp_labels = [int(y) for _, y in temp_samples]

            # Now split temp into train and val
            train_idx, val_idx = train_test_split(
                range(len(temp_samples)),
                test_size=val_ratio,
                random_state=random_seed,
                stratify=temp_labels
            )

            if split == 'train':
                indices = [temp_idx[i] for i in train_idx]
            elif split == 'val':
                indices = [temp_idx[i] for i in val_idx]
            elif split == 'test':
                indices = test_idx

            self.samples = [all_samples[i] for i in indices]
        else:
            self.samples = all_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return x.unsqueeze(0), y  # Add channel dimension: (1, 2500)
