
# Novel ECG Classification Model: ResNet-16 and Transformer Approach

# This notebook implements a novel approach to ECG classification using a combination of ResNet-16 for ECG feature extraction and a Transformer for R-R interval processing.

## 1. Import Libraries and Load Data

In [1]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns  # For plotting the confusion matrix

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report  # For evaluation

from imblearn.combine import SMOTETomek
import pywt
from wfdb import processing
from collections import Counter

import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    ModelCheckpoint, TensorBoard, ReduceLROnPlateau, EarlyStopping
)
from datetime import datetime

from models.resnet import build_resnet




## 2. Load the Dataset

In [2]:
# Load the dataset
dataset_path = 'data/raw/physionet2017.csv'  # Update with your dataset path
data = pd.read_csv(dataset_path)

# Extract ECG signals (features) and labels
X = data.drop(['name', 'label'], axis=1)  # Drop 'name' and 'label' columns
y = data['label'].values  # Use 'label' as the labels

# Convert X to a NumPy array
X = X.values

# Remove Class-2 ("Other") entries
mask = y != 2
X = X[mask]
y = y[mask]

# Convert X to numeric, replacing non-numeric values with NaN
X = pd.DataFrame(X)
X = X.apply(pd.to_numeric, errors='coerce')

# Check for NaN values
print("Number of NaN values in X:", X.isna().sum().sum())

# Fill NaN values with the mean of each column
X = X.fillna(X.mean())

# Convert back to NumPy array
X = X.values

# Check the shapes and data types
print("Data shape after cleaning:", X.shape)
print("Labels shape after cleaning:", y.shape)
print("Class distribution:", Counter(y))
print("Data types of X after cleaning:", X.dtype)


Number of NaN values in X: 0
Data shape after cleaning: (6113, 2000)
Labels shape after cleaning: (6113,)
Class distribution: Counter({0: 5076, 1: 758, 3: 279})
Data types of X after cleaning: float64


## 3. Noise Removal and Signal Detrending

In [3]:
# Step 1: Noise Removal and Signal Detrending
from scipy.signal import butter, filtfilt

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    y = filtfilt(b, a, data)
    return y

fs = 300  # Sampling frequency in Hz
lowcut = 0.5
highcut = 45.0

# Apply the filter to each ECG signal
X_filtered = np.array([butter_bandpass_filter(ecg, lowcut, highcut, fs) for ecg in X])

print("Filtered data shape:", X_filtered.shape)


Filtered data shape: (6113, 2000)


## 4. R-Peak Detection and Calculation of R-R Intervals

In [4]:
# Step 2: R-Peak Detection and Calculation of R-R Intervals
def calculate_r_r_intervals(ecg_signal, fs=300):
    r_peaks = processing.gqrs_detect(sig=ecg_signal, fs=fs)
    r_r_intervals = np.diff(r_peaks) / fs  # Calculate intervals between R-peaks
    return r_r_intervals

# Calculate R-R intervals for each ECG signal
X_r_r_intervals = [calculate_r_r_intervals(ecg) for ecg in X_filtered]

# Example of R-R intervals
print("Example R-R intervals:", X_r_r_intervals[0])


Example R-R intervals: [0.74333333 0.71666667 0.70666667 0.73       0.76       0.76666667]


## 5. Feature Extraction Using Wavelet Transform

In [5]:
# Step 3: Feature Extraction
def extract_wavelet_features(segment, wavelet='db1', level=3):
    coeffs = pywt.wavedec(segment, wavelet, level=level)
    features = []
    for coeff in coeffs:
        features.extend([
            np.mean(coeff),
            np.std(coeff),
            np.max(coeff),
            np.min(coeff)
        ])
    return np.array(features)

# Extract features for each ECG signal
X_pqrst_features = [extract_wavelet_features(ecg) for ecg in X_filtered]
X_pqrst_features = np.array(X_pqrst_features)

