In [18]:
import numpy as np
import pandas as pd
import os
from scipy.signal import find_peaks, butter, filtfilt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from tqdm import tqdm
import tensorflow as tf
# noinspection PyUnresolvedReferences
from tensorflow.keras.models import Sequential
# noinspection PyUnresolvedReferences
from tensorflow.keras.layers import Dense, Dropout

# Constants
ecg_folder = "../../../../Datasets/12-lead electrocardiogram database/ECGData"
diagnostics_file = "../../../../Datasets/12-lead electrocardiogram database/Diagnostics.xlsx"

# Label mapping
rhythm_mapping = {
    'AFIB': 'AFIB',
    'AF': 'AFIB',
    'SVT': 'GSVT',
    'AT': 'GSVT',
    'SAAWR': 'GSVT',
    'ST': 'GSVT',
    'AVNRT': 'GSVT',
    'AVRT': 'GSVT',
    'SB': 'SB',
    'SR': 'SR',
    'SA': 'SR'
}

# Load diagnostics data
diagnostics_df = pd.read_excel(diagnostics_file)
diagnostics_df['Rhythm'] = diagnostics_df['Rhythm'].map(rhythm_mapping)

In [19]:
diagnostics_df

Unnamed: 0,FileName,Rhythm,Beat,PatientAge,Gender,VentricularRate,AtrialRate,QRSDuration,QTInterval,QTCorrected,RAxis,TAxis,QRSCount,QOnset,QOffset,TOffset
0,MUSE_20180113_171327_27000,AFIB,RBBB TWC,85,MALE,117,234,114,356,496,81,-27,19,208,265,386
1,MUSE_20180112_073319_29000,SB,TWC,59,FEMALE,52,52,92,432,401,76,42,8,215,261,431
2,MUSE_20180111_165520_97000,SR,NONE,20,FEMALE,67,67,82,382,403,88,20,11,224,265,415
3,MUSE_20180113_121940_44000,SB,NONE,66,MALE,53,53,96,456,427,34,3,9,219,267,447
4,MUSE_20180112_122850_57000,AFIB,STDD STTC,73,FEMALE,162,162,114,252,413,68,-40,26,228,285,354
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10641,MUSE_20181222_204306_99000,GSVT,NONE,80,FEMALE,196,73,168,284,513,258,244,32,177,261,319
10642,MUSE_20181222_204309_22000,GSVT,NONE,81,FEMALE,162,81,162,294,482,110,-75,27,173,254,320
10643,MUSE_20181222_204310_31000,GSVT,NONE,39,MALE,152,92,152,340,540,250,38,25,208,284,378
10644,MUSE_20181222_204312_58000,GSVT,NONE,76,MALE,175,178,128,310,529,98,-83,29,205,269,360


In [20]:
diagnostics_df = diagnostics_df.dropna(subset=['Rhythm'])  # Drop unmapped rows

In [21]:
# Define functions for preprocessing and feature extraction
def preprocess_signal(signal: np.ndarray, sampling_rate: int = 500) -> np.ndarray:
    """
    Preprocess ECG signal with filtering and normalization
    """
    nyquist = sampling_rate / 2
    low = 0.5 / nyquist
    high = 45 / nyquist
    b, a = butter(2, [low, high], btype='band')
    filtered = filtfilt(b, a, signal)
    normalized = (filtered - np.mean(filtered)) / np.std(filtered)
    return normalized


