In [None]:
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy
import string
import seaborn as sns
from tqdm.auto import tqdm
tqdm.pandas()
from sklearn import metrics
from pyckmeans import MultiCKMeans
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import confusion_matrix
from sklearn.cluster import AgglomerativeClustering
from scipy.cluster.hierarchy import dendrogram, fcluster, linkage

# try 3D t-SNE plot
import pylab
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d import proj3d
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from collections import Counter

%store -r time_interval_before_ONSET

import warnings
warnings.simplefilter(action="...", category=FutureWarning)

In [None]:
figure_dpi = 300
%store figure_dpi

In [None]:
color_key = {
    0: "...",
    1: "...",
    2: "...",
    3: "...",
    4: "...",
    5: "...",
    6: "...",
    7: "...",
    8: "..."
}

In [None]:
cls_label_mapping = {0: "...", 1: "...", 2: "..."}

# Read Data Processed in Notebook A

In [None]:
%store -r window_before_and_onset
%store -r window_after_onset
%store -r window_full

In [None]:
#Read in Data
full_data = pd.read_csv("...")

In [None]:
full_SCR = full_data.loc[:, window_full + ["..."]]

In [None]:
full_SCR

# Cluster by Time-series Features Extracted from Raw Measurements

We want to focus on the onset SCr and the 48h window before onset, thus we extract:   
1. Onset value itself
2. The difference between onset SCr and 48h before
3. The difference between onset SCr and SCr baseline
4. The difference between the average SCr value of the first 4 days of the window and the baseline.

In [None]:
ts_features = pd.DataFrame(0, index = full_SCR.index, 
                           columns = ["...", "...", 
                                      "...", "..."])

In [None]:
ts_features.loc[:, "..."] = full_SCR.loc[:, "..."]
ts_features.loc[:, "..."] = full_SCR.loc[:, "..."] - full_SCR.loc[:, "..."]
ts_features.loc[:, "..."] = full_SCR.loc[:, "..."] - full_SCR.loc[:, "..."]
ts_features.loc[:, "..."] = full_SCR.loc[:, ["...", "...", "...", "..."]].mean(axis = 1) - full_SCR.loc[:, "..."]

In [None]:
# normalized features
def min_max_norm(column):
    return (column - column.min()) / (column.max() - column.min())

In [None]:
ts_features_norm = ts_features.apply(min_max_norm, axis = 0)

# Try Hierarchical Clustering

get the clustering structure of the features

In [None]:
# adjust Hierarchical clustering colors
from scipy.cluster import hierarchy
import matplotlib as mpl
from matplotlib.colors import ListedColormap

# Define the RGB values for green, red, and blue
colors = np.array([
    [0, 1, 0],  # green
    [1, 0, 0],  # red
    [0, 0, 1]   # blue
])

# Create the colormap
custom_cmap = ListedColormap(colors)

# Use fixed indices to ensure distinct colors
indices = np.array([0, 1, 2])

# Apply the colormap to the indices
colored_values = custom_cmap(indices / (len(indices) - 1))

# Set the link color palette using the custom colormap values
hierarchy.set_link_color_palette([mpl.colors.rgb2hex(rgb) for rgb in colored_values])

In [None]:
# Function to plot dendrogram
def plot_dendrogram(ax, model, anno_line_pos, anno_line_test, color_threshold, **kwargs):
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    ax.set_title("...", fontsize = 15)
    dendro = dendrogram(linkage_matrix, color_threshold = color_threshold, ax=ax, **kwargs)
    
    if anno_line_pos and anno_line_test:
        ax.axhline(y=anno_line_pos, color="...", linestyle="...")
        ax.text(20, anno_line_pos + 0.5, anno_line_test, verticalalignment="...")
    
    ax.set_xticks([])
    
    ax.text(x=165, y=14, s=cls_label_mapping[0])
    ax.text(x=405, y=8, s=cls_label_mapping[2])
    ax.text(x=570, y=12, s=cls_label_mapping[1])

In [None]:
hier_cluster = AgglomerativeClustering(n_clusters=None, 
                                       distance_threshold=0)
hier_cluster.fit(ts_features_norm)

In [None]:
# test plot function here
fig, ax = plt.subplots(1, 1)
plot_dendrogram(ax, hier_cluster, anno_line_pos = 16, 
                anno_line_test = "...", color_threshold = 16, truncate_mode="...", p=5)

# Clustering Visualization and Analysis Functions

In [None]:
def show_cluster_size(clusters_assign):
    clusters_assign_names = [cls_label_mapping[c] for c in clusters_assign]
    counts = pd.Series(clusters_assign_names).value_counts().sort_index()
    df = pd.DataFrame({
    "...": counts, 
    "...": 100 * counts / len(clusters_assign)})
    print(df)

In [None]:
def adjust_clustering_res(cluster_res_1, cluster_res_2):
    assert(len(cluster_res_1) == len(cluster_res_2))
    conf_mat = confusion_matrix(cluster_res_1, cluster_res_2)

    row_ind, col_ind = linear_sum_assignment(-conf_mat)  
    mapping = dict(zip(col_ind, row_ind))

    cluster_res_2_new = [mapping[label] for label in cluster_res_2]
    return cluster_res_2_new

