<a href="https://colab.research.google.com/github/EML-Labs/Dataset/blob/main/Dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [65]:
!pip install wfdb



In [85]:
import os
import os
import wfdb
import numpy as np
import glob
import csv
import torch
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader

In [69]:
from google.colab import drive

drive.mount('/content/drive',force_remount=True)

Mounted at /content/drive


In [70]:
def segment_and_label(record_path, record_name, label, window_sec=30, overlap=0.0, save_dir='segments'):
    try:
        record = wfdb.rdrecord(os.path.join(record_path, record_name))
        signal = record.p_signal
        fs = record.fs
        window_size = int(window_sec * fs)
        step_size = int(window_size * (1 - overlap))

        os.makedirs(save_dir, exist_ok=True)

        segment_id = 0
        for start in range(0, len(signal) - window_size + 1, step_size):
            segment = signal[start:start + window_size]
            fname = f"{record_name}_seg{segment_id:03d}.npy"
            np.save(os.path.join(save_dir, fname), segment)
            csv_rows.append([fname, label])
            segment_id += 1

        print(f"Saved {segment_id} segments for {record_name} with label {label}")
    except Exception as e:
        print(f"Failed to process {record_name}: {e}")

In [79]:
raw_data_dir =  "/content/drive/MyDrive/Datasets/1.0.0" # Adjust the path according to the drive
segment_save_dir = "/content/drive/MyDrive/Datasets/processed/1.0.0" # Adjust the path according to the drive
os.makedirs(segment_save_dir, exist_ok=True)


label_file = os.path.join(segment_save_dir, 'labels.csv')

csv_rows = [['filename', 'label']]

In [80]:
print(os.listdir(raw_data_dir))