print("Wavelet features shape:", X_pqrst_features.shape)


Wavelet features shape: (6113, 16)


## 6. Padding and Normalizing R-R Intervals

In [6]:
# Pad R-R intervals to a fixed length
max_r_r_length = max(len(intervals) for intervals in X_r_r_intervals)
X_r_r_padded = np.array([
    np.pad(intervals, (0, max_r_r_length - len(intervals)), 'constant')
    for intervals in X_r_r_intervals
])

print("Padded R-R intervals shape:", X_r_r_padded.shape)

# Normalize features
scaler_ecg = StandardScaler()
X_pqrst_normalized = scaler_ecg.fit_transform(X_pqrst_features)

scaler_rr = StandardScaler()
X_r_r_normalized = scaler_rr.fit_transform(X_r_r_padded)

# Combine normalized features
X_combined = np.hstack((X_pqrst_normalized, X_r_r_normalized))

print("Combined features shape:", X_combined.shape)


Padded R-R intervals shape: (6113, 17)
Combined features shape: (6113, 33)


## 7. Train/Test Split


In [7]:
# Split the data
X_train, X_temp, y_train, y_temp = train_test_split(
    X_combined, y, test_size=0.2, random_state=42
)
X_rr_train, X_rr_temp = train_test_split(
    X_r_r_padded, test_size=0.2, random_state=42
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)
X_rr_val, X_rr_test = train_test_split(
    X_rr_temp, test_size=0.5, random_state=42
)

print("Training set shape:", X_train.shape)
print("Validation set shape:", X_val.shape)
print("Test set shape:", X_test.shape)


Training set shape: (4890, 33)
Validation set shape: (611, 33)
Test set shape: (612, 33)


## 8. Class Balancing with SMOTETomek

In [8]:
# Step 5: Class Balancing with SMOTETomek
X_combined_train = np.hstack((X_train, X_rr_train))
smote_tomek = SMOTETomek(random_state=42)

X_combined_resampled, y_resampled = smote_tomek.fit_resample(X_combined_train, y_train)

# Split the resampled data back into ECG and R-R interval features
num_ecg_features = X_train.shape[1]
X_resampled = X_combined_resampled[:, :num_ecg_features]
X_rr_resampled = X_combined_resampled[:, num_ecg_features:]

print("Resampled data shape:", X_resampled.shape)
print("Resampled labels distribution:", Counter(y_resampled))


Resampled data shape: (12168, 33)
Resampled labels distribution: Counter({3: 4058, 1: 4056, 0: 4054})


## 9. Reshape Data for Model Input

In [9]:
# Reshape ECG features for CNN input
X_resampled = X_resampled.reshape(-1, X_resampled.shape[1], 1)
X_train = X_train.reshape(-1, X_train.shape[1], 1)
X_val = X_val.reshape(-1, X_val.shape[1], 1)
X_test = X_test.reshape(-1, X_test.shape[1], 1)

print("Reshaped ECG features shape:", X_resampled.shape)


Reshaped ECG features shape: (12168, 33, 1)


## 10. Define the Model Architecture

In [22]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.optimizers import Adam

# Custom ResNet block
def resnet_block(input_data, filters, kernel_size, stride=1, dropout_rate=0.3, l2_lambda=0.001):
    x = layers.Conv1D(filters=filters, kernel_size=kernel_size, strides=stride, padding='same',
                      kernel_regularizer=regularizers.l2(l2_lambda))(input_data)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv1D(filters=filters, kernel_size=kernel_size, strides=1, padding='same',
                      kernel_regularizer=regularizers.l2(l2_lambda))(x)
    x = layers.BatchNormalization()(x)

    shortcut = layers.Conv1D(filters=filters, kernel_size=1, strides=stride, padding='same',
                             kernel_regularizer=regularizers.l2(l2_lambda))(input_data)
    shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    x = layers.Activation('relu')(x)

    if dropout_rate:
        x = layers.Dropout(dropout_rate)(x)

    return x