In [None]:
def create_confusion_matrix(labels1, labels2):
    "...""...""..."
    assert len(labels1) == len(labels2), "..."
    max_label1 = max(labels1)
    max_label2 = max(labels2)
    
    # Create an empty matrix
    matrix = np.zeros((max_label1 + 1, max_label2 + 1))
    
    for l1, l2 in zip(labels1, labels2):
        matrix[l1][l2] += 1

    # Normalize by rows to get percentages
    matrix = matrix / matrix.sum(axis=1, keepdims=True) * 100
    return matrix

In [None]:
def plot_2D_tsne(ax, data, labels, title):
    
    model = TSNE(n_components=2)
        
    transformed_data = model.fit_transform(data)

    # Find unique labels and their corresponding colors
    unique_labels = np.unique(labels)

    for label in unique_labels:
        # Filter the data points belonging to the current label
        label_data = transformed_data[labels == label]

        # Scatter plot for the current label data points
        ax.scatter(label_data[:, 0], label_data[:, 1], c=color_key[label], label=cls_label_mapping[label], s=1.0)

    # Add a legend
    ax.legend()

    # Set labels and title
    ax.set_xlabel("...")
    ax.set_ylabel("...")
    ax.set_title(title, fontsize = 15)

In [None]:
def switch_labels_according_to_freq(clusters_assign):
    counter = Counter(clusters_assign)

    sorted_items = sorted(counter.items(), key=lambda x: x[1], reverse=True)

    new_labels = {item[0]: i for i, item in enumerate(sorted_items)}

    new_list = [new_labels[item] for item in clusters_assign]

    return new_list

In [None]:
def plot_consensus_heatmap(ax, ckm_res_best_k_mtx, clusters_assign, cluster_n):
    df_from_consensus = pd.DataFrame(ckm_res_best_k_mtx, 
                        index=clusters_assign, columns=clusters_assign)
    sorted_df = df_from_consensus.sort_index(axis=0).sort_index(axis=1)
    df_for_heatmap = pd.DataFrame(sorted_df.values)
    num_indices = int(0.1 * df_for_heatmap.shape[0])
    sample_indices = sorted(np.random.choice(df_for_heatmap.index, size=num_indices, replace=False))
    df_for_heatmap = df_for_heatmap.loc[sample_indices, sample_indices]
    sns.heatmap(df_for_heatmap, annot = False, ax=ax)
    ax.set_title("..." + str(cluster_n), fontsize = 20)
    ax.set_xticks([])
    ax.set_yticks([])

In [None]:
def get_consensus_mtx_and_cluster_at_n(cluster_n, mckm_results):
    results_at_n = mckm_results.ckmeans_results[cluster_n - 2]
    consensus_mtx = results_at_n.cmatrix
    clusters_at_n = results_at_n.cl
    clusters_at_n = switch_labels_according_to_freq(clusters_at_n)
    return consensus_mtx, clusters_at_n

# Apply Consensus K-Means

In [None]:
max_n_cluster_to_explore = 5

In [None]:
mckm = MultiCKMeans(k=[i for i in range(2, max_n_cluster_to_explore + 1)], 
                                  n_rep=100, p_samp=0.8, p_feat=0.75)
mckm.fit(ts_features_norm)
mckm_results = mckm.predict(ts_features_norm)

In [None]:
mckm_results.plot_metrics(figsize=(10,5))

In [None]:
best_cluster_n = 3

In [None]:
results_at_n = mckm_results.ckmeans_results[best_cluster_n - 2]
clusters_at_n = results_at_n.cl

# adjust cluster label order
clusters_at_n = switch_labels_according_to_freq(clusters_at_n)
show_cluster_size(clusters_at_n)

In [None]:
full_data["..."] = clusters_at_n

In [None]:
full_data.to_csv("...", index = False)

In [None]:
fig, axs = plt.subplots(2,2, figsize = (12, 10))

consensus_mtx_at_n, clusters_at_n_heatmap = \
get_consensus_mtx_and_cluster_at_n(2, mckm_results)
plot_consensus_heatmap(axs[0,0], consensus_mtx_at_n, clusters_at_n_heatmap, 2)

consensus_mtx_at_n, clusters_at_n_heatmap = \
get_consensus_mtx_and_cluster_at_n(3, mckm_results)
plot_consensus_heatmap(axs[0,1], consensus_mtx_at_n, clusters_at_n_heatmap, 3)

consensus_mtx_at_n, clusters_at_n_heatmap = \
get_consensus_mtx_and_cluster_at_n(4, mckm_results)
plot_consensus_heatmap(axs[1,0], consensus_mtx_at_n, clusters_at_n_heatmap, 4)

consensus_mtx_at_n, clusters_at_n_heatmap = \
get_consensus_mtx_and_cluster_at_n(5, mckm_results)
plot_consensus_heatmap(axs[1,1], consensus_mtx_at_n, clusters_at_n_heatmap, 5)

