In [4]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from collections import Counter

# Defining file path
ecg_folder = "../../../Datasets/12-lead electrocardiogram database/ECGDataDenoised"
attributes_file = "../../../Datasets/12-lead electrocardiogram database/AttributesDictionary.xlsx"
condition_names_file = "../../../Datasets/12-lead electrocardiogram database/ConditionNames.xlsx"
diagnostics_file = "../../../Datasets/12-lead electrocardiogram database/Diagnostics.xlsx"
rhythm_names_file = "../../../Datasets/12-lead electrocardiogram database/RhythmNames.xlsx"

# Loading metadata files
attributes_df = pd.read_excel(attributes_file)
condition_names_df = pd.read_excel(condition_names_file)
condition_names_df["Acronym Name"] = condition_names_df["Acronym Name"].str.strip()
diagnostics_df = pd.read_excel(diagnostics_file)
diagnostics_df["Rhythm"] = diagnostics_df["Rhythm"].str.strip()
rhythm_names_df = pd.read_excel(rhythm_names_file)
# Getting of the extra space which causes issue
rhythm_names_df["Acronym Name"] = rhythm_names_df["Acronym Name"].str.strip()

# Mapping acronyms to full names for condition and rhythm
condition_names = condition_names_df.set_index("Acronym Name")["Full Name"].to_dict()
rhythm_names = rhythm_names_df.set_index("Acronym Name")["Full Name"].to_dict()
print(condition_names)
print(rhythm_names)


# Mapping diagnostics to add full names of conditions and rhythms
diagnostics_df["RhythmFullName"] = diagnostics_df["Rhythm"].map(rhythm_names)
diagnostics_df["ConditionFullName"] = diagnostics_df["Beat"].map(condition_names)
diagnostics_df["RhythmFullName"]

