# This notebook is for preprocessing PTBXL, CPSC2018, and CSN datasets for finetuning tasks.

In [2]:
import numpy as np
import pandas as pd
import wfdb
import os
import ast
from matplotlib import pyplot as plt
import seaborn as sns
from pprint import pprint
from tqdm import tqdm
from scipy.ndimage import zoom
from scipy.io import loadmat
from scipy import signal as scipy_signal
from sklearn.model_selection import train_test_split

In [3]:
# set the split file path to store your processed csv file
split_path = './data_split/'
# set the meta path for the raw ecg you download
meta_path = '../Dataset'

# Preprocessing PTB-XL dataset

In [None]:
'''
Since PTB-XL provide the offical split, we will use the offical split for the finetune dataset.
The offical preprocess code is shown in the orignal paper: https://www.nature.com/articles/s41597-020-0495-6
We also list the preprocessed csv file in MERL/finetune/data_split/ptbxl
'''

# Preprocessing CPSC2018 Dataset

In [None]:
'''
This dataset provide raw file in .mat format.
We first convert the .mat file to .hea and .dat file using the wfdb package.
Then we downsample the data to 100Hz and 500Hz.
All information of this dataset can be found in: http://2018.icbeb.org/Challenge.html
'''

# here is your original data folder, you should download the data from the website
ori_data_folder = os.path.join(meta_path, 'CPSC2018')

# here is the output folder to store the preprocessed data
output_folder = os.path.join(meta_path, 'icbeb2018')
output_datafolder_100 = output_folder+ '/records100/'
output_datafolder_500 = output_folder+ '/records500/'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)
else:
    print('The folder already exists')
if not os.path.exists(output_datafolder_100):
    os.makedirs(output_datafolder_100)
else:
    print('The folder already exists')
if not os.path.exists(output_datafolder_500):
    os.makedirs(output_datafolder_500)
else:
    print('The folder already exists')

# function to store 12 leads ECG data as wfdb format
def store_as_wfdb(signame, data, sigfolder, fs):
    channel_itos=['I', 'II', 'III', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    wfdb.wrsamp(signame,
                fs=fs,
                sig_name=channel_itos, 
                p_signal=data,
                units=['mV']*len(channel_itos),
                fmt = ['16']*len(channel_itos), 
                write_dir=sigfolder)  

# load the reference csv file
reference_path = os.path.join(output_folder, 'REFERENCE.csv')
df_reference = pd.read_csv(reference_path)

# define the label dictionary
# label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD_', 9:'STE_'}
label_dict = {1:'NORM', 2:'AFIB', 3:'1AVB', 4:'CLBBB', 5:'CRBBB', 6:'PAC', 7:'VPC', 8:'STD', 9:'STE'}

data = {'ecg_id':[], 'filename':[], 'validation':[], 'age':[], 'sex':[], 'scp_codes':[]}

# read all .mat files from the folder then convert to .hea and .dat files
ecg_counter = 0
# filename = 'A0001.mat'
# mat = loadmat(ori_data_folder + '/' +  filename)
# print(mat['val'])


filenames = os.listdir(ori_data_folder)
for filename in tqdm(filenames):
    if filename.split('.')[1] == 'mat':
        name = filename.split('.')[0]
        hea_path = ori_data_folder + '/' +  name + '.hea'
        with open(hea_path, 'r') as f:
            lines = f.readlines()
        age = -1
        sex = 'Unknown'
        for line in lines:
            if '#Age' in line:
                age = int(line.split(':')[1].strip()) if line.split(':')[1].strip().isdigit() else -1
            if '#Sex' in line:
                sex = line.split(':')[1].strip()
        if age < 0 or sex == 'Unknown':
            continue
        ecg_counter += 1
        sig = loadmat(ori_data_folder + '/' + filename)['val']
        sig = sig.astype(np.float32)
        data['ecg_id'].append(ecg_counter)
        data['filename'].append(name)
        data['validation'].append(False)
        data['age'].append(age)
        data['sex'].append(1 if sex == 'Male' else 0)
        labels = df_reference[df_reference.Recording == name][['First_label' ,'Second_label' ,'Third_label']].values.flatten()
        labels = labels[~np.isnan(labels)].astype(int)
        data['scp_codes'].append({label_dict[key]:1 for key in labels})

        # resample to 500 hz data
        store_as_wfdb(str(ecg_counter), sig.T, output_datafolder_500, 500)
        # resample to 100 hz data
        down_sig = np.array([zoom(channel, .2) for channel in sig])
        store_as_wfdb(str(ecg_counter), down_sig.T, output_datafolder_100, 100)

df = pd.DataFrame(data)
df['patient_id'] = df.ecg_id
# df = stratisfy_df(df, 'strat_fold')
# df.to_csv(output_folder+'icbeb_database.csv')

In [None]:
# make the patient_id column the first column
cols = list(df.columns)
cols = [cols[-1]] + cols[:-1]
switched_df = df[cols]

In [None]:
# Extract all unique labels from the 'scp_codes' column
# all_labels = set()
# for item in switched_df['scp_codes']:
#     all_labels.update(item.keys())

all_labels = ['AFIB', 'VPC', 'NORM', '1AVB', 'CRBBB', 'STE', 'PAC', 'CLBBB', 'STD']


# # Create new columns for each label
for label in all_labels:
    switched_df[label] = switched_df['scp_codes'].apply(lambda x: x.get(label, 0))

cols = list(switched_df.columns)
print(cols)
# cols[-1] = 'STD'
# cols[-4] = 'STE'
# # replace columns name
# switched_df.columns = cols


In [None]:
# split train test val
train_df, test_df = train_test_split(switched_df, test_size=0.2, random_state=42)
train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=42)

