In [1]:
cd ..

# Multi-omics stratification on PDAC patients

In [3]:
import os
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.cluster import KMeans
import plotly.express as px
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
import optuna
import dill
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.seed import isolate_rng
import torch
import seaborn as sns
from lifelines.statistics import multivariate_logrank_test
from lifelines import KaplanMeierFitter, CoxPHFitter
import matplotlib.pyplot as plt
from lifelines.plotting import add_at_risk_counts
from scipy.stats import chi2_contingency, kruskal, wilcoxon

from src import settings
from utils import MultiViewDataset, transform_full_dataset
from optimization import Optimization

## Load dataset

In [4]:
clinical_data = pd.read_csv(settings.clinical_data_path, sep="\t")
clinical_data = clinical_data.set_index("Patient ID")
clinical_data.index.name = None
print("clinical_data.shape", clinical_data.shape)
clinical_data.head()

In [5]:
methylation_data = pd.read_csv(settings.methylation_data_path, sep=";", index_col=0, decimal=",")
methylation_data.columns = methylation_data.columns.str.replace(".", "-")
methylation_data = methylation_data.T
methylation_data = methylation_data.astype(np.float32)
print("methylation_data.shape", methylation_data.shape)
methylation_data.head()

In [8]:
rnaseq_data = pd.read_csv(settings.rnaseq_data_path, sep=";", index_col=0, decimal=",")
rnaseq_data = rnaseq_data.T
rnaseq_data = rnaseq_data.astype(np.float32)
print("rnaseq_data.shape", rnaseq_data.shape)
rnaseq_data.head()

In [9]:
samples = methylation_data.index.intersection(rnaseq_data.index)
methylation_data = methylation_data.loc[samples]
rnaseq_data = rnaseq_data.loc[samples]
assert methylation_data.index.equals(rnaseq_data.index)
clinical_data = clinical_data.loc[samples]
assert methylation_data.index.equals(clinical_data.index)
Xs= [rnaseq_data, methylation_data]
print("common samples:", len(samples))

In [10]:
date = "2023070315"
with open(os.path.join(settings.optimization_path, f'optimization_optuna_{date}.pkl'), 'rb') as file:
    optimization_study = dill.load(file)

In [11]:
optimization_results = pd.read_csv(os.path.join(settings.optimization_path, f"optimization_results_{date}.csv"))
best_trial = optimization_results.iloc[0]
print("optimization_results.shape", optimization_results.shape)
optimization_results.head()

In [70]:
param_importances = optuna.importance.get_param_importances(optimization_study)

In [27]:
max_value = optimization_results["value"].min()
max_values = []
for i in optimization_results.sort_values(by= "number")["value"]:
    if i > max_value:
        max_value = i
    max_values.append(max_value)

In [79]:
date = "2023070315"
with open(os.path.join(settings.optimization_path, f'optimization_optuna_{date}.pkl'), 'rb') as file:
    optimization_study = dill.load(file)

fig, axes = plt.subplots(1,2, figsize= (20, 3))

ax= axes[0]
optimization_results.plot.scatter(x='number', y='value', ax= ax, s= 3, xlabel= "Trial", ylabel= "Silhouette Score")
pd.Series(max_values).plot(ax= ax, c= "red")

ax= axes[1]
pd.Series(param_importances).plot.bar(ax= ax, xlabel= "Hyperparameter", ylabel= "Importance for Silhouette Score", rot= 4top25_neuron10_fi.png)
_ = ax.set_xticklabels(["Number of clusters", "Number of epochs", "Latent space", "Inverse proportion of units", "Lambda coefficient", "Number of layers"])

In [19]:
fig, axes = plt.subplots(1,7, sharey=True, figsize= (20, 3), gridspec_kw = {"wspace": 0.1})

