In [1]:
import os
import torch
import numpy as np
import scipy.io as scio
from sklearn.preprocessing import LabelEncoder

In [None]:
def standardization(data):
    mu = np.mean(data)
    sigma = np.std(data)
    return (data - mu) / sigma

SAVE_PATH = './datasets/mitbih_all_data.pt'
DATASET_PATH = './mit-bih'
CLASSES = ['NSR', 'AFIB', 'APB', 'PVC', 'SDHB']
le = LabelEncoder()
le.classes_ = np.array(CLASSES)

subject_data = {}

for root, _, files in os.walk(DATASET_PATH):
    for fname in files:
        if not fname.endswith('.mat'):
            continue
        
        try:
            # Get class from subfolder name like "0_NSR"
            folder_name = os.path.basename(root)
            class_index = int(folder_name.split(' ')[0])-1
            label_text = CLASSES[class_index]

            fpath = os.path.join(root, fname)
            mat = scio.loadmat(fpath)
            signal = mat.get('val')
            signal = signal[0]

            # Downsample to 2500
            indices = np.linspace(0, len(signal) - 1, 2500, dtype=int)
            signal = signal[indices]
            signal = standardization(signal)

            # Subject ID from filename
            subject_id = int(fname[:3])
            label_idx = le.transform([label_text])[0]

            if subject_id not in subject_data:
                subject_data[subject_id] = {'x': [], 'y': []}

            subject_data[subject_id]['x'].append(torch.tensor(signal, dtype=torch.float32))
            subject_data[subject_id]['y'].append(torch.tensor(label_idx))

        except Exception as e:
            print(f"Skipping {fname}: {e}")

# Convert each subject's list of x/y to tensor stacks
for sid in subject_data:
    subject_data[sid]['x'] = torch.stack(subject_data[sid]['x'])
    subject_data[sid]['y'] = torch.stack(subject_data[sid]['y'])

torch.save({
    'data_by_subject': subject_data,
    'label_encoder': le
}, SAVE_PATH)

print(f"Saved MIT-BIH data grouped by subject to: {SAVE_PATH}")


Saved MIT-BIH data grouped by subject to: ./mitbih_all_data.pt


In [5]:
print(le.classes_)

['NSR' 'AFIB' 'APB' 'PVC' 'SDHB']


In [30]:
from torch.utils.data import Dataset
from typing import Union, List, Optional
from sklearn.model_selection import train_test_split

In [None]:
class MITBIH_Dataset(Dataset):
    """
    A PyTorch Dataset for ECG samples grouped by subject (MIT-BIH).

    Args:
        data_path (str): Path to the .pt file with data grouped by subject.
        normal (bool or None): If True, include only normal (label == 0); if False, only abnormal; if None, include all.
        subject_ids (Union[float, List[float], None]): Subject(s) to include. If None, uses all subjects.
        split (str): 'train', 'test', or None — whether to return a subset.
        test_ratio (float): Proportion to reserve for test split.
        random_seed (int): Random seed for reproducibility.
    """
    def __init__(
        self,
        data_path: str,
        normal: Optional[bool] = None,
        subject_ids: Optional[Union[float, List[float]]] = None,
        split: Optional[str] = None,
        test_ratio: float = 0.7,
        random_seed: int = 42
    ):
        assert split in [None, 'train', 'test'], "split must be None, 'train', or 'test'"

        raw_data = torch.load(data_path)
        all_subject_data = raw_data['data_by_subject']
        self.label_encoder = raw_data['label_encoder']

        # Normalize subject_ids
        if subject_ids is None:
            selected_subjects = list(all_subject_data.keys())
        elif isinstance(subject_ids, float):
            selected_subjects = [subject_ids]
        else:
            selected_subjects = subject_ids

        # Collect and filter samples
        all_samples = []
        for sid in selected_subjects:
            subject_data = all_subject_data[sid]
            for x, y in zip(subject_data['x'], subject_data['y']):
                y_int = int(y.item())
                if normal is True and y_int != 0:
                    continue  # keep only label==0
                if normal is False and y_int == 0:
                    continue  # exclude label==0
                if normal is False:
                    y_int -= 1  # shift all labels by -1 so they start from 0

                all_samples.append((x, torch.tensor(y_int)))

        if len(all_samples) == 0:
            raise ValueError("No samples in the dataset. Check filtering conditions.")

        # Optional split
        if split is not None:
            stratify_labels = [int(y) for _, y in all_samples]
            train_idx, test_idx = train_test_split(
                range(len(all_samples)),
                test_size=test_ratio,
                random_state=random_seed,
                stratify=stratify_labels
            )
            indices = train_idx if split == 'train' else test_idx
            self.samples = [all_samples[i] for i in indices]
        else:
            self.samples = all_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return x.unsqueeze(0), y  # Add channel dimension (1, 2500)
