In [None]:
from pathlib import Path
import load_data as ld
import shap
from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score
from sklearn.mixture import GaussianMixture
import pickle
from scipy.spatial import distance
from piecewise import *

In [None]:
def calculate_inertia(predictions, data, centroids):
    dist = 0
    data = data.to_numpy()
    for k in range(len(data)):
        centroid_index = predictions[k]
        dist += distance.euclidean(data[k], centroids[centroid_index])
    return dist

def define_number_of_clusters(z, path):
    inertia_score = []
    sh_score = []
    db_score = []
    bic_score = []
    ch_score = []
    for cl in range(2, 11):
        clustering = GaussianMixture(n_components=cl, n_init=100, random_state=0, covariance_type="diag").fit(z)
        predictions = clustering.predict(z)

        bic_score.append(clustering.bic(z))
        inertia_score.append(calculate_inertia(predictions, z, clustering.means_))
        sh_score.append(silhouette_score(z, predictions))
        db_score.append(davies_bouldin_score(z, predictions))
        ch_score.append(calinski_harabasz_score(z, predictions))

    inertia_score = np.array(inertia_score)
    inertia_score = inertia_score/max(inertia_score)

    plt.figure()
    plt.plot(list(range(2, 11)), inertia_score, "-*")
    plt.xlabel("Number of Clusters")
    plt.ylabel("Sum of Distances (Normalised)")
    plt.xticks(list(range(2, 11)), list(range(2, 11)))
    plt.savefig(os.path.join(path, 'Define Clusters (Inertia Score).svg'))
    plt.savefig(os.path.join(path, 'Define Clusters (Inertia Score).png'))
    plt.show()

    plt.figure()
    plt.plot(list(range(2, 11)), sh_score, "-*")
    plt.xlabel("Number of Clusters")
    plt.ylabel("Silhouette Score")
    plt.xticks(list(range(2, 11)), list(range(2, 11)))
    plt.savefig(os.path.join(path, 'Define Clusters (Silhouette).svg'))
    plt.savefig(os.path.join(path, 'Define Clusters (Silhouette).png'))
    plt.show()

    plt.figure()
    plt.plot(list(range(2, 11)), db_score, "-*")
    plt.xlabel("Number of Clusters")
    plt.ylabel("Davies-Bouldin Score")
    plt.xticks(list(range(2, 11)), list(range(2, 11)))
    plt.savefig(os.path.join(path, 'Define Clusters (Davies-Bouldin).svg'))
    plt.savefig(os.path.join(path, 'Define Clusters (Davies-Bouldin).png'))
    plt.show()

    plt.figure()
    plt.plot(list(range(2, 11)), bic_score, "-*")
    plt.xlabel("Number of Clusters")
    plt.ylabel("BIC Score")
    plt.xticks(list(range(2, 11)), list(range(2, 11)))
    plt.savefig(os.path.join(path, 'Define Clusters (BIC).svg'))
    plt.savefig(os.path.join(path, 'Define Clusters (BIC).png'))
    plt.show()

    plt.figure()
    plt.plot(list(range(2, 11)), ch_score, "-*")
    plt.xlabel("Number of Clusters")
    plt.ylabel("Calinski-Harabasz Score")
    plt.xticks(list(range(2, 11)), list(range(2, 11)))
    plt.savefig(os.path.join(path, 'Define Clusters (CH).svg'))
    plt.savefig(os.path.join(path, 'Define Clusters (CH).png'))
    plt.show()

def retrieve_strain(root_path, full_path, avc, p_waves):
    #read the time, strain and ECG data from the .txt files
    original_data, data, patient_id, interval = ld.read_data(full_path)

    #Obtain the time AVC from the respective .xlsx file, as annotated manually be an expert
    #IDs of patients that do not have a measurement are included in the "excluded_patients1" variable
    excluded_patients1, avc_times = ld.read_avc_time(root_path, avc)

    #Read the time at which the peak of the P-wave occurs, as annotated manually be an expert
    #IDs of patients that do not have a measurement are included in the "excluded_patients2" variable
    excluded_patients2, p_wave_times = ld.read_p_wave_data(root_path, p_waves)

    #Remove from the time, strain and ECG data, the measurements that correspond to IDS that should be excluded
    original_data, data, patient_id, interval = exclude_patients(excluded_patients1, excluded_patients2, original_data,
                                                                 data, patient_id, interval)
    return original_data, patient_id, interval, avc_times, p_wave_times