# Build custom ResNet model
def build_custom_resnet(input_shape, dropout_rate=0.3, l2_lambda=0.001):
    inputs = layers.Input(shape=input_shape)

    x = resnet_block(inputs, filters=32, kernel_size=16, stride=1, dropout_rate=dropout_rate, l2_lambda=l2_lambda)
    x = resnet_block(x, filters=64, kernel_size=16, stride=2, dropout_rate=dropout_rate, l2_lambda=l2_lambda)
    x = resnet_block(x, filters=128, kernel_size=16, stride=2, dropout_rate=dropout_rate, l2_lambda=l2_lambda)
    
    x = layers.GlobalAveragePooling1D()(x)
    model = models.Model(inputs, x, name="CustomResNet")
    return model

# Chain of Thought block
def chain_of_thought_block(input_data, units=128):
    x = layers.Dense(units, activation='relu')(input_data)
    x = layers.Dropout(0.3)(x)

    # Residual connection
    x_res = layers.Dense(units, activation='relu')(x)
    x_res = layers.Dropout(0.3)(x_res)

    x = layers.Add()([x_res, x])
    return x

# Multi-Head Attention block
def multi_head_attention_block(input_data, head_size, num_heads, ff_dim, dropout=0.3):
    x = layers.Reshape((1, -1))(input_data)  # Reshape to 3D tensor for attention
    x_att = layers.MultiHeadAttention(key_dim=head_size, num_heads=num_heads, dropout=dropout)(x, x)
    x = layers.Add()([x_att, x])
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    x_ff = layers.Dense(ff_dim, activation='relu')(x)
    x_ff = layers.Dropout(dropout)(x_ff)
    x_ff = layers.Dense(input_data.shape[-1])(x_ff)
    x = layers.Add()([x_ff, x])
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    x = layers.Flatten()(x)  # Flatten back to 2D tensor
    return x

# Build the final model
def build_final_model(input_shape_ecg, input_shape_rr, num_classes, head_size=128, num_heads=4, ff_dim=256, dropout=0.3, dropout_rate=0.3):
    # Input Layers
    input_ecg = layers.Input(shape=input_shape_ecg, name="ECG_Input")
    input_rr = layers.Input(shape=input_shape_rr, name="RR_Interval_Input")

    # Custom ResNet for ECG feature extraction
    resnet_model = build_custom_resnet(input_shape_ecg, dropout_rate=dropout_rate)
    resnet_features = resnet_model(input_ecg)

    # Chain of Thought Mechanism on ResNet features
    cot_output = chain_of_thought_block(resnet_features, units=128)

    # Multi-Head Attention Block with Self-Attention on ResNet features
    attention_output = multi_head_attention_block(cot_output, head_size=head_size, num_heads=num_heads, ff_dim=ff_dim, dropout=dropout)

    # Processing R-R Interval Features
    rr_dense = layers.Dense(128, activation="relu")(input_rr)
    rr_dense = layers.Reshape((1, 128))(rr_dense)

    # Combine ResNet features (ECG) and RR Interval features
    combined_features = layers.Concatenate()([attention_output, layers.Flatten()(rr_dense)])

    # Fully Connected Layers for Final Classification
    x = layers.Dense(128, activation='relu')(combined_features)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    output = layers.Dense(num_classes, activation='softmax')(x)

    # Final model with two inputs
    model = models.Model(inputs=[input_ecg, input_rr], outputs=output, name="ECG_Classification_Model")
    return model

# Prepare your data (make sure to replace with your actual data variables)
# Ensure that X_resampled and X_val are reshaped to (num_samples, sequence_length, 1)
# For example:
# X_resampled = X_resampled.reshape(-1, X_resampled.shape[1], 1)
# X_val = X_val.reshape(-1, X_val.shape[1], 1)

