In [1]:
import numpy as np
import pandas as pd
import wfdb
import os
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
import seaborn as snss
from pprint import pprint
from tqdm import tqdm
from pathlib import Path
import sys

from torch_ecg._preprocessors import PreprocManager

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
# set your meta path of mimic-ecg
root_path = Path.cwd().parent

meta_path = f'{root_path}/datasets/pretrain/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0'
report_csv = pd.read_csv(f'{meta_path}/machine_measurements.csv', low_memory=False)
record_csv = pd.read_csv(f'{meta_path}/record_list.csv', low_memory=False)

In [10]:
output_dir = str((root_path / "datasets/pretrain/ecg_npy").resolve())
os.makedirs(output_dir, exist_ok=True)


In [11]:
preproc_config = {
    "random": False,
    "baseline_remove": {"window1": 0.2, "window2": 0.6},
    "bandpass": {
        "lowcut": 0.5,
        "highcut": 45,
        "filter_type": "butter",
        "filter_order": 4,
    },
}
ppm = PreprocManager.from_config(preproc_config)


In [12]:
def process_report(row):
    # Select the relevant columns and filter out NaNs
    report = row[['report_0', 'report_1', 'report_2', 'report_3', 'report_4', 
                  'report_5', 'report_6', 'report_7', 'report_8', 'report_9', 
                  'report_10', 'report_11', 'report_12', 'report_13', 'report_14', 
                  'report_15', 'report_16', 'report_17']].dropna()
    # Concatenate the report
    report = '. '.join(report)
    # Replace and preprocess text
    report = report.replace('EKG', 'ECG').replace('ekg', 'ecg')
    report = report.strip(' ***').strip('*** ').strip('***').strip('=-').strip('=')
    # Convert to lowercase
    report = report.lower()

    # concatenate the report if the report length is not 0
    total_report = ''
    if len(report.split()) != 0:
        total_report = report
        total_report = total_report.replace('\n', ' ')
        total_report = total_report.replace('\r', ' ')
        total_report = total_report.replace('\t', ' ')
        total_report += '.'
    if len(report.split()) == 0:
        total_report = 'empty'
    # Calculate the length of the report in words
    return len(report.split()), total_report

tqdm.pandas()
report_csv['report_length'], report_csv['total_report'] = zip(*report_csv.progress_apply(process_report, axis=1))
# Filter out reports with less than 4 words
report_csv = report_csv[report_csv['report_length'] >= 4]

# you should get 771693 here
print(report_csv.shape)

100%|██████████| 800035/800035 [02:47<00:00, 4781.77it/s]


(771693, 35)


In [13]:
report_csv.reset_index(drop=True, inplace=True)
record_csv = record_csv[record_csv['study_id'].isin(report_csv['study_id'])]
record_csv.reset_index(drop=True, inplace=True)

In [14]:
sid_to_rel_path = {}
skipped_count = 0

for _, rec_row in tqdm(record_csv.iterrows(), total=len(record_csv)):
    p = rec_row['path']
    study_id = rec_row['study_id']

    ecg_path = os.path.join(meta_path, p)
    record = wfdb.rdsamp(ecg_path)[0]
    record = record.T

    # 跳过有NaN或Inf的样本
    if np.isnan(record).sum() > 0 or np.isinf(record).sum() > 0:
        skipped_count += 1
        continue
    
    # 跳过min==max的样本（无法归一化）
    if record.min() == record.max():
        skipped_count += 1
        continue
    
    record = (record - record.min()) / (record.max() - record.min())
    # record *= 1000
    record = record.astype(np.float32)

    record, _ = ppm(record, 500)

    arr = record[:, :5000]
    npy_path = os.path.join(output_dir, f"{study_id}.npy")
    np.save(npy_path, arr)

    rel_path = f"datasets/pretrain/ecg_npy/{study_id}.npy"
    sid_to_rel_path[study_id] = rel_path

print(f"\n总共跳过 {skipped_count} 个有问题的样本")
print(f"成功处理 {len(sid_to_rel_path)} 个样本")

  0%|          | 0/771693 [00:00<?, ?it/s]

100%|██████████| 771693/771693 [19:10:18<00:00, 11.18it/s]   


总共跳过 11075 个有问题的样本
成功处理 760618 个样本





In [15]:
path_df = pd.DataFrame({
    'study_id': list(sid_to_rel_path.keys()),
    'path': list(sid_to_rel_path.values())
})

merged_df = report_csv.merge(path_df, on='study_id', how='inner')

train_df, val_df = train_test_split(merged_df, test_size=0.02, random_state=42)

train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)

csv_dir = str((root_path / "datasets/pretrain").resolve())
train_df.to_csv(os.path.join(csv_dir, "train.csv"), index=False)
val_df.to_csv(os.path.join(csv_dir, "val.csv"), index= False)
