# Approach A Data Loading and Model Training

### This notebook contains the full machine learning pipeline used in training the models 

### Step 1: Import all libraries to be used

In [None]:
import h5py
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
import tensorflow as tf
from tensorflow.keras import layers, models, regularizers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.optimizers import Adam
from scipy.signal import butter, filtfilt

### Step 2: Data preprocessing

This step loads in all the testing data, passes it through a high pass filter, and normalises the data, saving the normalisation constants in numpy arrays 

In [None]:

# This is a list of all combined parts (traces and csv data with 80% normal_ecg undersampling)
combined_parts = [
    'drive/MyDrive/Combined_Part_0.h5',
    'drive/MyDrive/Combined_Part_1.h5',
    'drive/MyDrive/Combined_Part_2.h5',
    'drive/MyDrive/Combined_Part_3.h5',
    'drive/MyDrive/Combined_Part_4.h5',
    'drive/MyDrive/Combined_Part_5.h5',
    'drive/MyDrive/Combined_Part_6.h5',
    'drive/MyDrive/Combined_Part_7.h5',
    'drive/MyDrive/Combined_Part_8.h5',
    'drive/MyDrive/Combined_Part_9.h5',
    'drive/MyDrive/Combined_Part_10.h5',
    'drive/MyDrive/Combined_Part_11.h5',
    'drive/MyDrive/Combined_Part_12.h5',
    'drive/MyDrive/Combined_Part_13.h5'
]


# This defines the High-pass Butterworth filter parameters
def highpass_filter(data, cutoff=0.5, fs=400, order=4):

    nyquist = 0.5 * fs
    normal_cutoff = cutoff / nyquist
    b, a = butter(order, normal_cutoff, btype='high', analog=False)
    filtered_data = filtfilt(b, a, data, axis=0)  # Apply the filter along each lead (axis=0)
    return filtered_data


# Here we initialise empty lists to store the data
all_tracings = []
all_labels = []



def load_combined_data_with_filter(hdf5_files):

    arrhythmias = ['1dAVb', 'RBBB', 'LBBB', 'SB', 'ST', 'AF', 'normal_ecg']
    ecg_data = []
    labels = []

    # Iterate through each HDF5 file
    for hdf5_file in hdf5_files:
        with h5py.File(hdf5_file, 'r') as f:
            for exam_id in f.keys():
                group = f[exam_id]
                tracing = group['tracing'][:]

                # Apply high-pass filter to the ECG data
                filtered_tracing = highpass_filter(tracing)
                ecg_data.append(filtered_tracing)

                # Extract arrhythmia labels
                label = [group.attrs[arrhythmia] for arrhythmia in arrhythmias]
                labels.append(label)

    return np.array(ecg_data), np.array(labels)

# Load data from combined HDF5 files with filtering
X, y = load_combined_data_with_filter(combined_parts)

# Separate normal ECGs and arrhythmias
normal_indices = np.where(y[:, -1] == 1)[0]  # NB! this assumes that 'normal_ecg' is the last column
arrhythmia_indices = np.where(y[:, -1] == 0)[0]  # NB! this assumes that 'normal_ecg' is the last column