def detect_r_peaks(signal: np.ndarray, sampling_rate: int = 500) -> np.ndarray:
    """
    Detect R-peaks in the signal using find_peaks
    """
    peaks, _ = find_peaks(signal, distance=sampling_rate // 2, height=0.5)  # Adjust threshold as needed
    return peaks


def extract_features(signal: np.ndarray, sampling_rate: int = 500) -> dict:
    """
    Extract features from the ECG signal
    """
    r_peaks = detect_r_peaks(signal, sampling_rate)
    rr_intervals = np.diff(r_peaks) / sampling_rate  # Convert to seconds

    features = {}

    # Basic RR interval-based features
    features['ventricular_rate'] = 60 / np.mean(rr_intervals) if len(rr_intervals) > 0 else 0
    features['mean_rr_interval'] = np.mean(rr_intervals) if len(rr_intervals) > 0 else 0
    features['variance_rr_interval'] = np.var(rr_intervals) if len(rr_intervals) > 0 else 0
    features['qrs_count'] = len(r_peaks)
    features['rr_interval_count'] = len(rr_intervals)

    # QRS Duration
    qrs_durations = []
    for i, r_peak in enumerate(r_peaks):
        # Look for the Q and S points around the R peak
        left_idx = max(0, r_peak - int(0.1 * sampling_rate))  # 100 ms window before
        right_idx = min(len(signal), r_peak + int(0.1 * sampling_rate))  # 100 ms window after
        segment = signal[left_idx:right_idx]

        if len(segment) > 1:
            # Approximate QRS width as the duration of the segment above a threshold
            threshold = 0.5 * np.max(segment)  # 50% of the max amplitude
            significant_points = np.where(segment > threshold)[0]
            if len(significant_points) > 1:
                qrs_duration = (significant_points[-1] - significant_points[0]) / sampling_rate
                qrs_durations.append(qrs_duration)

    features['qrs_duration'] = np.mean(qrs_durations) if len(qrs_durations) > 0 else 0.1

    # QT Interval
    qt_intervals = []
    for i, r_peak in enumerate(r_peaks):
        # Approximate T wave as a prominent feature after the R peak
        left_idx = r_peak
        right_idx = min(len(signal), r_peak + int(0.4 * sampling_rate))  # Up to 400 ms after R peak
        segment = signal[left_idx:right_idx]

        if len(segment) > 1:
            # Find the max point (T peak) and use it to approximate QT interval
            t_peak_idx = np.argmax(segment)
            qt_interval = (t_peak_idx + left_idx - r_peak) / sampling_rate
            qt_intervals.append(qt_interval)

    features['qt_interval'] = np.mean(qt_intervals) if len(qt_intervals) > 0 else 0.35

    # R and T Axes (Placeholder, lead-specific calculations)
    features['r_axis'] = np.sum(signal[r_peaks])  # Sum of R peak amplitudes as a proxy
    features['t_axis'] = np.mean(signal[r_peaks])  # Mean T wave amplitude as a proxy

    return features


def load_and_extract_features(ecg_folder: str, diagnostics_df: pd.DataFrame, selected_leads: int = 1) -> pd.DataFrame:
    """
    Load ECG signals and extract features
    """
    feature_list = []
    labels = []

    for idx, row in tqdm(diagnostics_df.iterrows(), total=len(diagnostics_df), desc="Processing ECG files"):
        file_path = os.path.join(ecg_folder, f"{row['FileName']}.csv")
        if os.path.exists(file_path):
            try:
                signal = pd.read_csv(file_path).values[:, selected_leads - 1]  # Extract selected lead
                signal = preprocess_signal(signal)
                features = extract_features(signal)
                feature_list.append(features)
                labels.append(row['Rhythm'])  # Assuming 'Rhythm' column contains target labels
            except Exception as e:
                print(f"Error processing {file_path}: {e}")
                continue

    features_df = pd.DataFrame(feature_list)
    features_df['label'] = labels
    return features_df


In [22]:
# Load data and extract features
features_df = load_and_extract_features(ecg_folder, diagnostics_df)

features_df

Processing ECG files: 100%|██████████| 10646/10646 [00:53<00:00, 198.01it/s]


Unnamed: 0,ventricular_rate,mean_rr_interval,variance_rr_interval,qrs_count,rr_interval_count,qrs_duration,qt_interval,r_axis,t_axis,label
0,91.145833,0.658286,0.020302,15,14,0.068667,0.017333,32.343690,2.156246,AFIB
1,53.499777,1.121500,0.011444,9,8,0.027556,0.000000,56.863504,6.318167,SB
2,67.803575,0.884909,0.010797,12,11,0.085500,0.000000,30.735576,2.561298,SR
3,53.309640,1.125500,0.000257,9,8,0.022667,0.000000,57.205732,6.356192,SB
4,81.081081,0.740000,0.032294,14,13,0.022000,0.129857,49.035667,3.502548,AFIB
...,...,...,...,...,...,...,...,...,...,...
10641,82.122552,0.730615,0.027474,14,13,0.098857,0.050000,22.607509,1.614822,GSVT
10642,64.864865,0.925000,0.090102,11,10,0.137455,0.072364,14.568783,1.324435,GSVT
10643,85.089141,0.705143,0.009654,15,14,0.068133,0.120533,21.165194,1.411013,GSVT
10644,68.550062,0.875273,0.029908,12,11,0.079167,0.029333,18.627008,1.552251,GSVT


In [23]:
# Encode labels
le = LabelEncoder()
features_df['label'] = le.fit_transform(features_df['label'])
features_df

Unnamed: 0,ventricular_rate,mean_rr_interval,variance_rr_interval,qrs_count,rr_interval_count,qrs_duration,qt_interval,r_axis,t_axis,label
0,91.145833,0.658286,0.020302,15,14,0.068667,0.017333,32.343690,2.156246,0
1,53.499777,1.121500,0.011444,9,8,0.027556,0.000000,56.863504,6.318167,2
2,67.803575,0.884909,0.010797,12,11,0.085500,0.000000,30.735576,2.561298,3
3,53.309640,1.125500,0.000257,9,8,0.022667,0.000000,57.205732,6.356192,2
4,81.081081,0.740000,0.032294,14,13,0.022000,0.129857,49.035667,3.502548,0
...,...,...,...,...,...,...,...,...,...,...
10641,82.122552,0.730615,0.027474,14,13,0.098857,0.050000,22.607509,1.614822,1
10642,64.864865,0.925000,0.090102,11,10,0.137455,0.072364,14.568783,1.324435,1
10643,85.089141,0.705143,0.009654,15,14,0.068133,0.120533,21.165194,1.411013,1
10644,68.550062,0.875273,0.029908,12,11,0.079167,0.029333,18.627008,1.552251,1


In [24]:
# Prepare data
X = features_df.drop(columns=['label'])
y = features_df['label']

# Scale features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

(8516, 9) (2130, 9) (8516,) (2130,)


In [25]:
def create_mlp_model(input_dim, num_classes):
    mlp = Sequential([
        Dense(128, activation='relu', input_dim=input_dim),
        Dropout(0.3),
        Dense(64, activation='relu'),
        Dropout(0.3),
        Dense(num_classes, activation='softmax')
    ])

    mlp.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return mlp


num_classes = len(le.classes_)
mlp_model = create_mlp_model(X_train.shape[1], num_classes)

# Train MLP
mlp_model.fit(X_train, y_train, epochs=150, batch_size=128, validation_split=0.2, verbose=1)

# Evaluate MLP
mlp_loss, mlp_accuracy = mlp_model.evaluate(X_test, y_test, verbose=0)
mlp_y_pred = np.argmax(mlp_model.predict(X_test), axis=1)

print("\nTensorFlow MLP Classifier Results")
print(f"Accuracy: {mlp_accuracy:.4f}")
print(classification_report(y_test, mlp_y_pred, target_names=le.classes_, digits=5))


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 41ms/step - accuracy: 0.4684 - loss: 1.1426 - val_accuracy: 0.7394 - val_loss: 0.6925
Epoch 2/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.7470 - loss: 0.6994 - val_accuracy: 0.7975 - val_loss: 0.5491
Epoch 3/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.8008 - loss: 0.5673 - val_accuracy: 0.8110 - val_loss: 0.5059
Epoch 4/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.8008 - loss: 0.5486 - val_accuracy: 0.8192 - val_loss: 0.4804
Epoch 5/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 984us/step - accuracy: 0.8212 - loss: 0.5028 - val_accuracy: 0.8357 - val_loss: 0.4590
Epoch 6/150
[1m54/54[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.8310 - loss: 0.4790 - val_accuracy: 0.8333 - val_loss: 0.4517
Epoch 7/150
[1m54/54[0m [32m

In [26]:
dt = DecisionTreeClassifier(random_state=42)
dt.fit(X_train, y_train)
dt_y_pred = dt.predict(X_test)

# Evaluate Decision Tree
dt_accuracy = accuracy_score(y_test, dt_y_pred)
print("\nDecision Tree Classifier Results")
print(f"Accuracy: {dt_accuracy:.4f}")
print(f"Max Depth: {dt.get_depth()}")
print(f"Max Leaf Nodes: {dt.get_n_leaves()}")
print(classification_report(y_test, dt_y_pred, target_names=le.classes_, digits=5))


Decision Tree Classifier Results
Accuracy: 0.8178
Max Depth: 35
Max Leaf Nodes: 1050
              precision    recall  f1-score   support

        AFIB    0.63920   0.67689   0.65750       424
        GSVT    0.78788   0.75519   0.77119       482
          SB    0.92159   0.92278   0.92219       777
          SR    0.84807   0.83669   0.84234       447

    accuracy                        0.81784      2130
   macro avg    0.79919   0.79789   0.79830      2130
weighted avg    0.81969   0.81784   0.81857      2130