{'1AVB': '1 degree atrioventricular block', '2AVB': '2 degree atrioventricular block', '2AVB1': '2 degree atrioventricular block(Type one)', '2AVB2': '2 degree atrioventricular block(Type two)', '3AVB': '3 degree atrioventricular block', 'ABI': 'atrial bigeminy', 'ALS': 'Axis left shift', 'APB': 'atrial\xa0premature\xa0beats', 'AQW': 'abnormal Q wave', 'ARS': 'Axis right shift', 'AVB': 'atrioventricular block', 'CCR': 'countercolockwise rotation', 'CR': 'colockwise rotation', 'ERV': 'Early repolarization of the ventricles', 'FQRS': 'fQRS Wave', 'IDC': 'Interior differences conduction', 'IVB': 'Intraventricular block', 'JEB': 'junctional escape beat', 'JPS': 'J point shift', 'JPT': 'junctional premature beat', 'LBBB': 'left bundle branch block', 'LBBBB': 'left back bundle branch block', 'LFBBB': 'left front bundle branch block', 'LRRI': 'Long RR interval', 'LVH': 'left ventricle hypertrophy', 'LVHV': 'left ventricle high voltage', 'LVQRSAL': 'lower voltage QRS in all lead', 'LVQRSCL': '

0                 Atrial Fibrillation
1                   Sinus Bradycardia
2                                 NaN
3                   Sinus Bradycardia
4                      Atrial Flutter
                     ...             
10641    Supraventricular Tachycardia
10642    Supraventricular Tachycardia
10643    Supraventricular Tachycardia
10644    Supraventricular Tachycardia
10645    Supraventricular Tachycardia
Name: RhythmFullName, Length: 10646, dtype: object

In [6]:
# Function to Load ECG data
def load_ecg_data(ecg_folder, diagnostics_df):
    data = []
    labels = []
    metadata = []

    for file_name in os.listdir(ecg_folder):
        if file_name.endswith('.csv'):
            ecg_data = pd.read_csv(os.path.join(ecg_folder, file_name), header=None)
            ecg_data = ecg_data.to_numpy()  # Convert to numpy array

            # Retrieve associated label and metadata
            record_id = file_name.replace('.csv', '')
            record_info = diagnostics_df[diagnostics_df['FileName'] == record_id]

            # Skip if no label information is available
            if record_info.empty:
                continue

            # Use rhythm as primary label, fallback to beat if unavailable
            rhythm_label = record_info['RhythmFullName'].values[0]
            condition_label = record_info['ConditionFullName'].values[0]
            label = rhythm_label if pd.notna(rhythm_label) else condition_label

            # Additional metadata
            patient_age = record_info['PatientAge'].values[0]
            ventricular_rate = record_info['VentricularRate'].values[0]
            atrial_rate = record_info['AtrialRate'].values[0]

            # Append data, label, and metadata
            data.append(ecg_data)
            labels.append(label)
            metadata.append({
                'patient_age': patient_age,
                'ventricular_rate': ventricular_rate,
                'atrial_rate': atrial_rate
            })

    return data, np.array(labels), metadata

# Load ECG data and labels
ecg_data, ecg_labels, ecg_metadata = load_ecg_data(ecg_folder, diagnostics_df)

# Standardize each ECG signal individually
scaler = StandardScaler()
ecg_data = [scaler.fit_transform(sample) for sample in ecg_data]

# Define a fixed sequence length (pad or truncate to this length)
sequence_length = 5000  # Adjust as per dataset requirements

# Function to pad or truncate each ECG signal to a fixed length
def preprocess_sequence(data, length):
    processed_data = []
    for sample in data:
        if sample.shape[0] > length:
            processed_data.append(sample[:length])
        else:
            pad_width = length - sample.shape[0]
            processed_data.append(np.pad(sample, ((0, pad_width), (0, 0)), mode='constant'))
    return np.array(processed_data)

# Preprocess ECG data sequences
ecg_data = preprocess_sequence(ecg_data, sequence_length)

# Count instances per class
label_counts = Counter(ecg_labels)
print("Label counts before filtering:", label_counts)

# Filter out classes with fewer than 2 instances
filtered_indices = [i for i, label in enumerate(ecg_labels) if label_counts[label] > 1]
filtered_ecg_data = ecg_data[filtered_indices]
filtered_ecg_labels = ecg_labels[filtered_indices]

# Print new counts after filtering
filtered_label_counts = Counter(filtered_ecg_labels)
print("Label counts after filtering:", filtered_label_counts)

# Train-test split with filtered data
X_train, X_test, y_train, y_test = train_test_split(filtered_ecg_data, filtered_ecg_labels, test_size=0.2, random_state=42, stratify=filtered_ecg_labels)

# Verify shapes and data summary
print("Training data shape:", X_train.shape)
print("Testing data shape:", X_test.shape)
print("Training labels shape:", y_train.shape)
print("Testing labels shape:", y_test.shape)


Label counts before filtering: Counter({'Sinus Bradycardia': 3889, 'Sinus Rhythm': 1826, 'Atrial Fibrillation': 1780, 'Sinus Tachycardia': 1568, 'Supraventricular Tachycardia': 587, 'Atrial Flutter': 445, 'nan': 321, 'Atrial Tachycardia': 121, 'Atrioventricular  Node Reentrant Tachycardia': 16, 'T wave Change': 16, 'ventricular premature beat': 11, 'atrial\xa0premature\xa0beats': 10, 'left ventricle high voltage': 9, 'Atrioventricular Reentrant Tachycardia': 8, 'ST-T Change': 7, 'Sinus Atrium to Atrial Wandering Rhythm': 7, 'Axis right shift': 5, 'right bundle branch block': 5, 'ST tilt up': 4, '1 degree atrioventricular block': 3, 'lower voltage QRS in chest lead': 2, 'ventricular preexcitation': 1, 'Wandering in the atrioventricalualr node': 1, 'ST drop down': 1, 'WPW': 1, 'Axis left shift': 1, 'Intraventricular block': 1})
Label counts after filtering: Counter({'Sinus Bradycardia': 3889, 'Sinus Rhythm': 1826, 'Atrial Fibrillation': 1780, 'Sinus Tachycardia': 1568, 'Supraventricular 