# Undersample normal ECGs again by 75%
np.random.shuffle(normal_indices)
undersampled_normal_indices = normal_indices[:len(normal_indices) // 4]  # Keep only 25% of normal ECGs - this reduces the number of normal ecgs, preventing overfitting and massive class imbalance

# Combine undersampled normal ECGs with all arrhythmias
balanced_indices = np.concatenate([undersampled_normal_indices, arrhythmia_indices])

# Shuffle the combined dataset
np.random.shuffle(balanced_indices)

# Create balanced dataset
X_balanced = X[balanced_indices]
y_balanced = y[balanced_indices]

# Normalise the ECG tracings
train_mean = np.mean(X_balanced, axis=0)
train_std = np.std(X_balanced, axis=0)
X_balanced = (X_balanced - train_mean) / train_std  # Normalise across each feature/lead

# Save normalisation parameters to be used for testing purposes
np.save('hpf_train_mean.npy', train_mean)
np.save('hpf_train_std.npy', train_std)

# Here we define the input shape
input_shape = (4096, 12)

# initialise the Learning rate scheduler with warm-up and decay
initial_learning_rate = 1e-3
lr_schedule = ExponentialDecay(
    initial_learning_rate,
    decay_steps=100000,
    decay_rate=0.96,
    staircase=True
)

# Define the arrhythmia columns (ensure this is consistent with the above labels)
arrhythmia_columns = ['1dAVb', 'RBBB', 'LBBB', 'SB', 'ST', 'AF', 'normal_ecg']

### Step 3: Building the model, defining learning parameters, setting checkpoints, and model training

In [None]:
def build_model(input_shape):
    inputs = layers.Input(shape=input_shape)

    # Initial Convolutional Block
    x = layers.Conv1D(filters=32, kernel_size=3, padding='same')(inputs)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.SpatialDropout1D(0.2)(x)
    x = layers.MaxPooling1D(pool_size=2)(x)

    # Residual Block 1
    residual = layers.Conv1D(filters=64, kernel_size=1, padding='same')(x)
    x = layers.Conv1D(filters=64, kernel_size=3, padding='same')(x)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.Conv1D(filters=64, kernel_size=3, padding='same')(x)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.Add()([x, residual])
    x = layers.MaxPooling1D(pool_size=2)(x)

    # Additional Residual Block (New Block)
    residual = layers.Conv1D(filters=128, kernel_size=1, padding='same')(x)
    x = layers.Conv1D(filters=128, kernel_size=3, padding='same')(x)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.Conv1D(filters=128, kernel_size=3, padding='same')(x)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.Add()([x, residual])
    x = layers.MaxPooling1D(pool_size=2)(x)

    # Additional Convolutional Block before LSTM
    x = layers.Conv1D(filters=128, kernel_size=3, padding='same')(x)  # non-orig
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)  # non-orig
    x = layers.MaxPooling1D(pool_size=2)(x)

    # First LSTM Block
    x = layers.Bidirectional(layers.LSTM(32, return_sequences=True, dropout=0.3))(x)

    # Second LSTM Block (New LSTM Layer)
    x = layers.Bidirectional(layers.LSTM(32, return_sequences=True, dropout=0.3))(x)

    # Multi-head Attention Block
    attention = layers.MultiHeadAttention(num_heads=2, key_dim=16)(x, x)

    # Project attention output to match LSTM output dimensions
    attention = layers.Dense(64)(attention)

    # Correct shape alignment for addition
    x = layers.Add()([x, attention])

    # Global Max Pooling
    x = layers.GlobalMaxPooling1D()(x)

    # Dense Layers with Dropout and Regularisation
    x = layers.Dense(64, activation='relu', kernel_regularizer=regularizers.l2(0.001))(x)
    x = layers.Dropout(0.5)(x)

    # Output Layer
    outputs = layers.Dense(len(arrhythmia_columns), activation='sigmoid')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model


model = build_model(input_shape)

# Compile the model with Adam optimiser with a fixed learning rate
model.compile(optimizer=Adam(learning_rate=initial_learning_rate, clipvalue=1.0),
              loss='binary_crossentropy', metrics=['accuracy'])


# Calculate class weights to handle class imbalance and penalise incorrect 1dAVb predictions more harshly.
class_weights = compute_class_weight('balanced', classes=np.unique(y_balanced.argmax(axis=1)), y=y_balanced.argmax(axis=1))
class_weight_dict = {i: weight if i != arrhythmia_columns.index('1dAVb') else weight * 1.5 for i, weight in enumerate(class_weights)} #this made a big difference


# Callbacks for early stopping and learning rate reduction
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.0001)

#Checkpoint for saving best model after each epoch
checkpoint_filepath = 'hpf_best_model_checkpoint.keras'
# This creates a ModelCheckpoint callback to save the best model during training
model_checkpoint = ModelCheckpoint(
    filepath=checkpoint_filepath,  # Path to save the model
    monitor='val_loss',            # Monitor validation loss
    save_best_only=True,           # Only save when validation loss improves
    mode='min',                    # Save when validation loss decreases
    verbose=1                      # Print a message when saving
)

# Train the model with class weights
model.fit(X_balanced, y_balanced, epochs=100, batch_size=32, validation_split=0.2,
          callbacks=[early_stopping, reduce_lr, model_checkpoint], class_weight=class_weight_dict)
# Save the model and weights
model.save('hpf_trained_model.keras')
model.save('hpf_trained_model.h5')
model.save_weights('hpf_model_weights.weights.h5')

# Save normalisation data to a text file
with open('hpf_normalization_data.txt', 'w') as f:
    f.write(f'Mean: {train_mean}\n')
    f.write(f'Std: {train_std}\n')