In [None]:
## Load the data
parent_folder = os.path.join(*Path(os.getcwd()).parts[:-2], "LV Strain Curves")
data_path_16 = os.path.join(parent_folder, "Data/FLEMENGHO/Strain curves - Sfile16 (Filtered)")
#manual annotation of the Aortic Valve Closure
avc_files = ["Data/FLEMENGHO/AVC time_16_all.xlsx", "Data/FLEMENGHO/Patients Without AVC_TK.xlsx"]
#manual annotation of the P-wave in the ECG
marker_file = ["Data/FLEMENGHO/Patients for manual annotation of markers_TK.xlsx"]

LV_original_data, LV_patient_id, LV_interval, aortic_closure, p_wv = retrieve_strain(parent_folder, data_path_16, avc_files, marker_file)

In [None]:
## Temporal Alignment
decision = "peak"
reference_patient_id = "1687"
LV_ecg_aligned, LV_deformation, _, _, LV_reference_time = get_aligned_signals(LV_original_data, decision, LV_interval,
                                                                              LV_patient_id,reference_patient_id, aortic_closure, p_wv)

LV_deformation = np.array(LV_deformation)

In [None]:
## Extract the desired features
save_data_path = os.path.join(*Path(os.getcwd()).parts[:-2], f"Results/Summary Index/GMM/LV")
if not os.path.exists(save_data_path):
    os.makedirs(save_data_path)

extracted_features = extract_time_series_features(LV_reference_time, LV_deformation, LV_patient_id, save_data_path, do_plot=False)
training_data = {"peak": extracted_features["Peak"], "peak_slopes": extracted_features[["Systolic Slope", "Diastolic Slope","Peak"]], "all":extracted_features}

In [None]:
select_features = "all"
save_data_path = os.path.join(save_data_path, select_features)
if not os.path.exists(save_data_path):
    os.makedirs(save_data_path)

In [None]:
## Search for the optimal number of clusters
define_number_of_clusters(training_data[select_features], save_data_path)

In [None]:
## Perform clustering with 4 clusters
clustering_model = GaussianMixture(n_components=4, n_init=30, random_state=0, covariance_type="diag").fit(training_data[select_features])
clusters = clustering_model.predict(training_data[select_features])
centres = clustering_model.means_
pickle.dump(clustering_model, open(os.path.join(save_data_path, "gmm_model.pkl"), 'wb'))

In [None]:
## Explain the model with SHAP values
explainer = shap.KernelExplainer(clustering_model.predict_proba, training_data[select_features])
shap_values_bsw = explainer.shap_values(training_data[select_features])

representative_centers = np.array(produce_centroids(clusters, LV_deformation))

In [None]:
## Plot the SHAP values per cluster
cluster_labels = {0:1, 1:4, 2:2, 3:3}
cluster_colours = {0:"green", 1:"red", 2:"blue", 3:"blueviolet"}

for i in range(len(np.unique(clusters))):
    shap.summary_plot(shap_values=shap_values_bsw[i], features=training_data[select_features], title=f"Cluster {cluster_labels[i]}", show=False)
    plt.savefig(os.path.join(save_data_path, f"Shap Cluster {cluster_labels[i]}.svg"))
    plt.savefig(os.path.join(save_data_path, f"Shap Cluster {cluster_labels[i]}.png"))
    plt.close()

In [None]:
## Visualize the clustering results

#groups the patients' ids per cluster. Returns a list
clustered_id = analyze_patient(clusters, LV_patient_id)

#writes an excel file with the patient's ID and its assigned cluster label
write2excel(clusters, LV_patient_id, save_data_path, cluster_labels)

#plots the strain traces grouped per cluster with Matplotlib and plotly. Matplotlib produces png and svg files.
visualize_clustering_results(LV_reference_time, LV_deformation, clustered_id, clusters,
                             LV_patient_id, representative_centers, save_data_path,
                             cluster_labels=cluster_labels, cluster_colours=cluster_colours)

#plot the first three principal components of the strain curves
plot_pca(clusters, LV_deformation, LV_patient_id, save_data_path,
         cluster_labels=cluster_labels, cluster_colours=cluster_colours)

#plots the gradient of the Strain traces with Matplotlib and plotly.
plot_gradients(LV_deformation, LV_reference_time, clusters, clustered_id,
               LV_patient_id, save_data_path, cluster_labels=cluster_labels,
               cluster_colours=cluster_colours)