# Define input shapes for ECG and R-R intervals
input_shape_ecg = (X_resampled.shape[1], 1)  # ECG data shape
input_shape_rr = (X_rr_resampled.shape[1],)  # R-R interval data shape

# Define the number of classes
num_classes = len(np.unique(y_resampled))

# Build and compile the final model
model = build_final_model(input_shape_ecg, input_shape_rr, num_classes)

# Compile with gradient clipping
optimizer = Adam(learning_rate=1e-6, clipnorm=1.0)
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=['accuracy']
)


# Display the model summary
model.summary()


Model: "ECG_Classification_Model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 ECG_Input (InputLayer)         [(None, 33, 1)]      0           []                               
                                                                                                  
 CustomResNet (Functional)      (None, 128)          522048      ['ECG_Input[0][0]']              
                                                                                                  
 dense_24 (Dense)               (None, 128)          16512       ['CustomResNet[0][0]']           
                                                                                                  
 dropout_27 (Dropout)           (None, 128)          0           ['dense_24[0][0]']               
                                                                           

## 12. Set Up Callbacks

In [23]:
# Callbacks
log_dir = f"logs/tensorboard/Novel_Model/{datetime.now().strftime('%Y%m%d-%H%M%S')}"
checkpoint_callback = ModelCheckpoint(
    filepath='novel_model_best.h5', save_best_only=True, verbose=1
)
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)
lr_scheduler = ReduceLROnPlateau(
    monitor='val_loss', factor=0.5, patience=3, min_lr=1e-8
)
early_stopping = EarlyStopping(
    monitor='val_loss', patience=10, restore_best_weights=True
)

callbacks = [
    checkpoint_callback,
    tensorboard_callback,
    lr_scheduler,
    early_stopping
]


## 13. Train the Model


In [24]:
# Model Training
history = model.fit(
    [X_resampled, X_rr_resampled], y_resampled,
    validation_data=([X_val, X_rr_val], y_val),
    epochs=100,
    batch_size=32,
    callbacks=callbacks,
    verbose=1
)


Epoch 1/100
Epoch 1: val_loss did not improve from inf
Epoch 2/100
Epoch 2: val_loss did not improve from inf
Epoch 3/100

KeyboardInterrupt: 

## 14. Evaluate the Model

In [None]:
# Model Evaluation
evaluation = model.evaluate([X_test, X_rr_test], y_test)
print(f"Test Loss: {evaluation[0]:.4f}")
print(f"Test Accuracy: {evaluation[1]:.4f}")


## 15. Make Predictions on Test Data

In [None]:
# Predict on test data
y_pred_probs = model.predict([X_test, X_rr_test])
y_pred = np.argmax(y_pred_probs, axis=1)


## 16. Confusion Matrix

In [None]:
# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Plot confusion matrix
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')

# Labeling the axes
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')

# Display the plot
plt.show()

# Classification Report
print("Classification Report:")
print(classification_report(y_test, y_pred))


In [None]:
# Plot training & validation accuracy values
plt.figure(figsize=(12, 4))
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')

plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='upper left')

plt.show()


# Plot training & validation loss values
plt.figure(figsize=(12, 4))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')

plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper left')

plt.show()


In [25]:
# Check for NaN and Inf in ECG data
print("NaNs in X_resampled:", np.isnan(X_resampled).any())
print("Infs in X_resampled:", np.isinf(X_resampled).any())

# Check for NaN and Inf in RR interval data
print("NaNs in X_rr_resampled:", np.isnan(X_rr_resampled).any())
print("Infs in X_rr_resampled:", np.isinf(X_rr_resampled).any())

# Check for NaN and Inf in labels
print("NaNs in y_resampled:", np.isnan(y_resampled).any())
print("Infs in y_resampled:", np.isinf(y_resampled).any())



NaNs in X_resampled: False
Infs in X_resampled: False
NaNs in X_rr_resampled: False
Infs in X_rr_resampled: False
NaNs in y_resampled: False
Infs in y_resampled: False
