In [None]:
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout, Lambda, Reshape
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras import backend as K
import numpy as np
import pandas as pd

In [None]:
# Load the ECGs for the main project
distdr_medians = np.load('DISTDR_medians_precprocessed.npy')

#Crop these as well, get the exterimities
distdr_medians_cropped = distdr_medians[:, :6, 20:-20]
distdr_medians_cropped.shape

In [None]:
# Define input layer
input_signal = Input(shape=(6, 210,))
latent_dim = 16

# Encoder layers
x = Flatten()(input_signal)
x = Dense(128, activation='relu', trainable=True)(x)
x = Dense(64, activation='relu', trainable=True)(x)

# Define mean and log-variance layers for latent variables
z_mean = Dense(latent_dim, trainable=True)(x)
z_log_var = Dense(latent_dim, trainable=True)(x)

# Define sampling layer
def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=1.)
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling)([z_mean, z_log_var])

# Decoder layers
x = Dense(64, activation='relu', trainable=True)(z)
x = Dense(128, activation='relu', trainable=True)(x)
decoded = Dense(6 * 210, activation='sigmoid', trainable=True)(x)
decoded = Reshape((6, 210))(decoded)


vae_model = Model(input_signal, decoded)

vae_model.load_weights('pretrained_autoencoder_weights.h5')

# Define VAE loss function
beta = 1
def vae_loss(input_signal, decoded):
    mse_loss = K.mean(K.square(input_signal - decoded), axis=(1,2))
    kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    return mse_loss + beta * kl_loss

adam = Adam(learning_rate=0.001)
vae_model.compile(optimizer=adam, loss=vae_loss)

# Define encoder model
encoder_model = Model(input_signal, z_mean)

early_stopping = EarlyStopping(monitor='val_loss',
                               patience=2,
                               verbose=1,
                               mode='min')

# Train the VAE model
history = vae_model.fit(distdr_medians_cropped, distdr_medians_cropped,
                        epochs=128,
                        batch_size=128,
                        shuffle=True,
                        validation_split=0.15,
                        callbacks=[early_stopping])

# Reconstruct input signal
reconstructed_signal = vae_model.predict(distdr_medians_cropped)

indices = np.random.choice(distdr_medians_cropped.shape[0], 77099, replace=False)
pearson_all = list()
for i, idx in enumerate(indices):
        for j in range(1, 6):
            correlation_matrix = np.corrcoef(distdr_medians_cropped[i, j], reconstructed_signal[i, j], rowvar=False)
            pearson_correlation = correlation_matrix[0, 1]
            pearson_all.append(pearson_correlation)
print("Pearson mean", np.mean(pearson_all))
print("Pearson std", np.std(pearson_all))

# Calculate metrics 

In [None]:
# Create arrays to store RMSE and PRD values
rmse_all = []
prd_all = []

# Iterate over each sample
for i, idx in enumerate(range(len(distdr_medians_cropped))):
    for j in range(1, 6):
        # Min-max normalization on the original and reconstructed signals
        original_signal_normalized = (distdr_medians_cropped[idx, j] - np.min(distdr_medians_cropped[idx, j])) / (np.max(distdr_medians_cropped[idx, j]) - np.min(distdr_medians_cropped[idx, j]))
        reconstructed_signal_normalized = (reconstructed_signal[idx, j] - np.min(reconstructed_signal[idx, j])) / (np.max(reconstructed_signal[idx, j]) - np.min(reconstructed_signal[idx, j]))

        # Calculate RMSE
        mse = np.mean((original_signal_normalized - reconstructed_signal_normalized) ** 2)
        rmse = np.sqrt(mse)
        rmse_all.append(rmse)

        # Calculate PRD
        numerator = np.sum((original_signal_normalized - reconstructed_signal_normalized) ** 2)
        denominator = np.sum(original_signal_normalized ** 2)
        prd = 100 * np.sqrt(numerator / denominator)
        prd_all.append(prd)

# Calculate and print RMSE and PRD statistics
rmse_mean = np.mean(rmse_all)
rmse_std = np.std(rmse_all)

prd_mean = np.mean(prd_all)
prd_std = np.std(prd_all)

print("RMSE mean:", rmse_mean)
print("RMSE std:", rmse_std)

print("PRD mean:", prd_mean)
print("PRD std:", prd_std)


In [None]:
from fastdtw import fastdtw

# Reconstruct input signal
reconstructed_signal = vae_model.predict(distdr_medians_cropped)

# Create an array to store DTW distances
dtw_distances = []

# Iterate over each sample
for i, idx in enumerate(range(len(distdr_medians_cropped))):
    for j in range(1, 6):
        # Calculate DTW distance
        dist, _ = fastdtw(distdr_medians_cropped[idx, j], reconstructed_signal[idx, j])
        dtw_distances.append(dist)

# Calculate and print DTW statistics
dtw_mean = np.mean(dtw_distances)
dtw_std = np.std(dtw_distances)

print("DTW mean:", dtw_mean)
print("DTW std:", dtw_std)

In [None]:
encoded_imgs = encoder_model.predict(distdr_medians_cropped)
encoded_imgs_df = pd.DataFrame(encoded_imgs)

# Reconstructions versus original 

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(16, 10))

# Generate random case index
case_idx = np.random.randint(len(distdr_medians_cropped))
titles = ['I', 'II', 'III', 'aVR', 'avL', 'avF']

time_axis = np.linspace(0, 1, num=len(distdr_medians_cropped[case_idx, 0, :]))

for i, ax in enumerate(axes.flatten()):
    ax.plot(time_axis, distdr_medians_cropped[case_idx, i, :], linestyle='-', linewidth = 3.5, color='darkblue', alpha=0.65)
    ax.grid(linewidth=0.5, color='gray', linestyle='--')
    ax.set_title(f'{titles[i]}')
    ax.tick_params(axis='both', which='both', length=0)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.plot(time_axis, reconstructed_signal[case_idx, i, :], linestyle='-', color='darkorange',linewidth = 3.5, alpha=0.65)
    ax.grid(linewidth=0.5, color='gray', linestyle='--')
    ax.set_title(f'{titles[i]}')
    
correlation_matrix = np.corrcoef(reconstructed_signal[case_idx, i, :], distdr_medians_cropped[case_idx, i, :], rowvar=False)
pearson_correlation = correlation_matrix[0, 1]
print("Pearson mean:", pearson_correlation)
plt.tight_layout()
plt.show()
plt.close()
