# Dependencies and parameters

In [None]:
import pandas as pd
import numpy as np
import umap
import hdbscan
from pandas.api.types import is_string_dtype, is_numeric_dtype
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OrdinalEncoder
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.feature_selection import VarianceThreshold
from sklearn.cluster import KMeans
from sklearn.metrics import make_scorer
import matplotlib.pyplot as plt
import matplotlib.ticker as plticker
import matplotlib.font_manager
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import itertools
import warnings
import scipy
from scipy import stats                  
import statsmodels.sandbox.stats.multicomp as mc
import statsmodels.stats.multitest as mt
from clusteval import clusteval
from scipy.stats import hypergeom
import scikit_posthocs as sp
import prince
from psmpy import PsmPy
from tableone import TableOne, load_dataset
from yellowbrick.cluster import KElbowVisualizer
from ipynb.fs.full.txt_to_analysis import txt_to_analysis

In [None]:
# Import Arial fonts
import matplotlib.font_manager as font_manager

# Add every font at the specified location
font_dir = [r"/home/pps21@isd.csc.mrc.ac.uk/pps21/arial"]
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)

In [None]:
# Table display
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)

# Plot output format
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('png')
# matplotlib_inline.backend_inline.set_matplotlib_formats('pdf') # uncomment for pdf version in notebook

# Plot settings
disc_color = px.colors.qualitative.Plotly
cont_color = px.colors.sequential.Turbo
plt.rcParams['axes.facecolor'] = 'none'
plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True
plt.rcParams['font.family'] = 'arial'
plt.rcParams['figure.dpi'] = 300

# Warning display
warnings.filterwarnings(action='ignore')
warnings.simplefilter(action='ignore')

# Random seed
np.random.seed(1)

# Clinical features

## Original dataset

In [None]:
# Clinical features - non imputed
HCM_METADATA = r"/home/pps21@isd.csc.mrc.ac.uk/pps21/cardiac-dimred/data/raw/rbh_superset_filamentinfo.csv"
df_hcm = pd.read_csv(HCM_METADATA, encoding="ISO-8859-1")
df_hcm = df_hcm.rename(columns={'ID':'patient_id'})
df_hcm = df_hcm.drop_duplicates(subset='patient_id')
df_hcm = df_hcm.sort_values(by='patient_id').reset_index(drop=True)
df_hcm = df_hcm.set_index('patient_id')
df_hcm = df_hcm[df_hcm.genotype.isin(['PLP','NEG','VUS'])]

# Wrangle categorical data
dic_123 = {1:True,2:False,3:np.nan}
var_123 = []
var_15 = []
for field in df_hcm.columns:
    if (set(df_hcm[field].dropna().unique())==set([1,2,3])):
        df_hcm[field] = df_hcm[field].replace(dic_123)
        var_123.append(field)
    elif (set(df_hcm[field].dropna().unique())==set([1,2,3,4,5])):
        df_hcm[field] = df_hcm[field].replace({5:np.nan})
        var_15.append(field)
    else:
        True
print(f'Variables processed with 123 rule: {var_123}')
print(f'Variables processed with 1-5 rule: {var_15}')

dic_smoking = {1:True,2:False,3:np.nan,4:'Ex'}
dic_aff_region = {0:np.nan, 1:'Septal', 2:'Anterior', 3:'Lateral', 4:'Inferior'}
dic_aff_level = {0:np.nan, 1:'Base', 2:'Mid', 3:'Apex'}
dic_sev = {1: 'None', 2: 'Minimal', 3: 'Moderate', 4: 'Severe'}
df_hcm['Gen.Smoking'] = df_hcm['Gen.Smoking'].replace(dic_smoking)
df_hcm['Hcm.Lvmostaffectedsegment'] = df_hcm['Hcm.Lvmostaffectedsegment'].replace(dic_aff_region)
df_hcm['Hcm.Mostaffectedlevel'] = df_hcm['Hcm.Mostaffectedlevel'].replace(dic_aff_level)
df_hcm['Hcm.Mitralregurgitation'] = df_hcm['Hcm.Mitralregurgitation'].replace(dic_sev)
df_hcm['Hcm.Lvgadolinum'] = df_hcm['Hcm.Lvgadolinum'].replace(dic_sev)

# Wrangle continuous variables
non_null_measurements = ['age_at_scan',
                         'Hcm.Bsa',
                         'Gen.Height',
                         'Gen.Weight',
                         'Gen.Pulserate',
                         'Gen.Systolic',
                         'Gen.Diastolic',
                         'Hcm.Edv',
                         'Hcm.Esv',
                         'Hcm.Sv',
                         'Hcm.Ef',
                         'Hcm.Lvm',
                         'Hcm.Maxlvwallthickness',
                         'Hcm.Lvotpeakvelocity']
df_hcm[non_null_measurements] = df_hcm[non_null_measurements].replace(0,np.nan)

# Ordinal to category
ord_var = ['Gen.Activityscore', 'Gen.Ccs', 'Gen.Nyha', 
          'Gen.Smoking']
for var in ord_var:
    df_hcm[var] = df_hcm[var].astype('category')
    
# All non-numerical to object
object_cols = df_hcm.select_dtypes(exclude=np.number)
df_hcm[object_cols.columns] = df_hcm[object_cols.columns].astype('object')

df_org = df_hcm.copy(deep=True)
df_hcm = pd.DataFrame()

print(df_org.shape)
df_org.head(5)

In [None]:
# Order features
df_org = df_org[["age_at_scan", "race", "Hcm.Bsa", "sex", "Gen.Diastolic", "Gen.Systolic", "Gen.Pulserate", "Gen.Ht", "Gen.Dm", "Gen.Smoking", "Gen.Alcohol", 
                 "Gen.Activityscore", "Gen.Cad", "Gen.Mi", "Hcm.Familyhistoryofhcm", "Hcm.Familyhistoryofscd", "Gen.Ccs", "Gen.Nyha", "Hcm.Edv", "Hcm.Esv", 
                 "Hcm.Sv", "Hcm.Ef", "Hcm.Lvm", "Hcm.Maxlvwallthickness", "Hcm.Lvgadolinum", "Hcm.Lvmostaffectedsegment", "Hcm.Mitralregurgitation", 
                 "Hcm.Mostaffectedlevel", "Hcm.Rvhypertrophy", "Hcm.Coincidentinfarction", "Hcm.Lvoto", "Hcm.Lvotpeakvelocity", "Gen.Acearb", "Gen.Asaclopi", 
                 "Gen.Betablocker", "Gen.Diuretic", "Gen.Cabg", "Gen.Pci", "Hcm.Perfusiondeficit"]]

In [None]:
# Tableone for paper
mytable = TableOne(df_org, 
                   # columns=columns, categorical=categorical, groupby=groupby, nonnormal=nonnormal, rename=labels, pval=False
                  )
