In [1]:
import os
from torch.utils.data.dataloader import default_collate
import torch
from torch.utils.data import Dataset
import pandas as pd
import numpy as np
import wfdb

def scaling(X, sigma=0.1):
    scalingFactor = np.random.normal(loc=1.0, scale=sigma, size=(1, X.shape[1]))
    myNoise = np.matmul(np.ones((X.shape[0], 1)), scalingFactor)
    return X * myNoise

def shift(sig, interval=20):
    for col in range(sig.shape[1]):
        offset = np.random.choice(range(-interval, interval))
        sig[:, col] += offset / 1000 
    return sig


def transform(sig, train=False):
    if train:
        if np.random.randn() > 0.5: sig = scaling(sig)
        if np.random.randn() > 0.5: sig = shift(sig)
    return sig


class ECGDataset(Dataset):
    def __init__(self, phase, data_dir, label_csv, folds, leads):
        super(ECGDataset, self).__init__()
        self.phase = phase
        df = pd.read_csv(label_csv)
        df = df[df['fold'].isin(folds)]
        self.data_dir = ""
        self.labels = df
        self.leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
        if leads == 'all':
            self.use_leads = np.where(np.in1d(self.leads, self.leads))[0]
        else:
            self.use_leads = np.where(np.in1d(self.leads, leads))[0]
        self.nleads = len(self.use_leads)
        self.classes = ['label_Atrial Fibrillation', 'label_Myocardial Infarction', 'label_Ventricular Tachycardia']
        self.n_classes = len(self.classes)
        self.data_dict = {}
        self.label_dict = {}

    def __getitem__(self, index: int):
        try: 
            row = self.labels.iloc[index]
            records = pd.read_csv("records_filtered.csv")
            patient_id = row['patient_id']
            record = records[records['study_id'] == patient_id]
            ecg_data, _ = wfdb.rdsamp("C:\\Users\\oladipea\\Downloads\\mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0\\" + str(record['file_name'].values[0]))
            if np.isnan(ecg_data).any():
                print(f"Skipping NaN ECG Data for patient {patient_id}")
                return None  
                
            #print(f"Raw ECG Data: min {np.min(ecg_data)}, max {np.max(ecg_data)}")  # Debug statement
            ecg_data = transform(ecg_data, self.phase == 'train')
            nsteps, _ = ecg_data.shape
            ecg_data = ecg_data[-5000:, self.use_leads]
            result = np.zeros((5000, self.nleads)) # 10 s, 500 Hz
            result[-nsteps:, :] = ecg_data

            if np.any(self.label_dict.get(patient_id)):
                labels = self.label_dict.get(patient_id)
            else:
                labels = row[self.classes].to_numpy(dtype=np.float32)
                self.label_dict[patient_id] = labels

            return torch.from_numpy(result.transpose()).float(), torch.from_numpy(labels).float()
        except:
            print('An error occured')
    def __len__(self):
        return len(self.labels)

In [None]:
records = pd.read_csv("records_final_new.csv", names=['index', 'file_name', 'study_id', 'subject_id', 'ecg_time', 'hosp_diag_hosp', 'gender', 'age', 'anchor_age', 'diagnostic_class'])

records = records.drop_duplicates(subset=['study_id'], keep='first')
records

In [4]:
records.value_counts('diagnostic_class')

diagnostic_class
Atrial Fibrillation        5109
Myocardial Infarction      4964
Ventricular Tachycardia    1731
Name: count, dtype: int64

In [5]:
empty_signals = []; 

for idx, row in records.iterrows():
    print("Reading: ", row['file_name'])
    recordpath = row['file_name']
    patient_id = recordpath.split('/')[-1]
    signal, meta_data = wfdb.rdsamp("C:\\Users\\oladipea\\Downloads\\mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0\\" + recordpath)
    if np.isnan(signal).any():
        print(f"Skipping NaN ECG Data for patient {patient_id}")
        empty_signals.append(row['study_id'])


Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10000764/s47218930/47218930
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10000980/s48376834/48376834
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10000980/s46293961/46293961
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10000980/s49245181/49245181
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10002013/s48217041/48217041
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10002013/s46241559/46241559
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10003502/s47886467/47886467
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10007058/s47527771/47527771
Reading:  mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0/files/p1000/p10007058/s47979034/4

In [6]:
records_filtered = records[~records['study_id'].isin(empty_signals)]

In [7]:
records_filtered.to_csv("records_filtered.csv", index=None)

In [None]:
import wfdb
import numpy as np
import pandas as pd

from sklearn.preprocessing import OneHotEncoder
from glob import glob
import argparse
import os


def gen_reference_csv(data_dir, reference_csv, records):
    results = []
    for idx, row in records.iterrows():
        recordpath = row['file_name']
        patient_id = recordpath.split('/')[-1]
        _, meta_data = wfdb.rdsamp("C:\\Users\\oladipea\\Downloads\\mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0\\" + recordpath)
        
        sample_rate = meta_data['fs']
        signal_len = meta_data['sig_len']
        results.append([patient_id, sample_rate, signal_len])
    df = pd.DataFrame(data=results, columns=['patient_id', 'sample_rate', 'signal_len'])
    df.sort_values('patient_id').to_csv(reference_csv, index=None)
    print("Reference.csv was saved: ", reference_csv)
        
def gen_label_csv(reference_csv, dx_dict, classes, records):
        results = []

        for _, row in records.iterrows():
            patient_id = row['study_id']
            label = row['diagnostic_class']
            results.append([patient_id, label])
            
        results = pd.DataFrame(data=results, columns=['patient_id', 'label'])

        print(results)

        encoder = OneHotEncoder(sparse_output=False, dtype=int)

        one_hot_encoded = encoder.fit_transform(results[['label']])

        one_hot_results = pd.DataFrame(one_hot_encoded, columns=encoder.get_feature_names_out(['label']))

        df = pd.concat([results['patient_id'], one_hot_results], axis=1)

        n = len(df) 
        folds = np.zeros(n, dtype=np.int8)
        for i in range(10):
            start = int(n * i / 10)
            end = int(n * (i + 1) / 10)
            folds[start:end] = i + 1
        df['fold'] = np.random.permutation(folds)
        columns = df.columns
        
        print(columns)
        df['keep'] = df[classes].sum(axis=1)
        df = df[df['keep'] > 0]
        df.sort_values("patient_id")[columns].to_csv("label.csv", index=None)
        print("Saved")

if __name__ == "__main__":
    leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
    classes = ['label_Atrial Fibrillation', 'label_Myocardial Infarction', 'label_Ventricular Tachycardia']
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-dir', type=str, default='', help='Directory to dataset')
    args, unknown = parser.parse_known_args()
    data_dir = args.data_dir

    print("Data directory: ", data_dir)
    
    reference_csv = os.path.join(data_dir, 'reference.csv')
    label_csv = os.path.join(data_dir, 'labels.csv')

    print("Generating reference CSV")
    gen_reference_csv(data_dir, "reference.csv", records_filtered)

    print("Generating label CSV")
    gen_label_csv("reference.csv", {}, classes, records_filtered)