['n28c.dat', 'p07c.dat', 'n36c.dat', 'p38.dat', 'p22.dat', 'n50.dat', 't41.dat', 't91.qrs', 'p19c.dat', 'n25.qrs', 'n42c.hea', 'p26c.dat', 'n32c.hea', 'p14c.hea', 'p45c.dat', 'n12c.hea', 'p30c.qrs', 'n05c.dat', 't83.hea', 'p19.qrs', 'n30.hea', 'n14.qrs', 't44.qrs', 'n36c.hea', 'p02c.hea', 'p23c.hea', 'n02.dat', 't81.hea', 'n17c.hea', 'p15c.hea', 't93.dat', 'n04c.hea', 'n04c.qrs', 't80.dat', 't07.hea', 'n38c.hea', 'p09.qrs', 't80.qrs', 't08.hea', 'n34c.hea', 'n44c.qrs', 'n17.hea', 'n06.dat', 'p33c.hea', 'p29c.qrs', 'n06c.qrs', 'n30c.qrs', 't72.qrs', 'p31.qrs', 'n18.hea', 't96.dat', 'p12.dat', 't77.qrs', 'p33.dat', 'n11c.qrs', 't27.hea', 'p28c.dat', 'n19.hea', 'n23.dat', 'p12.qrs', 't73.hea', 't45.qrs', 'p09c.hea', 'p22c.dat', 't39.hea', 't06.hea', 'p12.hea', 'p23.hea', 'n21.dat', 't29.hea', 'n13c.hea', 'p25.hea', 'n31.dat', 'p05.qrs', 'p31.hea', 'p47c.dat', 't90.qrs', 'n39c.dat', 'n29.dat', 'n46c.dat', 't55.hea', 't03.dat', 'p27c.qrs', 'p50c.hea', 'p32.qrs', 'p16c.dat', 'p32.dat', 'n16c

In [81]:
# 5. Get list of records starting with 'p', excluding those ending with 'c'

hea_files = sorted(glob.glob(os.path.join(raw_data_dir, 'p*.hea')))
record_names = [
    os.path.splitext(os.path.basename(f))[0]
    for f in hea_files
    if not os.path.basename(f).endswith('c')
]

print(f"Found {len(record_names)} records.")

Found 100 records.


In [82]:
# Process records
for record_name in record_names:
    try:
        rec_num = int(record_name[1:])  # Get number after 'p'
        label = 1 if rec_num % 2 == 0 else 0  # Even = pre-af (1), Odd = non-af (0)
        segment_and_label(record_path=raw_data_dir,
                          record_name=record_name,
                          label=label,
                          window_sec=10,
                          overlap=0.0,
                          save_dir=segment_save_dir)
    except ValueError:
        print(f"Skipping malformed record name: {record_name}")


Saved 180 segments for p01 with label 0
Skipping malformed record name: p01c
Saved 180 segments for p02 with label 1
Skipping malformed record name: p02c
Saved 180 segments for p03 with label 0
Skipping malformed record name: p03c
Saved 180 segments for p04 with label 1
Skipping malformed record name: p04c
Saved 180 segments for p05 with label 0
Skipping malformed record name: p05c
Saved 180 segments for p06 with label 1
Skipping malformed record name: p06c
Saved 180 segments for p07 with label 0
Skipping malformed record name: p07c
Saved 180 segments for p08 with label 1
Skipping malformed record name: p08c
Saved 180 segments for p09 with label 0
Skipping malformed record name: p09c
Saved 180 segments for p10 with label 1
Skipping malformed record name: p10c
Saved 180 segments for p11 with label 0
Skipping malformed record name: p11c
Saved 180 segments for p12 with label 1
Skipping malformed record name: p12c
Saved 180 segments for p13 with label 0
Skipping malformed record name: p13c

In [88]:
# Write labels.csv
with open(label_file, 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerows(csv_rows)

print(f"Saved label file to: {label_file}")


Saved label file to: /content/drive/MyDrive/Datasets/processed/1.0.0/labels.csv


In [94]:

class ECGSegmentDataset(Dataset):
    def __init__(self, segment_dir, transform=None):
        self.df_path = os.path.join(segment_dir, 'labels.csv')
        df = pd.read_csv(self.df_path)
        self.df = df.reset_index(drop=True)
        self.segment_dir = segment_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_path = os.path.join(self.segment_dir, row['filename'])

        ecg_segment = np.load(file_path).astype(np.float32)
        ecg_tensor = torch.from_numpy(ecg_segment).transpose(0, 1)  # [channels, samples]
        label = torch.tensor(row['label'], dtype=torch.long)

        if self.transform:
            ecg_tensor = self.transform(ecg_tensor)

        return ecg_tensor, label



In [95]:
segment_dir = '/content/drive/MyDrive/Datasets/processed/1.0.0'
label_csv = os.path.join(segment_dir, 'labels.csv')

# Load labels
df = pd.read_csv(label_csv)
print(df.describe())

# Stratified split
train_df, test_df = train_test_split(df, test_size=0.1, stratify=df['label'], random_state=42)

# Print class balance
print("Train label distribution:\n", train_df['label'].value_counts())
print("Test label distribution:\n", test_df['label'].value_counts())

             label
count  9000.000000
mean      0.500000
std       0.500028
min       0.000000
25%       0.000000
50%       0.500000
75%       1.000000
max       1.000000
Train label distribution:
 label
0    4050
1    4050
Name: count, dtype: int64
Test label distribution:
 label
0    450
1    450
Name: count, dtype: int64


In [107]:
test_dir = '/content/drive/MyDrive/Datasets/test'
train_dir = '/content/drive/MyDrive/Datasets/train'

os.makedirs(test_dir, exist_ok=True)
os.makedirs(train_dir, exist_ok=True)

test_df.to_csv(os.path.join(test_dir, 'label.csv'), index=False)
train_df.to_csv(os.path.join(train_dir, 'label.csv'), index=False)


In [103]:
for idx, row in test_df.iterrows():
    src_path = os.path.join(segment_dir, row['filename'])
    dst_path = os.path.join(test_dir, row['filename'])
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    os.rename(src_path, dst_path)


/content/drive/MyDrive/Datasets/processed/1.0.0/p19_seg095.npy
/content/drive/MyDrive/Datasets/test/p19_seg095.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p03_seg155.npy
/content/drive/MyDrive/Datasets/test/p03_seg155.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p32_seg035.npy
/content/drive/MyDrive/Datasets/test/p32_seg035.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p47_seg164.npy
/content/drive/MyDrive/Datasets/test/p47_seg164.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p42_seg119.npy
/content/drive/MyDrive/Datasets/test/p42_seg119.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p36_seg053.npy
/content/drive/MyDrive/Datasets/test/p36_seg053.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p39_seg162.npy
/content/drive/MyDrive/Datasets/test/p39_seg162.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p36_seg095.npy
/content/drive/MyDrive/Datasets/test/p36_seg095.npy
/content/drive/MyDrive/Datasets/processed/1.0.0/p34_seg010.npy
/content/drive/My

In [102]:
for idx, row in train_df.iterrows():
    src_path = os.path.join(segment_dir, row['filename'])
    dst_path = os.path.join(train_dir, row['filename'])
    os.makedirs(os.path.dirname(dst_path), exist_ok=True)
    os.rename(src_path, dst_path)

In [96]:
segment_dir = '/content/drive/MyDrive/Datasets/processed/1.0.0'
train_dataset = ECGSegmentDataset(segment_dir)
test_dataset = ECGSegmentDataset(segment_dir)

In [None]:
dir = '/content/drive/MyDrive/Datasets/2.0'

os.makedirs(dir,exist_ok=True)


test_csv_path = os.path.join(dir, 'test.csv')
train_csv_path = os.path.join(dir, 'train.csv')

train_df.to_csv(train_csv_path, index=False)
test_df.to_csv(test_csv_path, index=False)




In [106]:
import os

test_dir = '/content/drive/MyDrive/Datasets/test'
print(os.listdir(test_dir))

['p01_seg004.npy', 'p01_seg016.npy', 'p01_seg034.npy', 'p01_seg036.npy', 'p01_seg064.npy', 'p01_seg075.npy', 'p01_seg098.npy', 'p01_seg114.npy', 'p01_seg116.npy', 'p01_seg119.npy', 'p01_seg127.npy', 'p01_seg130.npy', 'p01_seg146.npy', 'p01_seg154.npy', 'p01_seg160.npy', 'p01_seg161.npy', 'p02_seg002.npy', 'p02_seg018.npy', 'p02_seg043.npy', 'p02_seg052.npy', 'p02_seg053.npy', 'p02_seg067.npy', 'p02_seg071.npy', 'p02_seg089.npy', 'p02_seg109.npy', 'p02_seg113.npy', 'p02_seg138.npy', 'p02_seg142.npy', 'p02_seg155.npy', 'p02_seg160.npy', 'p02_seg165.npy', 'p03_seg009.npy', 'p03_seg010.npy', 'p03_seg017.npy', 'p03_seg020.npy', 'p03_seg022.npy', 'p03_seg026.npy', 'p03_seg045.npy', 'p03_seg061.npy', 'p03_seg073.npy', 'p03_seg082.npy', 'p03_seg084.npy', 'p03_seg121.npy', 'p03_seg155.npy', 'p03_seg157.npy', 'p04_seg006.npy', 'p04_seg011.npy', 'p04_seg013.npy', 'p04_seg023.npy', 'p04_seg025.npy', 'p04_seg026.npy', 'p04_seg039.npy', 'p04_seg045.npy', 'p04_seg046.npy', 'p04_seg069.npy', 'p04_seg0