print(mytable.tabulate(tablefmt = "latex"))

## Imputed dataset (with mice package)

In [None]:
df_hcm = pd.read_csv(r"/home/pps21@isd.csc.mrc.ac.uk/pps21/cardiac-dimred/data/raw/metadata_imputed.csv", index_col=0)
df_hcm.index.names = ['patient_id']

In [None]:
df_hcm['plp'] = np.nan
df_hcm.loc[df_hcm.genotype=='PLP','plp'] = True
df_hcm['plp'] = df_hcm['plp'].fillna(False)

In [None]:
# Wrangle categorical variables

# Remove patients with a lot of missing information
dic_123 = {1:True,2:False,3:np.nan}
var_123 = []
var_15 = []
for field in df_hcm.columns:
    if (set(df_hcm[field].dropna().unique())==set([1,2])):
        df_hcm[field] = df_hcm[field].replace(dic_123)
        var_123.append(field)
    elif (set(df_hcm[field].dropna().unique())==set([1,2,3,4])):
        df_hcm[field] = df_hcm[field].replace({5:np.nan})
        var_15.append(field)
    else:
        True
print(f'Variables processed with 123 rule: {var_123}')
print(f'Variables processed with 1-5 rule: {var_15}')

In [None]:
# Convert ordinal variables to category
ord_var = ['Gen.Activityscore', 'Gen.Ccs', 'Gen.Nyha', 'Hcm.Lvgadolinum', 'Hcm.Mitralregurgitation',
          'Gen.Smoking']
for var in ord_var:
    df_hcm[var] = df_hcm[var].astype('category')

In [None]:
print(df_hcm.shape)
df_hcm.head(5)

In [None]:
# Outcome dataframe
list_outcome = ['genotype', 'Deceased', 'filaments', 'type', 'plp']
df_outcome = df_hcm[list_outcome]
df_hcm = df_hcm.drop(columns=list_outcome)

In [None]:
# Convert all non-numerical variables to object type
object_cols = df_hcm.select_dtypes(exclude=np.number)
df_hcm[object_cols.columns] = df_hcm[object_cols.columns].astype('object')

# Process mixed-type data

## FAMD (not used)

In [None]:
famd = prince.FAMD(n_components=15, n_iter=3,
                   copy=True, check_input=True,
                   engine='auto',random_state=0)

# famd = famd.fit(df_hcm)
# df_famd = famd.row_coordinates(df_hcm)
# df_famd.columns = [f'umap_{int(x)+1}' for x in df_famd.columns]

In [None]:
# df_famd.head(5)

## UMAP with one-hot encoding and normalisation (not used)

In [None]:
def calculate_zscore(df, columns):
  '''
  scales columns in dataframe using z-score
  '''
  df = df.copy()
  for col in columns:
      df[col] = (df[col] - df[col].mean())/df[col].std(ddof=0)

  return df


def one_hot_encode(df, columns):
  '''
  one hot encodes list of columns and
  concatenates them to the original df
  '''

  concat_df = pd.concat([pd.get_dummies(df[col], drop_first=True, prefix=col) for col in columns], axis=1)
  one_hot_cols = concat_df.columns

  return concat_df, one_hot_cols

# df = df_hcm.copy(deep=True)
# numeric_cols = df.select_dtypes(include=np.number)
# cat_cols = df.select_dtypes(include='object')
  
# # numeric process
# normalized_df = calculate_zscore(df, numeric_cols)
# normalized_df = normalized_df[numeric_cols.columns]

# # categorical process
# cat_one_hot_df, one_hot_cols = one_hot_encode(df, cat_cols)
# cat_one_hot_norm_df = calculate_zscore(cat_one_hot_df, one_hot_cols)

# # Merge DataFrames
# processed_df = pd.concat([normalized_df, cat_one_hot_norm_df], axis=1)

In [None]:
# processed_df.head(5)

## UMAP by transforming all to categorical

In [None]:
def get_knn_bins(df, cols, bins=5, drop_cols=True):
  
  for col in cols:
    
    kmeans  = KMeans(n_clusters=bins).fit(df[col].to_frame().values.reshape(-1,1))
    results = pd.DataFrame(kmeans.labels_, columns=[col + '_centroid'])

    df = df.reset_index()
    df[col + '_centroid'] = results[col + '_centroid']
  
    knn_bin_df = pd.DataFrame(kmeans.cluster_centers_)
    knn_bin_df = knn_bin_df.astype(int).reset_index()

    temp_df = pd.merge(df[col + '_centroid'],
                       knn_bin_df, 
                       left_on=col + '_centroid',
                       right_on='index',
                       how='left')
    
    # rename empty column header 0 -> column_name value
    temp_df = temp_df.rename(columns={0:col+'_value'})

    temp_df.loc[:,col+'_value'] = col + '_' + temp_df[col+'_value'].astype(str)

    df = pd.concat([df, temp_df[col+'_value']], axis=1)
    df.drop([col + '_centroid', 'index'], axis=1, inplace=True)


#     fig, (ax1, ax2) = plt.subplots(1, 2)
    
#     ax1.hist(df[col].values, bins=50)
#     ax1.set_title('Histogram of ' + col)

#     # produce the second bar chart need the centroids and sort values to make bar chart
#     # resemble the original distribution.
#     Count_Bins = temp_df[col + '_value'].value_counts().rename_axis('bins').reset_index(name='counts')
#     extract_number = Count_Bins.bins.str.extract('([a-zA-Z]+)([^a-zA-Z]+)', expand=True)
#     extract_number.columns = ['Text', 'Number']
#     Count_Bins['bin_centroid'] = extract_number['Number']
#     Count_Bins.sort_values(by=['bins'], ascending=True, inplace=True)

#     # plot results
#     ax2.bar(Count_Bins.iloc[:, 0], Count_Bins.iloc[:, 1])
#     ax2.set_title('Bar Chart of ' + col + ' Binned')
#     ax2.tick_params(labelrotation=90)

#     plt.show()
#     plt.tight_layout()

  if drop_cols:
    return df.drop(cols, axis=1)

  else:
    return df

numeric_cols = df_hcm.select_dtypes(include=np.number)
df_hcm.index.names = ['index']
print('Numerical variables converted to categorical with bins:')
print(numeric_cols.columns.to_list())
recoded_df = get_knn_bins(df_hcm, numeric_cols, bins=5)
recoded_df.index = df_hcm.index
df_hcm.index.names = ['patient_id']

In [None]:
recoded_df.head(5)

In [None]:
# Convert all variables to binary with one hot encoding
def one_hot_encode(df, columns):
  '''
  one hot encodes list of columns and
  concatenates them to the original df
  '''

  concat_df = pd.concat([pd.get_dummies(df[col], drop_first=True, prefix=col) for col in columns], axis=1)
  one_hot_cols = concat_df.columns

  return concat_df, one_hot_cols