for i, col in enumerate(["params_n_clusters", "params_n_epochs", "params_latent_space", "params_divisor_units",
                         "params_lambda_coeff", "params_num_layers", "params_features_per_component"]):
    ax = axes[i]
    if col == "params_lambda_coeff":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Lambda coefficient",
                                          logx= True, xticks= [0.001, 0.01, 0.1, 1], c= "number", colorbar= False, cmap = 'viridis')
    elif col == "params_features_per_component":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Features per component",
                                          xticks= [1,4,7,10], c= "number", colorbar= True, cmap = 'viridis')
        ax.collections[-1].colorbar.set_label("Trial")
        # ax.set_cmap('viridis')
    elif col == "params_n_clusters":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Number of clusters",
                                          xticks= list(range(2,7)), c= "number", colorbar= False, cmap = 'viridis')
    elif col == "params_n_epochs":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Number of epochs",
                                          xticks= [20,40,60,80,100], c= "number", colorbar= False, cmap = 'viridis')
    elif col == "params_num_layers":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Number of layers",
                                          xticks= [1,2], c= "number", colorbar= False, cmap = 'viridis')
        ax.set_xlim(0.5,2.5)
    elif col == "params_latent_space":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Latent space",
                                          c= "number", colorbar= False, cmap = 'viridis')
    elif col == "params_divisor_units":
        optimization_results.plot.scatter(x= col, y= "value", s= 7, ax=ax, ylabel= "Silhouette score", xlabel= "Inverse proportion of units",
                                          c= "number", colorbar= False, cmap = 'viridis')

_ = ax.set_ylim(0,0.31)

In [504]:
fig, axes = plt.subplots(1, 4, figsize= (20, 4))
ax, metric = axes[0], "au"
metric_index = ["training", "validation", "testing"]
metric_result = pd.DataFrame(optimization_results.iloc[0][[f"user_attrs_train_{metric}_loss_list", f"user_attrs_val_{metric}_loss_list",
                                           f"user_attrs_test_{metric}_loss_list"]].apply(eval).to_list(), index= metric_index).T
