# Import necessary libraries

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from pathlib import Path
from _utils.load_data import fetch_data
from _clusterers.dtw_kmedoids import DTWCluster
from tslearn.utils import to_time_series_dataset
from _utils.utils import visualize_clustering_results, write2excel

# Helper functions

In [None]:
def min_max_scaling(cvi):
    cvi = np.array(cvi) + 1e-05
    return (cvi - np.min(cvi)) / (np.max(cvi) - np.min(cvi))


def cluster_evaluation(training_data, path, max_clusters=10):
    sh_index = []
    dunn_index = []
    inertia_index = []
    fig1, ax1 = plt.subplots()
    fig2, ax2 = plt.subplots()

    for cl in tqdm(range(2, max_clusters + 1), total=max_clusters-1):
        res, _ = DTWCluster(cl, normalize=False).return_cvi(X=training_data)
        sh_index.append(res["silhouette"])
        dunn_index.append(res["dunn"])
        inertia_index.append(res["inertia"])

    sh_index = min_max_scaling(sh_index)
    dunn_index = min_max_scaling(dunn_index)
    ccvi = (sh_index + dunn_index)/2
    pcvi = np.sqrt(sh_index * dunn_index)

    ax1.plot(list(range(2, max_clusters + 1)), sh_index, "--b^", label="silhouette index")
    ax1.plot(list(range(2, max_clusters + 1)), dunn_index, "--g*", label="dunn index")
    ax1.plot(list(range(2, max_clusters + 1)), ccvi, "--r*", label="Average CVI")
    ax1.set_xlabel("Number of clusters")
    ax1.set_ylabel("Average value of individual CVIs")
    ax1.legend()
    fig1.savefig(os.path.join(path, "ccvi.png"))
    fig1.savefig(os.path.join(path, "ccvi.svg"))
    plt.close(fig1)

    ax2.plot(list(range(2, max_clusters + 1)), sh_index, "--b^", label="silhouette index")
    ax2.plot(list(range(2, max_clusters + 1)), dunn_index, "--g*", label="dunn index")
    ax2.plot(list(range(2, max_clusters + 1)), pcvi, "--r*", label="Average CVI")
    ax2.set_xlabel("Number of clusters")
    ax2.set_ylabel("Square root of the individual CVIs product.")
    ax2.legend()
    fig2.savefig(os.path.join(path, "pcvi.png"))
    fig2.savefig(os.path.join(path, "pcvi.svg"))
    plt.close(fig2)

# Load the data

In [None]:
parent_folder = Path(os.getcwd()).parent.absolute()
data_path = os.path.join(parent_folder, "Data/iCOMPEER")
sex_file = os.path.join(parent_folder, "Data/s64901_06JUN2023_all.xlsx")
save_data_path = os.path.join(parent_folder, "Results/dtw_kmedoids")

if not os.path.exists(save_data_path):
    os.makedirs(save_data_path)

male_cpet, female_cpet = fetch_data(data_path, sex_file)

# Pre-select data with clinical data

In [None]:
data = pd.read_excel(os.path.join(parent_folder, "Data/s64901_06JUN2023_all.xlsx"))
pids = data["record_id"].astype(str).tolist()

male_cpet = male_cpet[male_cpet["Patient IDs"].isin(pids)]
female_cpet = female_cpet[female_cpet["Patient IDs"].isin(pids)]

# Variable and sex selection

In [None]:
variables = ["HR", "V'O2", "RER", "PETO2", "PETCO2"]
sex="males"
cpet = {"males": male_cpet, "females": female_cpet}

data = []
patient_ids =[]
for i in tqdm(range(len(cpet[sex]))):
    data.append(cpet[sex]["CPET Data"].iloc[i][variables].to_numpy())
    patient_ids.append(cpet[sex]["Patient IDs"].iloc[i])

formatted_dataset = to_time_series_dataset(data)

# Find optimal number of clusters

In [None]:
cluster_evaluation(training_data=formatted_dataset, max_clusters=10, path=save_data_path)

# Perform clustering

In [None]:
from collections import Counter

model = DTWCluster(n_clusters=4, normalize=False).fit(X=formatted_dataset)
model.save(path=save_data_path)
clusters = model.labels_

print(Counter(clusters))

# Visualize clustering results

In [None]:
# the desired combinations of variables to visualise the clustering results
combinations = [("V'O2", "HR"), ("V'O2", "V'E"), ("V'O2", "V'CO2"),
                ("V'CO2", "V'E"), ("Time", "Load"), ("Time", "V'E"),
                ("Time", "PETO2"), ("Time", "PETCO2"), ("Time", "RER")]

# cluster renaming to force cluster 1 to be the one with the most favourable profile
cluster_labels = {0: 3, 1: 1, 2: 2, 3: 4}

# colours used for the clusters. '0' corresponds to cluster 1 and is illustrated with green colour.
cluster_colours = {0: "green", 1: "blue", 2: "darkorange", 3: "red", 4: "red"}

# write the cluster annotations in an .xlsx file
write2excel(clusters, list(cpet[sex]["Patient IDs"]), sex, save_data_path, cluster_labels=cluster_labels)
renamed_clusters = pd.read_excel(os.path.join(save_data_path, f'Clustering_assignments_{sex}.xlsx'))["Cluster"]

# visualise the clustering results
visualize_clustering_results(cpet[sex], renamed_clusters, combinations, sex, str(save_data_path), cluster_colours=cluster_colours)