cat_cols = recoded_df.columns
df_one_hot, _ = one_hot_encode(recoded_df, cat_cols)

In [None]:
print(df_one_hot.shape)
df_one_hot.head(5)

# Dimensionality reduction with UMAP

In [None]:
df_all = df_one_hot.copy(deep=True)

## Impact of parameters: number of neighbors and minimum distance

In [None]:
# Explore impact of number of neighbors and minimum distance on UMAP1 VS UMAP 2 plot
PARAM = {
        'metric' : 'dice',
        'random_state' : 1,
        'transform_seed': 42
        }
list_n_neigh = [2, 3, 5, 8, 15, 30]
list_min_dist = [1, 0.5, 0.01, 0.0001, 0.000001, 0.00000001]
n_comp = 30

fig, axs = plt.subplots(6, 6, figsize=(90, 70))
for i, n_neigh in enumerate(list_n_neigh):
    for j, min_dist in enumerate(list_min_dist):
        reducer = umap.UMAP(n_neighbors  = int(n_neigh),
                            min_dist     = float(min_dist),
                            metric       = PARAM['metric'],
                            random_state = int(PARAM['random_state']),
                            n_components = n_comp,
                            transform_seed = int(PARAM['transform_seed']))

        try:
            proj = reducer.fit_transform(df_all.iloc[:,:].reset_index(drop=True))
        except Exception as e:
            print(e)
            continue
        df_umap = pd.DataFrame(data=proj, columns=[f'umap_{x+1}' for x in range(n_comp)])
        df_umap.index = df_all.index

        sns.scatterplot(data=df_umap, x='umap_1', y='umap_2', ax=axs[i,j])

for ax, col in zip(axs[0,:], list_min_dist):
    ax.set_title(col, size=60)
for ax, row in zip(axs[:,0], list_n_neigh):
    ax.set_ylabel(row, size=60)

## Impact of parameter: number of components

In [None]:
# Explore impact of number of components on UMAP1 VS UMAP 2 plot
list_n_comp = [2, 3, 5, 8, 10, 15, 20, 25, 30, 50]
n_neighbors = 8
min_dist = 0.000001

list_df_umap = []
fig, axs = plt.subplots(2, 5, figsize=(50, 20))
for i, n_comp in enumerate(list_n_comp):
    reducer = umap.UMAP(n_neighbors  = n_neighbors,
                        min_dist     = min_dist,
                        metric       = PARAM['metric'],
                        random_state = int(PARAM['random_state']),
                        n_components = n_comp)
    
    #df_all = df_all[~df_all.index.isin(list_outlier)]
    proj = reducer.fit_transform(df_all.iloc[:,:].reset_index(drop=True))
    df_umap = pd.DataFrame(data=proj, columns=[f'umap_{x+1}' for x in range(n_comp)])
    df_umap.index = df_all.index
    list_df_umap.append(df_umap)
    
    sns.scatterplot(data=df_umap, x='umap_1', y='umap_2', ax=axs.flat[i])
    axs.flat[i].set_title(n_comp, size=20)

In [None]:
# Define definitive dataframe with UMAP components
ix = 7 # choose arbitrarily depending on the preferred 2D space layout
df_umap = list_df_umap[ix] # n_comp=15
plt.figure(figsize=(15, 10))
ax = sns.scatterplot(data=df_umap, x='umap_1', y='umap_2')
ax.set_title(f'UMAP 2D projection: no. neighbours={n_neighbors}, min dist={min_dist}, no. components={list_n_comp[ix]}',
           pad=20, fontsize=20)
plt.xlabel('UMAP 1', fontsize=12)
plt.ylabel('UMAP 2', fontsize=12)
plt.show(ax)

In [None]:
# Import UMAP coordinates if needed (reproducibility)
df_umap = pd.read_csv(r"/home/pps21@isd.csc.mrc.ac.uk/pps21/cardiac-dimred/data/raw/umap.csv", index_col=0)

# Clustering

## K-means

In [None]:
# Elbow method
df_kmeans = df_umap.copy(deep=True)
distortions = []
K = range(1,10)
for k in K:
    kmeanModel = KMeans(n_clusters=k)
    kmeanModel.fit(df_kmeans)
    distortions.append(kmeanModel.inertia_)
plt.figure(figsize=(16,8))
plt.plot(K, distortions, '-bx')
plt.xlabel('k')
plt.ylabel('Distortion')
plt.title('The elbow method showing the optimal k')
plt.show()

In [None]:
# Silhouette Score to find optimal number of clusters
model = KMeans()
# k is range of number of clusters.
df_kmeans = df_umap.copy(deep=True)
visualizer = KElbowVisualizer(model, k=(3,30),metric='silhouette', timings= True)
visualizer.fit(df_kmeans)
visualizer.show()
optimal_k = visualizer.elbow_value_

In [None]:
# Perform K-means
n_clusters = optimal_k
kmeanModel = KMeans(n_clusters=n_clusters, random_state=1)
kmeanModel.fit(df_kmeans)
df_kmeans['kmeans_label'] = kmeanModel.predict(df_umap)

df_kmeans['kmeans_label'] = df_kmeans['kmeans_label'] + 1

plt.figure(figsize=(16,8))
sns.scatterplot(data=df_kmeans, x='umap_1', y='umap_2', hue='kmeans_label', palette='tab10')
plt.title(f'K-Means clusters on UMAP plot (K = {n_clusters} clusters)')
plt.show()

# Exploration

In [None]:
# Merge cluster results and outcomes
df_exp = pd.merge(df_kmeans, df_outcome, how='inner', left_index=True, right_index=True)
df_exp.head(5)

## Visualisation

In [None]:
# Choose variable
var = 'genotype'

# Table: distribution of variable in each cluster
df_plot = df_exp.copy(deep=True)
df_plot['genotype'] = df_plot['genotype'].replace({'PLP': 'P/LP'})
freq_table = df_plot.groupby(['kmeans_label'])[var].value_counts(normalize=True)
print(freq_table)

# Table: distribution of variable in entire dataset
freq = df_plot[var].value_counts(normalize=True)
freq = freq.sort_index()
print(freq)

fig, axs = plt.subplots(1, 4, figsize=(40, 10))

# Bar plot: distribution of variable in each cluster
freq_table.unstack().plot(kind='bar', color=['green', 'red', 'orange'], stacked=True, ax=axs[0])

# Bar plot: distribution of variable in entire dataset
freq.to_frame().T.plot.bar(stacked=True, color=['green', 'red', 'orange'], ax=axs[1])

# Scatter plot: clusters in UMAP space
sns.scatterplot(data=df_plot, x='umap_1', y='umap_2', hue='kmeans_label', palette='Set2', ax=axs[2])