# Adding sequential labels (A, B, C, D) to each subplot
labels = list(string.ascii_lowercase)
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

for label, pos in zip(labels, positions):
    axs[pos].text(-0.1, 1.1, label, transform=axs[pos].transAxes, 
                  fontsize=16, fontweight="...", va="...", ha="...")

plt.tight_layout()
plt.savefig("...", format="...", dpi = figure_dpi)
plt.show()

In [None]:
hier_at_n = AgglomerativeClustering(n_clusters = best_cluster_n).fit(ts_features_norm)
hier_at_n_clusters = hier_at_n.labels_

In [None]:
hier_at_n_clusters = adjust_clustering_res(clusters_at_n, hier_at_n_clusters)

In [None]:
def plot_two_alg_overlap(ax, clustering_1, clustering_2):
    confusion_matrix_two_alg = create_confusion_matrix(clustering_1, clustering_2)

    sns.heatmap(confusion_matrix_two_alg, cmap="...", ax = ax,
                annot = True, cbar_kws={"...": "..."}, annot_kws={"...": 20},
                fmt="...")
    ax.set_xlabel("...")
    ax.set_ylabel("...")
    ax.set_title("...", fontsize = 15)
    
    # Setting custom tick labels for x and y axis
    ax.set_xticklabels(["...", "...", "..."], fontsize=12)
    ax.set_yticklabels(["...", "...", "..."], fontsize=12)

In [None]:
fig, axs = plt.subplots(2,2, figsize = (12, 10))
plot_dendrogram(axs[0, 0], hier_cluster, anno_line_pos = 16, 
                anno_line_test = "...", color_threshold = 16, truncate_mode="...", p=5)
plot_two_alg_overlap(axs[0, 1], hier_at_n_clusters, clusters_at_n)
plot_2D_tsne(axs[1, 0], ts_features_norm, hier_at_n_clusters, "...")
plot_2D_tsne(axs[1, 1], ts_features_norm, clusters_at_n, "...")

# Adding sequential labels (A, B, C, D) to each subplot
labels = list(string.ascii_lowercase)
positions = [(0, 0), (0, 1), (1, 0), (1, 1)]

for label, pos in zip(labels, positions):
    axs[pos].text(-0.1, 1.1, label, transform=axs[pos].transAxes, 
                  fontsize=16, fontweight="...", va="...", ha="...")

plt.tight_layout()
plt.savefig("...", format="...", dpi = figure_dpi)
plt.show()

In [None]:
def plot_trajectory_cluster_trend(pat_info, cluster_col_name, SCR_window, title, 
                                  color_key, figure_dpi):
    
    best_cluster_num = np.max(pat_info[cluster_col_name]) + 1
    
    plt.figure(figsize = (10, 5))
 
    for c in range(best_cluster_num):
        this_cluster = pat_info[pat_info[cluster_col_name] == c]
        trajectory_arr = this_cluster[SCR_window].values
        trajectory_50th = np.nanpercentile(trajectory_arr, 50, axis = 0)
        trajectory_1th = np.nanpercentile(trajectory_arr, 1, axis = 0)
        trajectory_99th = np.nanpercentile(trajectory_arr, 99, axis = 0)
        line, = plt.plot(SCR_window, trajectory_50th, "...", label=cls_label_mapping[c], 
                         color = color_key[c])
        line_color = line.get_color()
        plt.fill_between(SCR_window, trajectory_1th, 
                         trajectory_99th, color=line_color, alpha=0.15, linewidth=2)
    
    
    y_min, y_max = plt.ylim()
    y_annotation = y_min + (y_max - y_min) * 5/6
    plt.text(6.85, y_annotation, "...", fontsize=12, 
             verticalalignment="...", horizontalalignment="...")
    plt.axvline(x = "...", color="...", linestyle="...", linewidth=1)
    
    
    plt.annotate("...", xy=(6, 2.8), xytext=(0, 2.8), arrowprops=dict(arrowstyle="...", 
                                                                   linestyle="...", color="..."))
    
    plt.text(3, 2.8, "...", fontsize=12, 
             verticalalignment="...", horizontalalignment="...")
    
    plt.annotate("...", xy=(0.8, 2.00), xytext=(0.8, 1.00),
            arrowprops=dict(arrowstyle="...", linestyle="...", color="..."))
    
    plt.text(1.4, 1.4, "...", fontsize=12, 
             verticalalignment="...", horizontalalignment="...")
    
    plt.annotate("...", xy=(1.2, 1.30), xytext=(1.2, 0.6),
            arrowprops=dict(arrowstyle="...", linestyle="...", color="..."))
    
    plt.text(1.8, 0.8, "...", fontsize=12, 
             verticalalignment="...", horizontalalignment="...")
    

        
    plt.title(title)
    
    plt.xlabel("...")
    plt.ylabel("...")
    plt.legend(loc = "...")
    plt.tight_layout()
    plt.savefig("...", format="...", dpi = figure_dpi)
    plt.show()

In [None]:
plot_trajectory_cluster_trend(full_data, "...", window_before_and_onset + window_after_onset, 
                              "...", color_key, figure_dpi)