metric_result.boxplot(grid=False, showmeans= True, ax= ax)
bottom, top = ax.get_ylim()
y_range = top - bottom
bar_height = (y_range * 0.07 * 3) + top
bar_tips = bar_height - (y_range * 0.02)
text_height = bar_height + (y_range * 0.01)
for idx,met in enumerate(metric_index[:2]):
    pv = wilcoxon(metric_result[met], metric_result[metric_result.columns[idx + 1]]).pvalue
    ax.plot([idx+1, idx+1, idx+2, idx+2], [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k')
    ax.text((idx + idx + 3) * 0.5, text_height, 'p={:.3g}'.format(pv), ha='center', c='k')
ax.set_ylim(0, 1)
ax.axhline(optimization_results[f"user_attrs_test_{metric}_loss"].max(), c= "r", linestyle= '--')
ax.set_ylabel('MAE')

ax, metric = axes[1], "dist"
metric_result = pd.DataFrame(optimization_results.iloc[0][[f"user_attrs_train_{metric}_loss_list", f"user_attrs_val_{metric}_loss_list",
                                           f"user_attrs_test_{metric}_loss_list"]].apply(eval).to_list(), index= metric_index).T
metric_result.boxplot(grid=False, showmeans= True, ax= ax)
bottom, top = ax.get_ylim()
y_range = top - bottom
bar_height = (y_range * 0.07 * 3) + top
bar_tips = bar_height - (y_range * 0.02)
text_height = bar_height + (y_range * 0.01)
for idx,met in enumerate(metric_index[:2]):
    pv = wilcoxon(metric_result[met], metric_result[metric_result.columns[idx + 1]]).pvalue
    ax.plot([idx+1, idx+1, idx+2, idx+2], [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k')
    ax.text((idx + idx + 3) * 0.5, text_height, 'p={:.3g}'.format(pv), ha='center', c='k')
ax.axhline(optimization_results[optimization_results["params_n_clusters"] == 2][f"user_attrs_test_{metric}_loss"].mean(), c= "r", linestyle= '--')
ax.set_ylabel('Inertia')

ax, metric = axes[2], "total"
metric_result = pd.DataFrame(optimization_results.iloc[0][[f"user_attrs_train_{metric}_loss_list", f"user_attrs_val_{metric}_loss_list",
                                           f"user_attrs_test_{metric}_loss_list"]].apply(eval).to_list(), index= metric_index).T
metric_result.boxplot(grid=False, showmeans= True, ax= ax)
bottom, top = ax.get_ylim()
y_range = top - bottom
bar_height = (y_range * 0.07 * 3) + top - 1
bar_tips = bar_height - (y_range * 0.02)
text_height = bar_height + (y_range * 0.01)
for idx,met in enumerate(metric_index[:2]):
    pv = wilcoxon(metric_result[met], metric_result[metric_result.columns[idx + 1]]).pvalue
    ax.plot([idx+1, idx+1, idx+2, idx+2], [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k')
    ax.text((idx + idx + 3) * 0.5, text_height, 'p={:.3g}'.format(pv), ha='center', c='k')
ax.axhline(optimization_results[optimization_results["params_lambda_coeff"].between(0.012, 0.016)][f"user_attrs_test_{metric}_loss"].max(), c= "r", linestyle= '--')
ax.set_ylabel('Joint loss function')


ax, metric = axes[3], "silhscore"
metric_result = pd.DataFrame(optimization_results.iloc[0][[f"user_attrs_train_{metric}_list", f"user_attrs_val_{metric}_list",
                                           f"user_attrs_test_{metric}_list"]].apply(eval).to_list(), index= metric_index).T
metric_result.boxplot(grid=False, showmeans= True, ax= ax)
bottom, top = ax.get_ylim()
y_range = top - bottom
bar_height = (y_range * 0.07 * 3) + top
bar_tips = bar_height - (y_range * 0.02)
text_height = bar_height + (y_range * 0.01)
for idx,met in enumerate(metric_index[:2]):
    pv = wilcoxon(metric_result[met], metric_result[metric_result.columns[idx + 1]]).pvalue
    ax.plot([idx+1, idx+1, idx+2, idx+2], [bar_tips, bar_height, bar_height, bar_tips], lw=1, c='k')
    ax.text((idx + idx + 3) * 0.5, text_height, 'p={:.3g}'.format(pv), ha='center', c='k')
ax.set_ylabel('Silhouette score')
ax.set_ylim(0, 1)
_ = ax.axhline(optimization_results[f"user_attrs_test_{metric}"].mean(), c= "r", linestyle= '--')

In [10]:
# transformed_Xs = transform_full_dataset(Xs=Xs, fit_pipelines = True, results_folder = settings.results_path, 
#                                         features_per_component = optimization_study.best_params["features_per_component"],
#                                         optimization_folder = settings.optimization_path)
transformed_Xs = transform_full_dataset(Xs=Xs, fit_pipelines = False, results_folder = settings.results_path)
for transformed_X in transformed_Xs:
    print("transformed_X.shape", transformed_X.shape)
    display(transformed_X.head())

In [11]:
batch_size = int(best_trial["user_attrs_batch_size"])
full_data = MultiViewDataset(Xs=transformed_Xs)
full_dataloader = DataLoader(dataset=full_data, batch_size=batch_size, shuffle=True)
with isolate_rng():
    total_results = Optimization().training(n_clusters = optimization_study.best_params["n_clusters"], latent_space=optimization_study.best_params["latent_space"],
                                            in_channels_list = eval(best_trial["user_attrs_num_features"]),
                                            hidden_channels_list = [view_hidden[1:] for view_hidden in eval(best_trial["user_attrs_num_units"])], 
                                            train_dataloader = full_dataloader, val_dataloader = full_dataloader, test_dataloader = full_dataloader,
                                            n_epochs = optimization_study.best_params["n_epochs"],
                                            log_every_n_steps = np.ceil(len(full_data) / batch_size).astype(int),
                                            lambda_coeff = optimization_study.best_params["lambda_coeff"])

In [12]:
model = total_results["model"]
model_path = os.path.join("outputs", "models", "model_state_dict.pt")
torch.save(model.state_dict(), model_path)
model_path = os.path.join("outputs", "models", "model.pt")
torch.save(model, model_path)
model = model.eval()

In [13]:
full_data = MultiViewDataset(Xs=[X_.loc[clinical_data.index] for X_ in transformed_Xs])
full_dataloader = DataLoader(dataset=full_data, batch_size=batch_size, shuffle=False)
with torch.no_grad():
    z_full = torch.vstack([model.autoencoder.encode(batch) for batch in full_dataloader])
    clusters = model.predict_cluster_from_embedding(z_full).detach().cpu().numpy()
z_full = pd.DataFrame(z_full)
clinical_data["clusters"] = clusters
ax = clinical_data["clusters"].value_counts().plot(kind="bar", title= "Count of samples in clusters", x= "clusters", ylabel= "Number of samples")
for container in ax.containers:
    ax.bar_label(container)

In [344]:
clustering_statistical_table = pd.DataFrame([], columns= ["Deep clustering", "K-Means", "HC"], index= ["Overall Survival", "Diagnosis age", "AJCC tumor stage",
                                                                                                       "AJCC metastasis stage", "AJCC neoplasm histologic grade", "Sex"])
clustering_statistical_table

In [384]:
logrank_test = multivariate_logrank_test(event_durations= clinical_data["Overall Survival (Months)"], groups= clinical_data["clusters"])
logrank_test

In [480]:
ax = plt.subplot(111)
ax.text(0, 0, f"p= {round(logrank_test.p_value, 3)}")
clustering_statistical_table.loc["Overall Survival", "Deep clustering"] = logrank_test.p_value
kmfs = []
colors = ['y', 'm']
for cluster in sorted(clinical_data["clusters"].unique()):
    duration = clinical_data["Overall Survival (Months)"][clinical_data["clusters"] == cluster]
    kmf = KaplanMeierFitter().fit(duration, label = str(cluster))
    kmfs.append(kmf)
    ax = kmf.plot(ax=ax)

add_at_risk_counts(*kmfs, ax=ax)
ax.set_ylabel("Survival probability")
ax.set_xlabel("Timeline (months)")
for i,cl in enumerate(ax.legend().get_texts()):
    cl.set_text(f'Cluster {i}')
plt.tight_layout()

In [271]:
cph = CoxPHFitter()
cph.fit(clinical_data[['Overall Survival (Months)', "clusters"]].replace({"clusters": 0}, 2).replace({"clusters": 1}, 0).replace({"clusters": 2}, 1), duration_col = 'Overall Survival (Months)').print_summary()
_ = cph.plot()

In [352]:
clinical_label = "American Joint Committee on Cancer Tumor Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
clinical_label = "AJN Tumor stage"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['clusters'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['clusters'])).pvalue
clustering_statistical_table.loc["AJCC tumor stage", "Deep clustering"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["clusters"] = clinical_data['clusters']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "clusters", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [353]:
clinical_label = "Neoplasm Histologic Grade"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['clusters'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['clusters'])).pvalue
clustering_statistical_table.loc["AJCC neoplasm histologic grade", "Deep clustering"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["clusters"] = clinical_data['clusters']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "clusters", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [356]:
clinical_label = "Sex"
clinical_parameter = clinical_data[clinical_label]
crosstab = pd.crosstab(clinical_parameter, clinical_data['clusters'])
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['clusters'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
clinical_parameter = pd.get_dummies(clinical_data[clinical_label], drop_first=True).astype(int).squeeze()
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['clusters'])).pvalue
clustering_statistical_table.loc["Sex", "Deep clustering"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["clusters"] = clinical_data['clusters']
sns.countplot(data= clinical_data, x= clinical_label, hue= "clusters", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [357]:
clinical_label = "American Joint Committee on Cancer Metastasis Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1]
clinical_parameter = clinical_parameter[clinical_parameter != "X"]
clinical_parameter = clinical_parameter.astype(float).dropna().astype(int)
clinical_label = "AJN Cancer Metastasis"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['clusters'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['clusters'])).pvalue
clustering_statistical_table.loc["AJCC metastasis stage", "Deep clustering"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["clusters"] = clinical_data['clusters']
ax= sns.countplot(data= clinical_parameter, x= clinical_label, hue= "clusters")
ax.set_title(f"p-value= {pval}", fontsize= 10)
for container in ax.containers:
    ax.bar_label(container)

In [358]:
clinical_label = "Diagnosis Age"
clinical_parameter = clinical_data[clinical_label].dropna()
pval = kruskal(*[clinical_parameter[clinical_data['clusters'] == cl] for cl in clinical_data['clusters'].unique()]).pvalue
clustering_statistical_table.loc["Diagnosis age", "Deep clustering"] = pval
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["clusters"] = clinical_data['clusters']
ax = clinical_parameter.boxplot(column= clinical_label, by= "clusters", grid=False)
_ = ax.text(1.3, 50, f"p= {round(pval, 3)}")

In [265]:
fig, axes = plt.subplots(1, 4, figsize= (25, 4))

ax = sns.violinplot(data= pd.DataFrame([methylation_data.loc[clinical_data[clinical_data["clusters"] == cl].index, transformed_Xs[1].columns].values.flatten() for cl in sorted(clinical_data["clusters"].unique())],
                                   index = [cl for cl in sorted(clinical_data["clusters"].unique())]).T,
                    orient= "h", ax = axes[0])
ax.set_title("All methylation values by cluster")

ax = sns.violinplot(data= pd.DataFrame([methylation_data.loc[clinical_data[clinical_data["clusters"] == cl].index, transformed_Xs[1].columns].mean(0) for cl in sorted(clinical_data["clusters"].unique())],
                                   index = [cl for cl in sorted(clinical_data["clusters"].unique())]).T,
                    orient= "h", ax = axes[1])
ax.set_title("Mean methylation values by cluster")

ax = sns.violinplot(data= pd.DataFrame([rnaseq_data.loc[clinical_data[clinical_data["clusters"] == cl].index, transformed_Xs[0].columns].apply(lambda x: np.log2(1 + x)).values.flatten() for cl in sorted(clinical_data["clusters"].unique())],
                                   index = [cl for cl in sorted(clinical_data["clusters"].unique())]).T,
                    orient= "h", ax = axes[2])
ax.set_title("All RNA-seq values by cluster")

ax = sns.violinplot(data= pd.DataFrame([rnaseq_data.loc[clinical_data[clinical_data["clusters"] == cl].index, transformed_Xs[0].columns].apply(lambda x: np.log2(1 + x)).mean(0) for cl in sorted(clinical_data["clusters"].unique())],
                                   index = [cl for cl in sorted(clinical_data["clusters"].unique())]).T,
                    orient= "h", ax = axes[3])
= ax.set_title("Mean RNA-seq values by cluster")

In [22]:
transformed_X = pd.concat(transformed_Xs, axis = 1)
print("transformed_X.shape", transformed_X.shape)
transformed_X.head()

In [250]:
from sklearn.cluster import AgglomerativeClustering

preds_hc = AgglomerativeClustering(n_clusters= 2).fit_predict(transformed_X)
clinical_data["preds_hc"] = preds_hc
ax = clinical_data["preds_hc"].value_counts().plot(kind="bar", title= "Count of samples in clusters", x= "preds_hc", ylabel= "Number of samples")
for container in ax.containers:
    ax.bar_label(container)

In [359]:
logrank_test = multivariate_logrank_test(event_durations= clinical_data["Overall Survival (Months)"], groups= clinical_data["preds_hc"])
logrank_test

In [360]:
ax = plt.subplot(111)
ax.set_title("Survival plot")
ax.text(0, 0, f"p-value= {round(logrank_test.p_value, 3)}")
clustering_statistical_table.loc["Overall Survival", "HC"] = logrank_test.p_value
kmfs = []
for cluster in sorted(clinical_data["preds_hc"].unique()):
    duration = clinical_data["Overall Survival (Months)"][clinical_data["preds_hc"] == cluster]
    kmf = KaplanMeierFitter().fit(duration, label = str(cluster))
    kmfs.append(kmf)
    ax = kmf.plot(ax=ax)

add_at_risk_counts(*kmfs, ax=ax)
plt.tight_layout()

In [361]:
clinical_label = "American Joint Committee on Cancer Tumor Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
clinical_label = "AJN Tumor stage"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_hc'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_hc'])).pvalue
clustering_statistical_table.loc["AJCC tumor stage", "HC"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_hc"] = clinical_data['preds_hc']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_hc", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [362]:
clinical_label = "Neoplasm Histologic Grade"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_hc'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_hc'])).pvalue
clustering_statistical_table.loc["AJCC neoplasm histologic grade", "HC"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_hc"] = clinical_data['preds_hc']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_hc", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [363]:
clinical_label = "Sex"
clinical_parameter = clinical_data[clinical_label]
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_hc'])
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_hc'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
clinical_parameter = pd.get_dummies(clinical_data[clinical_label], drop_first=True).astype(int).squeeze()
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_hc'])).pvalue
clustering_statistical_table.loc["Sex", "HC"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_hc"] = clinical_data['preds_hc']
sns.countplot(data= clinical_data, x= clinical_label, hue= "preds_hc", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [364]:
clinical_label = "American Joint Committee on Cancer Metastasis Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1]
clinical_parameter = clinical_parameter[clinical_parameter != "X"]
clinical_parameter = clinical_parameter.astype(float).dropna().astype(int)
clinical_label = "AJN Cancer Metastasis"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_hc'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_hc'])).pvalue
clustering_statistical_table.loc["AJCC metastasis stage", "HC"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_hc"] = clinical_data['preds_hc']
ax= sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_hc")
ax.set_title(f"p-value= {pval}", fontsize= 10)
for container in ax.containers:
    ax.bar_label(container)

In [367]:
clinical_label = "Diagnosis Age"
clinical_parameter = clinical_data[clinical_label].dropna()
pval = kruskal(*[clinical_parameter[clinical_data['preds_hc'] == cl] for cl in clinical_data['preds_hc'].unique()]).pvalue
clustering_statistical_table.loc["Diagnosis age", "HC"] = pval
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_hc"] = clinical_data['preds_hc']
ax = clinical_parameter.boxplot(column= clinical_label, by= "preds_hc", grid=False)
_ = ax.text(1.3, 50, f"p-value= {round(pval, 3)}")

In [272]:
from sklearn.cluster import KMeans

preds_kmeans = KMeans(n_clusters= 2, random_state= 42).fit_predict(transformed_X)
clinical_data["preds_kmeans"] = preds_kmeans
ax = clinical_data["preds_kmeans"].value_counts().plot(kind="bar", title= "Count of samples in clusters", x= "preds_kmeans", ylabel= "Number of samples")
for container in ax.containers:
    ax.bar_label(container)

In [373]:
logrank_test = multivariate_logrank_test(event_durations= clinical_data["Overall Survival (Months)"], groups= clinical_data["preds_kmeans"])
logrank_test

In [374]:
ax = plt.subplot(111)
ax.set_title("Survival plot")
ax.text(0, 0, f"p-value= {round(logrank_test.p_value, 3)}")
clustering_statistical_table.loc["Overall Survival", "K-Means"] = logrank_test.p_value
kmfs = []
for cluster in sorted(clinical_data["preds_kmeans"].unique()):
    duration = clinical_data["Overall Survival (Months)"][clinical_data["preds_kmeans"] == cluster]
    kmf = KaplanMeierFitter().fit(duration, label = str(cluster))
    kmfs.append(kmf)
    ax = kmf.plot(ax=ax)

add_at_risk_counts(*kmfs, ax=ax)
plt.tight_layout()

In [375]:
clinical_label = "American Joint Committee on Cancer Tumor Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
clinical_label = "AJN Tumor stage"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])).pvalue
clustering_statistical_table.loc["AJCC tumor stage", "K-Means"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_kmeans"] = clinical_data['preds_kmeans']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_kmeans", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [376]:
clinical_label = "Neoplasm Histologic Grade"
clinical_parameter = clinical_data[clinical_label].str[-1].astype(float).dropna().astype(int)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])).pvalue
clustering_statistical_table.loc["AJCC neoplasm histologic grade", "K-Means"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_kmeans"] = clinical_data['preds_kmeans']
sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_kmeans", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [377]:
clinical_label = "Sex"
clinical_parameter = clinical_data[clinical_label]
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
clinical_parameter = pd.get_dummies(clinical_data[clinical_label], drop_first=True).astype(int).squeeze()
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])).pvalue
clustering_statistical_table.loc["Sex", "K-Means"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_kmeans"] = clinical_data['preds_kmeans']
sns.countplot(data= clinical_data, x= clinical_label, hue= "preds_kmeans", ax= axes[1])
for container in axes[1].containers:
    axes[1].bar_label(container)

In [378]:
clinical_label = "American Joint Committee on Cancer Metastasis Stage Code"
clinical_parameter = clinical_data[clinical_label].str[-1]
clinical_parameter = clinical_parameter[clinical_parameter != "X"]
clinical_parameter = clinical_parameter.astype(float).dropna().astype(int)
clinical_label = "AJN Cancer Metastasis"
clinical_parameter = clinical_parameter.rename(clinical_label)
fig, axes = plt.subplots(1, 2, figsize= (12, 4))
crosstab = pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])
sns.heatmap(crosstab * 100 / crosstab.sum(0), annot=True, fmt=".1f", ax= axes[0])
pval = chi2_contingency(pd.crosstab(clinical_parameter, clinical_data['preds_kmeans'])).pvalue
clustering_statistical_table.loc["AJCC metastasis stage", "K-Means"] = pval
axes[0].set_title(f"p-value= {round(pval, 3)}", fontsize= 10)
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_kmeans"] = clinical_data['preds_kmeans']
ax= sns.countplot(data= clinical_parameter, x= clinical_label, hue= "preds_kmeans")
ax.set_title(f"p-value= {pval}", fontsize= 10)
for container in ax.containers:
    ax.bar_label(container)