# Scatter plot: clusters in UMAP space labelled by genotype
sns.scatterplot(data=df_plot, x='umap_1', y='umap_2', hue='kmeans_label', palette='Set2', style='genotype', s=100, ax=axs[3])

# Plot settings
fig.suptitle('Distribution of genotype profiles per cluster', fontsize=25, y=0.95)

axs[0].set_xlabel('K-means cluster', fontsize=12)
axs[0].get_legend().remove()
axs[1].legend(fontsize=14)
axs[2].legend(fontsize=14)
axs[2].set_xlabel('UMAP 1', fontsize=12)
axs[2].set_ylabel('UMAP 2', fontsize=12)

plt.show(fig)

## Statistical testing

In [None]:
def chi_sq_test(df, col1, col2):

    table = pd.crosstab(df[col1], df[col2])
    col_sum = table.sum(axis=0)
    col_percents = table/col_sum

    res = stats.chi2_contingency(table)
    return res[0], res[1]

In [None]:
# Test for statistical differences with Kruskal-Wallis
table = []
for col in ['genotype']:
    try:
        if np.issubdtype(df_plot[col].dtype, np.number):
            statistics, pvalue = stats.kruskal(*[group[col].dropna().values for name, group in df_plot.groupby('kmeans_label')])
            comment = ''
        else: 
            statistics, pvalue = chi_sq_test(df_plot, 'kmeans_label', col)
            comment = ''
    except Exception as e:
        statistics = np.nan
        pvalue = np.nan
        comment = repr(e)
    table += [[col, statistics, pvalue, comment]]
df_kruskal = pd.DataFrame(data=table, columns=['field','stats','pvalue','comment'])
df_kruskal.pvalue = df_kruskal.pvalue.astype(float)
df_kruskal = df_kruskal.sort_values(by='pvalue',ascending=True)
df_kruskal = df_kruskal.dropna()
df_kruskal

In [None]:
# Post hoc test for cluster-specific differences (exact Fisher test)
var_list = ['genotype']
L = []
for var in var_list:
    value_list = df_plot[var].unique()
    for value in value_list:
        bmask = df_plot[var] == value
        cluster_labels = df_plot.kmeans_label
        for i in np.unique(cluster_labels):
            cl_mask = cluster_labels == i
            M = df_plot.shape[0]  # Total
            n = np.sum(bmask)   # white balls
            N = np.sum(cl_mask) # selected
            k = np.sum((bmask) & (cl_mask)) # selected and white
            pvalue = (1 - hypergeom.cdf(k-1,M,n,N,loc=0))
            L.append([var, value, i, pvalue])
df_pvalue = pd.DataFrame(data=L, columns=['var', 'value', 'cluster', 'pvalue'])
dfpvalue = df_pvalue.sort_values(by='pvalue',ascending=True)

# Correction with Benjamini-Hochberg
df_pvalue['pvalue_bh'] = mt.multipletests(df_pvalue.pvalue, alpha=0.05, method='fdr_bh')[1]
df_pvalue.sort_values(by='pvalue_bh', ascending=True)

## Final plot for paper

In [None]:
# Figure for paper
df_plot = df_exp.copy(deep=True)
df_plot['genotype'] = df_plot['genotype'].replace({'PLP': 'P/LP'})

fig, axs = plt.subplots(1, 3, figsize=(60,16))

# Scatter plot: clusters
sns.scatterplot(data=df_plot, x='umap_1', y='umap_2', 
                hue='kmeans_label', 
                palette=['#ff7f0e', '#9467bd',  '#17becf'],
                # style='genotype',
                # markers=markers,
                s=200,
                alpha=0.5,
                ax=axs[0]
               )

# Scatter plot: genotype status
sns.scatterplot(data=df_plot, x='umap_1', y='umap_2', 
                hue='genotype', 
                palette=['#d62728', '#1f77b4', '#2ca02c'],
                linewidth=0,
                # style='kmeans_label',
                # markers=markers,
                s=200,
                alpha=0.5,
                ax=axs[1]
               )

# Bar plot: genotype status
freq_table.unstack().plot(kind='bar', color=['#d62728', '#1f77b4', '#2ca02c'], stacked=True, ax=axs[2])

# Plot parameters
# Legend
sns.move_legend(
    axs[0], "lower center", title = 'Cluster',
    bbox_to_anchor=(.5, 1), ncol=3, frameon=False, markerscale=2
)
sns.move_legend(
    axs[1], "lower center", title = 'Genotype',
    bbox_to_anchor=(.5, 1), ncol=3, frameon=False, markerscale=2
)
sns.move_legend(
    axs[2], "lower center", title = 'Genotype',
    bbox_to_anchor=(.5, 1), ncol=3, frameon=False, markerscale=2
)
plt.setp(axs[0].get_legend().get_texts(), fontsize='22') # for legend text
plt.setp(axs[0].get_legend().get_title(), fontsize='28') # for legend title
plt.setp(axs[1].get_legend().get_texts(), fontsize='22') # for legend text
plt.setp(axs[1].get_legend().get_title(), fontsize='28') # for legend title
plt.setp(axs[2].get_legend().get_texts(), fontsize='22') # for legend text
plt.setp(axs[2].get_legend().get_title(), fontsize='28') # for legend title

# Axis ticks
locy = plticker.MultipleLocator(base=0.5) # this locator puts ticks at regular intervals
locx = plticker.MultipleLocator(base=1)
axs[0].xaxis.set_major_locator(locx)
axs[0].yaxis.set_major_locator(locy)
axs[1].xaxis.set_major_locator(locx)
axs[1].yaxis.set_major_locator(locy)
axs[2].tick_params(labelrotation=0)

axs[0].spines['left'].set_color('black')
axs[0].spines['bottom'].set_color('black')
axs[1].spines['left'].set_color('black')
axs[1].spines['bottom'].set_color('black')
axs[2].spines['left'].set_color('black')
axs[2].spines['bottom'].set_color('black')

# Axis labels
axs[0].set_xlabel('UMAP 1', fontsize=20)
axs[0].set_ylabel('UMAP 2', fontsize=20)
axs[1].set_xlabel('UMAP 1', fontsize=20)
axs[1].set_ylabel('UMAP 2', fontsize=20)
axs[2].set_ylabel('Proportion of subjects', fontsize=20)
axs[2].set_xlabel('Cluster', fontsize=20)

# Significance markers
plt.text(0, 0.4, '**', ha='center', fontsize=35) # neg
plt.text(2, 0.9, '*', ha='center', fontsize=35) # vus
plt.text(2, 0.65, '**', ha='center', fontsize=35) # plp

plt.show()

# Statistical testing for feature importance

In [None]:
# Merge HCM variables and cluster results
df = pd.merge(df_kmeans, df_org, how='inner', left_index=True, right_index=True)
df = df[[x for x in df.columns if 'umap' not in x]]
df.head(5)