print(f'train_df shape: {train_df.shape}')
print(f'val_df shape: {val_df.shape}')
print(f'test_df shape: {test_df.shape}')

# save the csv files
train_df.to_csv(split_path+'/icbeb/icbeb_train.csv', index=False)
val_df.to_csv(split_path+'/icbeb/icbeb_val.csv', index=False)
test_df.to_csv(split_path+'/icbeb/icbeb_test.csv', index=False)


# Preprocessing CSN Dataset

In [4]:
# 重新整理你的CSN数据处理代码
your_path = '../Dataset/'
data_path = f'{your_path}CSN/WFDBRecords'

# 初始化数据字典
df = {'ecg_path': [], 'age': [], 'diagnose': []}

# 读取参考文件
ref = pd.read_csv(f'{your_path}CSN/ConditionNames_SNOMED-CT.csv')
ref['Snomed_CT'] = ref['Snomed_CT'].astype(str)

def read_header_file(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()
        header_info = [line.strip() for line in lines]
    return header_info

def extract_age_from_hea(hea_lines):
    """从.hea文件提取年龄"""
    for line in hea_lines:
        if '#Age' in line:
            try:
                age_str = line.split(':')[1].strip()
                if age_str.lower() not in ['nan', 'unknown', '', 'null']:
                    return int(float(age_str))
            except (ValueError, IndexError):
                continue
    return -1

def extract_diagnose_from_hea(hea_lines, ref_df):
    """从.hea文件提取诊断信息"""
    diagnoses = []
    for line in hea_lines:
        if '#Dx' in line:
            try:
                dx_codes = line.split(':')[1].strip()
                if dx_codes and dx_codes != 'Unknown':
                    codes = dx_codes.split(',')
                    for code in codes:
                        code = code.strip()
                        matched = ref_df[ref_df['Snomed_CT'] == code]
                        if not matched.empty:
                            diagnoses.append(matched['Acronym Name'].iloc[0])
            except IndexError:
                continue
    return ','.join(diagnoses) if diagnoses else 'Unknown'

# 处理数据文件夹
folders = os.listdir(data_path)
folders = sorted([os.path.join(data_path, f) for f in folders if os.path.isdir(os.path.join(data_path, f))])

successful_records = 0
failed_records = 0

for i, folder in enumerate(tqdm(folders)):
    subfolders = os.listdir(folder)
    subfolders = sorted([os.path.join(folder, f) for f in subfolders if os.path.isdir(os.path.join(folder, f))])
    
    for subfolder in subfolders:
        try:
            files = os.listdir(subfolder)
            mat_files = sorted([f for f in files if f.endswith('.mat')])
            hea_files = sorted([f for f in files if f.endswith('.hea')])
            
            # 确保mat和hea文件一一对应
            for mat_file in mat_files:
                base_name = mat_file.replace('.mat', '')
                hea_file = base_name + '.hea'
                
                if hea_file in hea_files:
                    mat_path = os.path.join(subfolder, mat_file)
                    hea_path = os.path.join(subfolder, hea_file)
                    
                    try:
                        # 读取ECG数据
                        mat = loadmat(mat_path)
                        ecg = mat['val']
                        
                        # 读取头文件
                        hea_lines = read_header_file(hea_path)
                        
                        # 提取年龄
                        age = extract_age_from_hea(hea_lines)
                        if age < 0 or age > 120:  # 年龄合理性检查
                            failed_records += 1
                            continue
                        
                        # 提取诊断
                        diagnose = extract_diagnose_from_hea(hea_lines, ref)
                        
                        # 添加到数据字典 - 确保同时添加所有字段
                        relative_path = os.path.relpath(mat_path, start=your_path)
                        df['ecg_path'].append(relative_path)
                        df['age'].append(age)
                        df['diagnose'].append(diagnose)
                        
                        successful_records += 1
                        
                    except Exception as e:
                        print(f"处理文件 {mat_path} 时出错: {e}")
                        failed_records += 1
                        continue
                        
        except Exception as e:
            print(f"处理文件夹 {subfolder} 时出错: {e}")
            failed_records += 1
            continue

print(f"\n处理完成:")
print(f"成功记录: {successful_records}")
print(f"失败记录: {failed_records}")

  0%|          | 0/46 [00:00<?, ?it/s]

100%|██████████| 46/46 [06:20<00:00,  8.26s/it]


处理完成:
成功记录: 45097
失败记录: 55





In [5]:


# 验证数据一致性后创建DataFrame
print(f"\n数据验证:")
for key, value in df.items():
    print(f"{key}: {len(value)} 条记录")

# 确保所有列长度一致
if len(set(len(v) for v in df.values())) == 1:
    new_df = pd.DataFrame(df)
    
    # 过滤掉诊断为Unknown的记录
    new_df = new_df[new_df['diagnose'] != 'Unknown']
    new_df.reset_index(inplace=True, drop=True)
    
    print(f"有效诊断记录: {len(new_df)}")
    
    # 创建多标签列
    if len(new_df) > 0:
        unique_labels = []
        for labels in new_df['diagnose']:
            if labels and labels != 'Unknown':
                labels = labels.split(',')
                unique_labels.extend([label.strip() for label in labels])
        
        unique_labels = list(set(unique_labels))
        print(f"发现 {len(unique_labels)} 种诊断类别")
        
        # 为每个标签创建二进制列
        for label in unique_labels:
            new_df[label] = new_df['diagnose'].apply(lambda x: 1 if label in x else 0)
        
        print(f"最终DataFrame形状: {new_df.shape}")
        print("标签分布:")
        for label in unique_labels:
            count = new_df[label].sum()
            print(f"  {label}: {count}")
            
else:
    print("错误: 数据字典中各列长度不一致，无法创建DataFrame")
    for key, value in df.items():
        print(f"  {key}: {len(value)}")


数据验证:
ecg_path: 45097 条记录
age: 45097 条记录
diagnose: 45097 条记录
有效诊断记录: 44422
发现 51 种诊断类别
最终DataFrame形状: (44422, 54)
标签分布:
  SA: 3552
  UW: 136
  RVH: 109
  STE: 800
  CCR: 162
  MI: 120
  ALS: 1543
  2AVB1: 31
  VB: 1551
  VFW: 115
  JPT: 11
  ARS: 847
  WAVN: 2
  TWC: 7028
  PRIE: 52
  SR: 8123
  LBBB: 240
  AVB: 1548
  3AVB: 76
  LVH: 642
  TWO: 2875
  AVRT: 26
  VET: 8
  1AVB: 1140
  SVT: 700
  ST: 9858
  ERV: 366
  AQW: 1062
  APB: 1312
  QTIE: 391
  CR: 238
  AT: 297
  VEB: 56
  JEB: 75
  RAH: 36
  VPE: 12
  IDC: 767
  FQRS: 3
  2AVB: 97
  RBBB: 649
  LVQRSAL: 1039
  ABI: 3
  WPW: 72
  AF: 9809
  STDD: 1668
  STTU: 176
  AFIB: 1780
  SB: 16559
  PWC: 142
  VPB: 294
  STTC: 1158


In [6]:
# count the number of sample for each label
label_count = {}
for label in unique_labels:
    label_count[label] = new_df[label].sum()
# sort the label_count dictionary
label_count = dict(sorted(label_count.items(), key=lambda item: item[1], reverse=True))
# drop the label with less than 10 samples
for key in list(label_count.keys()):
    if label_count[key] < 10:
        del label_count[key]
# drop the columns not in label_count
essential_cols = ['ecg_path', 'age', 'diagnose']
for key in list(new_df.columns):
    if key not in label_count.keys() and key not in essential_cols:
        new_df.drop(key, axis=1, inplace=True)

In [14]:

# split train test val

train_df, test_df = train_test_split(new_df, test_size=0.2)
train_df, val_df = train_test_split(train_df, test_size=0.1)
train_df.reset_index(inplace=True, drop=True)
val_df.reset_index(inplace=True, drop=True)
test_df.reset_index(inplace=True, drop=True)

print(f'train_df shape: {train_df.shape}')
print(f'val_df shape: {val_df.shape}')
print(f'test_df shape: {test_df.shape}')

# save the csv files
train_df.to_csv(f'{split_path}chapman/'+'chapman_train.csv', index=False)
val_df.to_csv(f'{split_path}chapman/'+'chapman_val.csv', index=False)
test_df.to_csv(f'{split_path}chapman/'+'chapman_test.csv', index=False)

train_df shape: (31983, 50)
val_df shape: (3554, 50)
test_df shape: (8885, 50)