In [379]:
clinical_label = "Diagnosis Age"
clinical_parameter = clinical_data[clinical_label].dropna()
pval = kruskal(*[clinical_parameter[clinical_data['preds_kmeans'] == cl] for cl in clinical_data['preds_kmeans'].unique()]).pvalue
clustering_statistical_table.loc["Diagnosis age", "K-Means"] = pval
clinical_parameter = clinical_parameter.to_frame()
clinical_parameter["preds_kmeans"] = clinical_data['preds_kmeans']
ax = clinical_parameter.boxplot(column= clinical_label, by= "preds_kmeans", grid=False)
_ = ax.text(1.3, 50, f"p-value= {round(pval, 3)}")

In [383]:
print(clustering_statistical_table.to_latex(float_format="{:.3f}".format))

In [23]:
method = make_pipeline(StandardScaler(), PCA(n_components = 2, random_state = settings.RANDOM_STATE))
plot_data = method.fit_transform(z_full)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], color = clinical_data["clusters"].astype(str))
fig.show()

In [24]:
method = make_pipeline(StandardScaler(), PCA(n_components = 2, random_state = settings.RANDOM_STATE))
plot_data = method.fit_transform(z_full)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], color = clinical_data["Overall Survival (Months)"], color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [25]:
method = make_pipeline(StandardScaler(), PCA(n_components = 3, random_state = settings.RANDOM_STATE))
plot_data = method.fit_transform(z_full)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2], color = clusters.astype(str))
fig.show()