In [None]:
df_test = df.copy(deep=True)

## Significant features: Kruskal-Wallis test (continuous) and Chi-square (discrete)

In [None]:
# Kruskal-Wallis and Chi-square tests
table = []
for i, col in enumerate(df_test.columns):
    if col == 'kmeans_label':
        continue
    try:
        if np.issubdtype(df_test[col].dtype, np.number):
            statistics, pvalue = stats.kruskal(*[group[col].dropna().values for name, group in df_test.groupby('kmeans_label')])
            comment = ''
            test = 'kruskal'
        else: 
            statistics, pvalue = chi_sq_test(df_test, 'kmeans_label', col)
            comment = ''
            test = 'chi_sq'
    except Exception as e:
        statistics = np.nan
        pvalue = np.nan
        comment = repr(e)
    table += [[col, statistics, pvalue, comment, test]]
    
df_kruskal = pd.DataFrame(data=table, columns=['field','stats','pvalue','comment','test'])
df_kruskal.pvalue = df_kruskal.pvalue.astype(float)
df_kruskal = df_kruskal.dropna()
df_kruskal['pvalue_bh'] = mt.multipletests(df_kruskal.pvalue, alpha=0.05, method='fdr_bh')[1] # correction for Benjamini-Hochberg
df_kruskal = df_kruskal.sort_values(by='pvalue_bh',ascending=True).reset_index(drop=True)
list_sgnfc = df_kruskal.loc[df_kruskal.pvalue_bh < 0.5, 'field'].values
df_kruskal

## Cluster-specific significant associations

### Dunn test (continuous)

In [None]:
def dunn_test(df, col, group):
    table_p = sp.posthoc_dunn(df, val_col=col, group_col=group)
    table_p = table_p.reset_index()
    table_p = table_p.rename(columns={'index':'cluster_1'})
    table_sgnfc = table_p.melt('cluster_1', var_name='cluster_2', value_name='pvalue')
    table_sgnfc = table_sgnfc[table_sgnfc.pvalue < 1]
    table_sgnfc = table_sgnfc[~table_sgnfc.apply(frozenset, axis=1).duplicated()]
    
    return table_sgnfc

In [None]:
# Post hoc testing for cluster specific differences: Dunn for continuous
list_dunn = []
for col in list_sgnfc:
    if np.issubdtype(df_test[col].dtype, np.number):
        df_sgnfc = dunn_test(df, col, group='kmeans_label')
        df_sgnfc['var'] = col
        list_dunn.append(df_sgnfc)
df_dunn = pd.concat(list_dunn, axis=0)
df_dunn = df_dunn[['var', 'cluster_1', 'cluster_2', 'pvalue']].reset_index(drop=True)
df_dunn

### Fisher/hypergeometric test (discrete)

In [None]:
def hypergeom_test(df, var):
    table_p = []
    value_list = df[var].unique()
    for value in value_list:
        bmask = df[var] == value
        cluster_labels = df.kmeans_label
        for i in np.unique(cluster_labels):
            cl_mask = cluster_labels == i
            M = df.shape[0]  # Total
            n = np.sum(bmask)   # white balls
            N = np.sum(cl_mask) # selected
            k = np.sum((bmask) & (cl_mask)) # selected and white
            pvalue = (1 - hypergeom.cdf(k-1,M,n,N,loc=0))
            table_p.append([var, value, i, pvalue])
    return table_p

In [None]:
# Post hoc testing for cluster specific differences: hypergeometric test for discrete
list_hypergeom = []
for col in list_sgnfc:
    if not np.issubdtype(df_test[col].dtype, np.number):
        table = hypergeom_test(df, col)
        list_hypergeom += table
df_hgeom = pd.DataFrame(data=list_hypergeom, columns=['var', 'value', 'cluster', 'pvalue'])
df_hgeom = df_hgeom.sort_values(by='pvalue',ascending=True).reset_index(drop=True)
df_hgeom

### Correction with Benjamini-Hochberg

In [None]:
def convert_pvalue_to_asterisks(pvalue):
    if pvalue <= 0.0001:
        return "****"
    elif pvalue <= 0.001:
        return "***"
    elif pvalue <= 0.01:
        return "**"
    elif pvalue <= 0.05:
        return "*"
    return "ns"

In [None]:
# Correction with Benjamini-Hochberg
dunn_size = df_dunn.shape[0]
hgeom_size = df_hgeom.shape[0]
pvalue_bh = mt.multipletests([*df_dunn.pvalue.values, *df_hgeom.pvalue.values], alpha=0.05, method='fdr_bh')[1]
df_dunn['pvalue_bh'] = pvalue_bh[:dunn_size]
df_hgeom['pvalue_bh'] = pvalue_bh[dunn_size:]

In [None]:
df_dunn = df_dunn.sort_values(by='pvalue_bh', ascending=True)
df_dunn = df_dunn[df_dunn.pvalue_bh < 0.05]

In [None]:
df_dunn['asterisk'] = df_dunn['pvalue_bh'].apply(convert_pvalue_to_asterisks)
df_dunn

In [None]:
df_hgeom = df_hgeom.sort_values(by='pvalue_bh', ascending=True)
df_hgeom = df_hgeom[df_hgeom.pvalue_bh < 0.05]

In [None]:
# Labels for plot
var_red = {'sex': 'Sex',
            'Gen.Ht': 'Hypertension', 
            'Gen.Acearb': 'ACE/ARBs', 
            'Gen.Asaclopi' :'ASA/Clopi',
            'Hcm.Mitralregurgitation': 'Mitral regurgitation',
            'Gen.Betablocker': 'Beta blocker', 
            'Gen.Diuretic': 'Diuretic',
            'Hcm.Mostaffectedlevel': 'Most affected level',
            'Hcm.Familyhistoryofhcm': 'Family history of HCM', 
            'Gen.Nyha': 'NYHA score',
            'Hcm.Lvoto': 'LVOTO', 
            'Hcm.Lvgadolinum': 'LV gadolinium', 
            'Gen.Activityscore': 'Activity score', 
            'Gen.Cad': 'Coronary artery disease',
            }
value_red = {'M': 'Male',
             'F': 'Female',
             'False': 'No',
             'Apex': 'Apex',
             1: 'ok',
             '2': '2',
             '3': '3',
             'None': 'No'
            }

In [None]:
def concat_var_val(df):
    if df.value_red in ['No']:
        legend = df.value_red + ' ' + df.var_red.lower()
    elif df.value_red in ['']:
        legend = df.value_red + ' ' + df.var_red
    else:
        legend = df.var_red + ': ' + df.value_red.lower()
    
    return legend

In [None]:
# Create labels for plot
df_hgeom['var_red'] = df_hgeom['var'].replace(var_red)
df_hgeom['value_red'] = df_hgeom['value'].astype(str).replace(value_red).replace('True','').replace('ok','1')
df_hgeom['plot_legend'] = df_hgeom.apply(lambda row : concat_var_val(row), axis=1)

In [None]:
df_hgeom['asterisk'] = df_hgeom['pvalue_bh'].apply(convert_pvalue_to_asterisks)
df_hgeom

## Significance plots

### Mean + IQR plots (numerical features)

In [None]:
def q1_diff(x):
    return x.median() - x.quantile(0.25)

def q3_diff(x):
    return x.quantile(0.75) - x.median()

f = ['median', q1_diff, q3_diff]

In [None]:
# Create summary for numerical features
df_mean_iqr = df_test[[*['kmeans_label'], *set(df_dunn['var'].values)]]
df_mean_iqr['kmeans_label'] = df_mean_iqr['kmeans_label'].astype('category')
stats_mean_iqr = df_mean_iqr.groupby('kmeans_label').agg(f)

In [None]:
# Units for numerical features
plot_title = {'Gen.Alcohol': 'Alcohol units/week',
 'Gen.Diastolic': 'DBP (mmHg)',
 'Gen.Pulserate': 'Pulse rate (bpm)',
 'Gen.Systolic': 'SBP (mmHg)',
 'Hcm.Bsa': 'BSA (m2)',
 'Hcm.Edv': 'LV ED volume',
 'Hcm.Ef': 'Ejection fraction',
 'Hcm.Esv': 'LV ES volume',
 'Hcm.Lvm': 'LV mass',
 'Hcm.Lvotpeakvelocity': 'LV outflow peak velocity',
 'Hcm.Maxlvwallthickness': 'LV max WT',
 'Hcm.Sv': 'Stroke volume',
 'age_at_scan': 'Age'}

In [None]:
# Summary plots
rown=3
coln=5
fig, axs = plt.subplots(rown, coln, figsize=(20,10))
fig.tight_layout(pad=5)
axs = axs.flatten()

for i, col in enumerate(df_mean_iqr.columns[1:]):

    col_iqr = stats_mean_iqr[col].reset_index()
    axs[i].errorbar(
        data=col_iqr, x='median', y='kmeans_label',
        xerr=col_iqr[['q1_diff','q3_diff']].T,
        capsize=0, 
        elinewidth=5, 
        ls='None',
        ecolor=['#ff7f0e', '#9467bd',  '#17becf'],
    )
    axs[i].scatter(data=col_iqr, x='median', y='kmeans_label', s = 70, marker = "o", color = ['#ff7f0e', '#9467bd',  '#17becf'])
    
    axs[i].set_title(plot_title[col], 
                     fontsize=10,
                     loc='left'
                    )
    
    # Significance lines
    xlim_min = axs[i].get_xlim()[0]
    xlim = axs[i].get_xlim()[1]
    range_x = xlim - xlim_min
    off = 0
    df_sgnfc = df_dunn[df_dunn['var']==col]
    for j, sgnfc in df_sgnfc.iterrows():
        cl1 = sgnfc.cluster_1
        cl2 = sgnfc.cluster_2
        ast = sgnfc.asterisk
        xlim = xlim + off
        axs[i].plot([xlim, xlim, xlim, xlim], [cl1, cl2, cl2, cl1], linewidth=1.2, color='k') 
        axs[i].text(xlim+range_x*0.015, abs(cl2+cl1)/2, ast, rotation=90, fontsize=10)
        off += range_x*0.05
    
for ax in axs:
    ax.set_ylabel('Cluster', fontsize=10)
    ax.spines['left'].set_color('black')
    ax.spines['bottom'].set_color('black')
    ax.xaxis.label.set_color('black')
    ax.tick_params('both', length=6, width=1, which='major', colors='black', labelsize=8)
    ax.locator_params(axis='x', nbins=5)
    ax.locator_params(axis='y', nbins=3)
    ax.xaxis.label.set_visible(False)

for j in range(i+1,rown*coln):
    fig.delaxes(axs[j])
    
plt.text(60, 10, 'D', weight='bold', size=20)

plt.show()

### Circular radial plot (categorical features)

In [None]:
# Processing for radial plot
df_hgeom.pvalue_bh = df_hgeom.pvalue_bh.replace(0,10**(-15))
df_hgeom['log_pvalue_bh'] = -np.log10(df_hgeom.pvalue_bh)

In [None]:
COLOR_GRAD = {0: ["#ff7f0e", "#ff923c", "#ffa45e", "#ffb67e", "#ffc89d", "#ffdbbd"],
           1: ["#9467bd", "#a47cc7", "#b491d0", "#c3a6da", "#d2bce3", "#e1d2ec"],
           2: ["#17becf", "#56c7d6", "#7ad1dd", "#98dae3", "#b3e3ea", "#cdedf1"]
          }

In [None]:
df_hgeom = df_hgeom[~df_hgeom['var'].isin(['Hcm.Mostaffectedlevel'])]

In [None]:
def get_label_rotation(angle, offset):
    rotation = np.rad2deg(angle)
    if angle <= np.pi/2:
        alignment = "left"
    elif angle <= np.pi:
        alignment = "right"
        rotation = rotation + 180
    elif angle <= 3*np.pi/2:
        alignment = 'right'
        rotation = rotation + 180
    else: 
        alignment = "left"
    return rotation, alignment

def add_labels(angles, values, labels, offset, ax):

    padding = 4 # space between bar and legend
    
    for angle, value, label, in zip(angles, values, labels):
        angle = angle
        rotation, alignment = get_label_rotation(angle, offset)
        ax.text(
            x=angle, 
            y=value + padding, 
            s=label, 
            ha=alignment, 
            va="center", 
            # rotation=0, 
            rotation=rotation,
            rotation_mode="anchor",
            fontsize=13
        ) 

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(20, 10), subplot_kw={"projection": "polar"})