In [26]:
method = make_pipeline(StandardScaler(), PCA(n_components = 3, random_state = settings.RANDOM_STATE))
plot_data = method.fit_transform(z_full)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2], color = clinical_data["Overall Survival (Months)"],
                    color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [27]:
method = PCA(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], 
                 color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [28]:
method = PCA(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], color = clinical_data["Overall Survival (Months)"], color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [29]:
method = TSNE(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1],
                 color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [30]:
method = TSNE(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], color = clinical_data["Overall Survival (Months)"], color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [31]:
method = UMAP(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1],
                 color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [32]:
method = UMAP(n_components = 2, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter(x = plot_data[:, 0], y = plot_data[:, 1], color = clinical_data["Overall Survival (Months)"], color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [33]:
method = PCA(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2],
                    color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [34]:
method = PCA(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2], color = clinical_data["Overall Survival (Months)"], 
                    color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [35]:
method = TSNE(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2],
                    color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [36]:
method = TSNE(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2], color = clinical_data["Overall Survival (Months)"], 
                    color_continuous_scale=px.colors.sequential.Bluered)
fig.show()

In [37]:
method = UMAP(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2],
                    color = make_pipeline(StandardScaler(), KMeans(n_clusters=optimization_study.best_params["n_clusters"], random_state = settings.RANDOM_STATE)).fit_predict(plot_data).astype(str))
fig.show()

In [38]:
method = UMAP(n_components = 3, random_state = settings.RANDOM_STATE)
plot_data = method.fit_transform(transformed_X)
fig = px.scatter_3d(x = plot_data[:, 0], y = plot_data[:, 1], z = plot_data[:, 2], color = clinical_data["Overall Survival (Months)"],
                    color_continuous_scale=px.colors.sequential.Bluered)
fig.show()