for j in df_hgeom.cluster.sort_values().unique():

    df_cl = df_hgeom[df_hgeom.cluster==j]
    j-=1
    # Grab the group values: CHANGE TO FEATURE SUBGROUPS
    GROUP = df_cl['asterisk'].values
    VALUES = 5*df_cl['log_pvalue_bh'].values
    LABELS = df_cl['plot_legend']

    # Add three empty bars to the end of each group
    PAD = 1
    ANGLES_N = len(VALUES) + PAD * len(np.unique(GROUP))
    ANGLES = np.linspace(0, 2 * np.pi, num=ANGLES_N, endpoint=False)
    WIDTH = (2 * np.pi) / len(ANGLES)

    # Obtaining the right indexes is now a little more complicated
    offset = 0
    IDXS = []
    GROUPS_SIZE = df_cl['asterisk'].value_counts(sort=False).to_list()
    for group, size in zip(df_cl.asterisk.unique(), GROUPS_SIZE):
        
        IDXS += list(range(offset + PAD, offset + size + PAD))
        # Add line below bars
        x1 = np.linspace(ANGLES[offset + PAD], ANGLES[offset + size + PAD - 1], num=50)
        axs[j].plot(x1, [-5] * 50, color="#333333")

        # Add text to indicate group
        axs[j].text(
            np.mean(x1), -25, group, color="#333333", fontsize=14, 
            fontweight="bold", ha="center", va="center"
        )
        offset += size + PAD
    axs[j].set_theta_offset(0)
    axs[j].set_ylim(-100, 100)
    axs[j].set_frame_on(False)
    axs[j].xaxis.grid(False)
    axs[j].yaxis.grid(False)
    axs[j].set_xticks([])
    axs[j].set_yticks([])
    axs[j].set_title(f'Cluster {j+1}',
                    fontsize=16,
                     loc='left',
                     y=1.1,
                    )

    # Use different colors for each cluster
    COLORS = [COLOR_GRAD[j][i] for i, size in enumerate(GROUPS_SIZE) for _ in range(size)]
 
    axs[j].bar(
        ANGLES[IDXS], VALUES, width=WIDTH, 
        color=COLORS, 
        edgecolor="white", linewidth=2
    )

    add_labels(ANGLES[IDXS], VALUES, LABELS, offset, axs[j])

plt.text(-2.5, 1.2, 'E', weight='bold',
        horizontalalignment='left',
        verticalalignment='top',
        transform=plt.gca().transAxes,
        size=20)
plt.show()

# Supplemental

In [None]:
# # Grouped circular radial plot
# ANGLES = np.linspace(0, 2 * np.pi, len(df), endpoint=False)
# VALUES = df["value"].values
# LABELS = df["name"].values

# # Determine the width of each bar. 
# # The circumference is '2 * pi', so we divide that total width over the number of bars.
# WIDTH = 2 * np.pi / len(VALUES)

# # Determines where to place the first bar. 
# # By default, matplotlib starts at 0 (the first bar is horizontal)
# # but here we say we want to start at pi/2 (90 deg)
# OFFSET = np.pi / 2

# # Initialize Figure and Axis
# fig, ax = plt.subplots(figsize=(20, 10), subplot_kw={"projection": "polar"})

# # Specify offset
# ax.set_theta_offset(OFFSET)

# # Set limits for radial (y) axis. The negative lower bound creates the whole in the middle.
# ax.set_ylim(-100, 100)

# # Remove all spines
# ax.set_frame_on(False)

# # Remove grid and tick marks
# ax.xaxis.grid(False)
# ax.yaxis.grid(False)
# ax.set_xticks([])
# ax.set_yticks([])

# # Add bars
# ax.bar(
#     ANGLES, VALUES, width=WIDTH, linewidth=2,
#     color="#61a4b2", edgecolor="white"
# )

# # Add labels
# add_labels(ANGLES, VALUES, LABELS, OFFSET, ax)

############

## Info table

In [None]:
# Seaborn pointplot (incorrect quantiles)

sns.pointplot(
    data=df_mean_iqr, x=col, y='kmeans_label',
    # errorbar=lambda x: (x.min(), x.max()), 
    errorbar=lambda x: (x.quantile(0.25), x.quantile(0.75)), 
    estimator=np.median,
    capsize=0, 
    errwidth=4, 
    join=False, 
    palette=['#ff7f0e', '#9467bd',  '#17becf'],
    ax=axs[i]
)

In [None]:
# Table info for paper (deprecated)
# for var in ['race', 'Gen.Activityscore', 'Gen.Ccs', 'Gen.Nyha', 'Hcm.Lvgadolinum', 'Hcm.Mitralregurgitation', 'Hcm.Lvmostaffectedsegment', 'Hcm.Mostaffectedlevel']:
#     print(df_org[var].value_counts().sort_index())
#     print(round(df_org[var].value_counts(normalize=True)*100,1).sort_index())

In [None]:
# Table info for paper (deprecated)
# # Import for table in paper
# # Get info for table (paper)
# df_info = df_org.describe(include='all')
# df_info = df_info.round(1)
# df_info = df_info.T
# print(df_info.shape[0])
# # df_info.to_csv('clin_features_info_no_na.csv')
# df_info

## Figures

In [None]:
# Plot significant fields found with Kruskal Wallis test
top = df_kruskal.loc[df_kruskal.pvalue_bh < 0.05, 'field']
top = [*top[:min(len(top), 14)-1], *['genotype']]
top = list(dict.fromkeys(top))
df_top = pd.merge(df_test, df_umap, how='inner', left_index=True, right_index=True)
n_cluster = len(df_test['hdbscan_label'].unique())

n_row = ((len(top)-1) // 5) + 1
n_col = min(5, len(top))
fig, axs = plt.subplots(n_row, n_col, figsize=(40, 25))
for i, field in enumerate(top):
    if (set(df_top[field].dropna().unique())==set([1,2])):
        palette='tab10'
    else:
        palette='hls'
    sns.scatterplot(data=df_top, x='umap_1', y='umap_2', hue=field, palette=palette, style='hdbscan_label', s=300, ax=axs.flat[i])
    h,l = axs.flat[i].get_legend_handles_labels()
    size = len(h)
    axs.flat[i].legend(h[:size-n_cluster-1],l[:size-n_cluster-1], fontsize=14) # remove symbol from legend
    #bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    pvalue = df_kruskal.loc[df_kruskal.field==field,'pvalue'].iloc[0]
    pvalue = "%.2e"%pvalue
    axs.flat[i].set_title(f'p-value={pvalue}', fontsize=16)
fig.suptitle('Significant associations between clusters and features', fontsize=30, y=0.92)
plt.show(fig)

In [None]:
# Figure for paper (not used)
pos = gp.iloc[gp.index.get_level_values('genotype') == 'PLP']
pos_fq = pos.values
pos_fq = [f'{round(x*100, 1)}%' for x in pos_fq]
print(pos_fq)

pal = sns.color_palette('Set1_r')
pal.as_hex()[:]
import matplotlib.ticker as plticker

markers = {'Negative': 'X', 'Positive': 's', 'Indeterminate': 'o'}
df_plot = df_top.copy(deep=True)
df_plot = df_plot.rename(columns={'kmeans_label':'K-means cluster and \n proportion of genotype-positive', 'genotype':'Genotype status'})
df_plot['Genotype status'] = df_plot['Genotype status'].replace({'NEG': 'Negative', 'PLP': 'Positive', 'VUS': 'Indeterminate'})
df_plot['K-means cluster and \n proportion of genotype-positive'] = df_plot['K-means cluster and \n proportion of genotype-positive'].replace({0: f'0 ({pos_fq[0]})',
                                                                1: f'1 ({pos_fq[1]}) (*)',
                                                                2: f'2 ({pos_fq[2]})',})

fig, ax = plt.subplots(1, 2, figsize=(16,16))
sns.scatterplot(data=df_plot, x='umap_1', y='umap_2', 
                hue='K-means cluster and \n proportion of genotype-positive', 
                palette='Set1', 
                style='Genotype status',
                markers=markers,
                s=100
               )

plt.xlabel('UMAP 1', fontsize=18)
plt.ylabel('UMAP 2', fontsize=18)
# plt.legend(loc='upper right')
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
fig.suptitle('Distribution of genotype status per cluster', fontsize=25, y=0.92)

## Shape data

In [None]:
df_es = pd.read_csv(r"/home/pps21@isd.csc.mrc.ac.uk/pps21/cardiac-dimred/data/raw/wallthickness_all_vertices_ES_decimated_0.99.csv", index_col=0)
df_ed = pd.read_csv(r"/home/pps21@isd.csc.mrc.ac.uk/pps21/cardiac-dimred/data/raw/wallthickness_all_vertices_ED_decimated_0.99.csv", index_col=0)
df_es = df_es.rename(columns={'ID':'patient_id'})
df_ed = df_ed.rename(columns={'ID':'patient_id'})
df_es = df_es.set_index('patient_id')
df_ed = df_ed.set_index('patient_id')
df_es = df_es.drop(columns='genotype')
df_ed = df_ed.drop(columns='genotype')

In [None]:
print(df_es.shape)
df_es.head(5)

In [None]:
# Pick ES or ED for analysis
df_all = df_es.copy(deep=True)

In [None]:
scaler = MinMaxScaler()
scaled = scaler.fit_transform(df_all)
df_scaled = pd.DataFrame(data=scaled, columns=df_all.columns)
df_scaled.index = df_all.index
print(df_scaled.shape)

## Cox model

In [None]:
# # on psm subset
# cox predicting survival from u1 u2, on isdeceased + how many years of survival from enrollment (compute) 
# pygam cox? 
# pvalue betwen coeff and survival? 
# predict probability of survival at 1,2,5,10

# df_umap12 = df_kmeans[['umap_1','umap_2','kmeans_label']]
# df_cox = pd.merge(df_psm, df_umap12, left_index=True, right_index=True)
# df_cox = pd.merge(df_cox, df_hcm[['Gen.Ht']], left_index=True, right_index=True)
# print(df_cox.shape)
# df_cox

In [None]:
# Cox model
# from lifelines import CoxPHFitter
# df_cox_model = df_cox[['umap_1','umap_2', 'is_deceased','time_censor']]

# cph = CoxPHFitter()
# cph.fit(df_cox_model, duration_col = 'time_censor', event_col = 'is_deceased')
# cph.print_summary()

In [None]:
# df_cox['cox_prob_1'] = cph.predict_survival_function(df_cox_model, 1).T
# df_cox['cox_prob_2'] = cph.predict_survival_function(df_cox_model, 2).T
# df_cox['cox_prob_5'] = cph.predict_survival_function(df_cox_model, 5).T
# df_cox['cox_prob_10'] = cph.predict_survival_function(df_cox_model, 10).T

In [None]:
# fig, axs = plt.subplots(1, 4, figsize=(40, 10))
# sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='cox_prob_1', style='is_plp', hue_norm=(0.7, 1), s=100, ax=axs[0])
# sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='cox_prob_2', style='is_plp', hue_norm=(0.7, 1), s=100, ax=axs[1])
# sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='cox_prob_5', style='is_plp', hue_norm=(0.7, 1), s=100, ax=axs[2])
# sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='cox_prob_10', style='is_plp', hue_norm=(0.7, 1), s=100, ax=axs[3])

# norm = plt.Normalize(0.5, 1)
# cmap = sns.cubehelix_palette(light=1, as_cmap=True)
# sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# sm.set_array([])
# # cax = fig.add_axes([axs[0].get_position().x1+0.05, axs[0].get_position().y0, 0.06, axs[0].get_position().height / 2])
# axs[0].figure.colorbar(sm, 
#                        # cax=cax
#                       )
# fig.suptitle('Cox model on UMAP coordinates', fontsize=25, y=0.95)
# axs[0].set_title('Probability of survival (1 year)')
# axs[1].set_title('Probability of survival (2 years)')
# axs[2].set_title('Probability of survival (5 years)')
# axs[3].set_title('Probability of survival (10 years)')
# axs[0].get_legend().remove()
# axs[1].get_legend().remove()
# axs[2].get_legend().remove()
# axs[3].get_legend().remove()
# plt.show()
# fig.save('cox_umap.pdf')

## Logistic regression probability of plp

In [None]:
from sklearn.linear_model import LogisticRegression
model = LogisticRegression(solver='liblinear', random_state=0)
model.fit(df_cox[['umap_1','umap_2']], df_cox['is_plp'])

In [None]:
prob_plp = model.predict_proba(df_cox[['umap_1','umap_2']])[:,1]
df_cox['prob_plp'] = prob_plp

In [None]:
sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='prob_plp', style='is_plp', s=100)

## GAM

In [None]:
from pygam import LogisticGAM
gam = LogisticGAM(n_splines=4).fit(df_cox[['umap_1','umap_2']], df_cox['is_plp'])

In [None]:
gam_prob_plp = gam.predict_proba(df_cox[['umap_1','umap_2']])
df_cox['gam_prob_plp'] = gam_prob_plp

In [None]:
sns.scatterplot(data=df_cox, x='umap_1', y='umap_2', hue='gam_prob_plp', style='is_plp', s=100)

## Propensity matching (survival)

In [None]:
df_cov = df_hcm[['age_at_scan', 'plp' ]].reset_index()
psm = PsmPy(df_cov, treatment='plp', indx='patient_id', exclude = [])

In [None]:
psm.logistic_ps(balance = False)

In [None]:
psm.knn_matched_12n(matcher='propensity_logit', how_many=1)

In [None]:
psm.matched_ids

In [None]:
df_matched = psm.df_matched
df_matched

In [None]:
list_psm = df_matched.patient_id.to_list()

## Include survival prediction model

In [None]:
# Label by survival probability
df_prob = pd.read_csv(r'rbh_cb_prob_deceased.csv', index_col=0)
df_prob = df_prob.set_index('patient_id')
df_prob['prob_norm'] = (df_prob.prob_deceased-df_prob.prob_deceased.min())/(df_prob.prob_deceased.max()-df_prob.prob_deceased.min())
df_prob_plot = pd.merge(df_top, df_prob, how='inner', left_index=True, right_index=True)
plt.figure(figsize=(16,8))
sns.scatterplot(data=df_prob_plot, x='umap_1', y='umap_2', hue='prob_deceased', palette='crest', style=labels, s=100)