In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import seaborn as sns
from glob import glob
import pathlib
from sklearn.metrics import mean_absolute_error
from scipy import stats
import patchworklib as pw
import os
import functools
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import shap
from slugify import slugify


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# 0. Setup

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/044_small_immuno_clocks_revision"
pathlib.Path(f"{path}").mkdir(parents=True, exist_ok=True)

# 1. Prepare additional test data

In [None]:
df_origin = pd.read_excel(f"{path}/data_origin/260_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
df_all = pd.read_excel(f"{path}/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
indexes_not_origin = df_all.index.difference(df_origin.index)
df_all["parts_danet"] = df_all["Region"].str.cat(df_all[["Status"]].astype(str), sep="_")
df_all["Split"] = "tst"

df_new = pd.concat([df_origin, df_all.loc[indexes_not_origin, :]])
df_new.to_excel(f"{path}/all_for_test.xlsx", index_label="index")

# 2. Create new dataset

In [None]:
df_origin = pd.read_excel(f"{path}/origin/260_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
df_all = pd.read_excel(f"{path}/origin/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)

df_tst_central_include = pd.read_excel(f"{path}/origin/samples_test_slctd.xlsx", index_col=0)
df_tst_central_include["index"] = df_tst_central_include.index.values
df_tst_central_include["index"] = df_tst_central_include["index"].str.rstrip('_copy')
indexes_test_ctrl_central = df_tst_central_include["index"].values
print(len(indexes_test_ctrl_central))
df_test_ctrl_central = df_all.loc[indexes_test_ctrl_central, :].copy()
df_test_ctrl_central["parts_danet"] = 'tst_ctrl_central'
df_test_ctrl_central["Split"] = 'tst_ctrl_central'

df_res = pd.read_excel(f"{path}/origin/models/danet_inference/runs/2023-04-12_12-16-05/df.xlsx", index_col=0)
indexes_test_ctrl_yakutia = df_res.index[(df_res["parts_danet"] == "Yakutia_Control")].values
print(len(indexes_test_ctrl_yakutia))
df_test_ctrl_yakutia = df_all.loc[indexes_test_ctrl_yakutia, :].copy()
df_test_ctrl_yakutia["parts_danet"] = 'tst_ctrl_yakutia'
df_test_ctrl_yakutia["Split"] = 'tst_ctrl_yakutia'

indexes_test_esrd = df_all.index[df_all["Status"] == "ESRD"].values
print(len(indexes_test_esrd))
df_test_esrd = df_all.loc[indexes_test_esrd, :]
df_test_esrd["parts_danet"] = 'tst_esrd'
df_test_esrd["Split"] = 'tst_esrd'

df_new = pd.concat([df_origin, df_test_ctrl_central, df_test_ctrl_yakutia, df_test_esrd])

df_dead_alive = pd.read_excel(f"{path}/origin/df_samples_dead_or_alive.xlsx", index_col=0)
inds_dead_alive = df_dead_alive.index[df_dead_alive["Dead_Alive"] == "Dead"].values
df_new.loc[inds_dead_alive, "Dead_Alive"] = "Dead"

df_new.to_excel(f"{path}/data_wtf.xlsx", index_label="index")

# 3. Collect ML results

In [None]:
model = 'widedeep_ft_transformer_trn_val_tst'

part_check = "tst_ctrl_central"
part_check_thld_mean = 7.5
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_tst_ctrl_central = df.index[df["Split"] == "tst_ctrl_central"].values

path_runs = f"{path}/models/10_trn_val_tst/best_{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_all_best_*.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)

for file in files:

    head, tail = os.path.split(file)
    df_pred = pd.read_excel(f"{head}/predictions.xlsx", index_col=0)
    df_pred["Error"] = df_pred["Prediction"] - df_pred["Age"]
    df_pred["AbsError"] = df_pred["Error"].abs()
    part_col = df_pred.columns[0]
    df_pred_tst_ctrl_central = df_pred.loc[df_pred[part_col] == part_check, :].copy()
    df_pred_tst_ctrl_central.sort_values(["AbsError"], ascending=[True], inplace=True)
    df_pred_tst_ctrl_central["MeanAbsErrorExpanding"] = df_pred_tst_ctrl_central["AbsError"].expanding().mean()
    samples_tst_ctrl_central_passed = df_pred_tst_ctrl_central.index[df_pred_tst_ctrl_central["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    n_samples_passed = len(samples_tst_ctrl_central_passed)

    df_res.at[file, "passed_test_samples"] = n_samples_passed

    # tst_ctrl_central
    y_real = df_pred.loc[samples_tst_ctrl_central, "Age"]
    y_pred = df_pred.loc[samples_tst_ctrl_central, "Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[file, 'mean_absolute_error_tst_ctrl_central'] = mae_tst
    df_res.at[file, 'pearson_corr_coef_tst_ctrl_central'] = rho_tst

    # Metrics
    df_metrics = pd.read_excel(file, index_col="metric")
    for metric in df_metrics.index.values:
        df_res.at[file, metric + "_val"] = df_metrics.at[metric, "val"]
        df_res.at[file, metric + "_trn"] = df_metrics.at[metric, "trn"]
        df_res.at[file, metric + "_tst_ctrl_central"] = df_metrics.at[metric, "tst_ctrl_central"]
        df_res.at[file, metric + "_trn_val_tst_ctrl_central"] = df_metrics.at[metric, "trn_val_tst_ctrl_central"]
        df_res.at[file, metric + "_val_tst_ctrl_central"] = df_metrics.at[metric, "val_tst_ctrl_central"]

    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

df_res["train_more_val"] = False
df_res["selected"] = False
df_res.loc[df_res["mean_absolute_error_trn"] > df_res["mean_absolute_error_val"], "train_more_val"] = True

first_columns = [
    'selected',
    'passed_test_samples',
    'train_more_val',
    'mean_absolute_error_trn',
    'mean_absolute_error_val',
    'mean_absolute_error_tst_ctrl_central',
    'mean_absolute_error_val_tst_ctrl_central',
    'mean_absolute_error_trn_val_tst_ctrl_central',
    'pearson_corr_coef_trn',
    'pearson_corr_coef_val',
    'pearson_corr_coef_tst_ctrl_central',
    'pearson_corr_coef_val_tst_ctrl_central',
    'pearson_corr_coef_trn_val_tst_ctrl_central',
    'mean_absolute_error_cv_mean_trn',
    'mean_absolute_error_cv_std_trn',
    'mean_absolute_error_cv_mean_val',
    'mean_absolute_error_cv_std_val',
    'pearson_corr_coef_cv_mean_trn',
    'pearson_corr_coef_cv_std_trn',
    'pearson_corr_coef_cv_mean_val',
    'pearson_corr_coef_cv_std_val',
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 4. Decider for central controls

In [None]:
part_check = "tst_ctrl_all"
part_check_thld_mean = 8.43

df = pd.read_excel(f"{path}/data.xlsx", index_col=0)
samples_test = df.index[df["Split"] == "tst_ctrl_all"].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

models_main = [
    "danet",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"

df_res = pd.DataFrame(index=models_all)
df_samples_test = pd.DataFrame(index=samples_test, columns=models_all)

for m in models_all:
    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)
    path_head = path_head.replace('/46/', '/46_trn_val_tst/', 1)

    file_val = glob(f"{path_head}/metrics_val_best_*.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)

    df_res.at[m, 'val_mae_best'] = df_res_val.at['mean_absolute_error', 'val']
    df_res.at[m, 'val_mae_mean'] = df_res_val.at['mean_absolute_error_cv_mean', 'val']
    df_res.at[m, 'val_mae_std'] = df_res_val.at['mean_absolute_error_cv_std', 'val']
    df_res.at[m, 'val_rho_best'] = df_res_val.at['pearson_corr_coef', 'val']
    df_res.at[m, 'val_rho_mean'] = df_res_val.at['pearson_corr_coef_cv_mean', 'val']
    df_res.at[m, 'val_rho_std'] = df_res_val.at['pearson_corr_coef_cv_std', 'val']

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    df_pred["Error"] = df_pred["Prediction"] - df_pred["Age"]
    df_pred["AbsError"] = df_pred["Error"].abs()
    part_col = df_pred.columns[0]
    df_pred[part_col].replace({"tst_ctrl_central": part_check}, inplace=True)
    df_pred = df_pred.loc[df_pred[part_col] == part_check, :]
    df_pred.sort_values(["AbsError"], ascending=[True], inplace=True)
    df_pred["MeanAbsErrorExpanding"] = df_pred["AbsError"].expanding().mean()
    samples_passed = df_pred.index[df_pred["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    df_samples_test.loc[:, m] = 0
    df_samples_test.loc[samples_passed, m] = 1
    n_samples_passed = len(samples_passed)
    print(f"{m}: {n_samples_passed}")

df_samples_test.to_excel(f"{path_models}/samples_test_full.xlsx", index_label="model")

conditions = [df_samples_test[m] == 1 for m in models_main]
df_samples_test = df_samples_test[conjunction(conditions)]
samples_test_final = df_samples_test.index.values
print(len(samples_test_final))

for m in models_all:
    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)
    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    df_pred = df_pred.loc[samples_test_final, :]
    y_real = df_pred["Age"]
    y_pred = df_pred["Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[m, 'tst_mae'] = mae_tst
    df_res.at[m, 'tst_rho'] = rho_tst

df_res.to_excel(f"{path_models}/baseline_results.xlsx", index_label="model")
df_samples_test.to_excel(f"{path_models}/samples_test_slctd.xlsx", index_label="model")

# 5. Updating data with trn/val splits from best models

In [None]:
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"
for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    part_col = df_pred.columns[0]
    ids_trn = df_pred.index[df_pred[part_col] == "trn"].values
    df.loc[ids_trn, f"best_{m}"] = "trn"
    ids_val = df_pred.index[df_pred[part_col] == "val"].values
    df.loc[ids_val, f"best_{m}"] = "val"

df.to_excel(f"{path}/data_final1.xlsx", index_label="index")

# 6. Inference ckecking and yakutia thresholding

In [None]:
part_check = "tst_ctrl_yakutia"
part_check_thld_mean = 99999

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_test = df.index[df["Split"] == part_check].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

models_main = [
    "danet",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_inference"

df_res = pd.DataFrame(index=models_all)
df_samples_test = pd.DataFrame(index=samples_test, columns=models_all)

for m in models_all:
    file_val = glob(f"{path_models}/{m}_inference/runs/*/metrics_val.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)
    df_res.at[m, 'val_mae'] = df_res_val.at['mean_absolute_error', 'val']
    df_res.at[m, 'val_rho'] = df_res_val.at['pearson_corr_coef', 'val']

    file_tst_central = glob(f"{path_models}/{m}_inference/runs/*/metrics_tst_ctrl_central.xlsx")[0]
    df_res_tst_central = pd.read_excel(file_tst_central, index_col=0)
    df_res.at[m, 'tst_central_mae'] = df_res_tst_central.at['mean_absolute_error', 'tst_ctrl_central']
    df_res.at[m, 'tst_central_rho'] = df_res_tst_central.at['pearson_corr_coef', 'tst_ctrl_central']

    file_val = glob(f"{path_models}/{m}_inference/runs/*/metrics_tst_esrd.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)
    df_res.at[m, 'tst_esrd_mae'] = df_res_val.at['mean_absolute_error', 'tst_esrd']
    df_res.at[m, 'tst_esrd_rho'] = df_res_val.at['pearson_corr_coef', 'tst_esrd']

    file_pred = glob(f"{path_models}/{m}_inference/runs/*/df.xlsx")[0]
    df_pred = pd.read_excel(file_pred, index_col=0)
    df_pred = df_pred.loc[df_pred[f"best_{m}"] == part_check, :]
    df_pred.sort_values(["Prediction error abs"], ascending=[True], inplace=True)
    df_pred["MeanAbsErrorExpanding"] = df_pred["Prediction error abs"].expanding().mean()
    samples_passed = df_pred.index[df_pred["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    df_samples_test.loc[:, m] = 0
    df_samples_test.loc[samples_passed, m] = 1
    n_samples_passed = len(samples_passed)
    print(f"{m}: {n_samples_passed}")

df_samples_test.to_excel(f"{path_models}/samples_test_full.xlsx", index_label="model")

conditions = [df_samples_test[m] == 1 for m in models_main]
df_samples_test = df_samples_test[conjunction(conditions)]
samples_test_final = df_samples_test.index.values
print(len(samples_test_final))

for m in models_all:
    file_pred = glob(f"{path_models}/{m}_inference/runs/*/df.xlsx")[0]
    df_pred = pd.read_excel(file_pred, index_col=0)
    df_pred = df_pred.loc[samples_test_final, :]
    y_real = df_pred["Age"]
    y_pred = df_pred["Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[m, 'tst_yakutia_mae'] = mae_tst
    df_res.at[m, 'tst_yakutia_rho'] = rho_tst

df_res.to_excel(f"{path_models}/baseline_results_{part_check_thld_mean}.xlsx", index_label="model")
df_samples_test.to_excel(f"{path_models}/samples_test_slctd_{part_check_thld_mean}.xlsx", index_label="model")

# 7. Feature selection and dimensionality reduction via SHAP

In [None]:
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
df_trn_val = df.loc[df["Split"] == "trn_val", :]
feats = pd.read_excel(f"{path}/feats_con.xlsx", index_col=0).index.values
path_models = f"{path}/models/46_shap"

models = {
    "danet": "DANet",
    "widedeep_tab_net": "TabNet",
    "widedeep_saint": "SAINT",
    "widedeep_ft_transformer": "FT-Transformer"
}

df_fi = pd.DataFrame(index=feats)

for m in models:
    file_shap = glob(f"{path_models}/{m}_inference/runs/*/shap/trn_val/shap.xlsx")[0]
    df_shap = pd.read_excel(file_shap, index_col=0)

    for f in feats:
        df_fi.at[f, models[m]] = df_shap[f].abs().mean()

df_fi['Summary'] = df_fi.sum(axis=1)
df_fi.sort_values(['Summary'], ascending=[False], inplace=True)
df_fi.to_excel(f"{path_models}/A.xlsx", index_label="features")

feats_top10 = df_fi.index.values[0:10]
df_fig = df_fi.loc[:, list(models.values())]

sns.set_theme(style='ticks', font_scale=3)
df_fig.plot(kind='bar', stacked=True, color=px.colors.qualitative.D3, figsize=(34, 10), edgecolor='black')
plt.xticks(rotation=90)
plt.xlabel('Features')
plt.ylabel('Mean(|SHAP values|)')
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path_models}/fi.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_models}/fi.pdf", bbox_inches='tight')

pw.overwrite_axisgrid()
sns.set_theme(style='ticks', font_scale=3)
ax1 = pw.Brick(figsize=(30, 10))
tmp = df_fig.plot(kind='bar', stacked=True, color=px.colors.qualitative.D3, ax=ax1, edgecolor='black')
sns.despine(left=False, right=True, bottom=False, top=True, ax=ax1)
ax1.set_xlabel('')
ax1.set_ylabel('Mean(|SHAP values|)')
ax1.text(-4, 45,'A', fontsize=72)
ax1.text(-4, -15,'B', fontsize=72)
feats_plot = ["Age"] + list(feats_top10)
sns.set_theme(style="ticks", font_scale=2.0)
g1 = sns.PairGrid(df_trn_val, vars=feats_plot)
g1.map_diag(plt.hist, bins=15, color='darkred', edgecolor='k')
g1.map_upper(sns.scatterplot, color='darkred', s=10, alpha=0.5, edgecolor='k', linewidth=0.2)
g1.map_lower(sns.kdeplot, fill=True, cbar=False, cmap='rocket_r', thresh=-0.1)
g1.add_legend()
for ax in g1.axes.flatten():
    ax.get_yaxis().set_label_coords(-0.5, 0.5)
ax2 = pw.load_seaborngrid(g1, figsize=(30, 30))
(ax1/ax2).savefig(f"{path_models}/together.pdf", bbox_inches='tight')
(ax1/ax2).savefig(f"{path_models}/together.png", bbox_inches='tight', dpi=400)
df_save = df_trn_val.loc[:, feats_plot]
df_save.to_excel(f"{path_models}/B.xlsx", index_label='index')

# 8. trn/val split

## 8.1 Check trn/val identity

In [None]:
pathlib.Path(f"{path}/figure_splits").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_trn_val = df.index[df["Split"] == "trn_val"].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"

df_folds = {f"fold_{fold_idx:04d}": pd.DataFrame(index=samples_trn_val) for fold_idx in range(5)}

for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_cv_ids = pd.read_excel(f"{path_head}/cv_ids.xlsx", index_col=0)
    for fold_idx in range(5):
        df_folds[f"fold_{fold_idx:04d}"].loc[samples_trn_val, m] = df_cv_ids.loc[samples_trn_val, f"fold_{fold_idx:04d}"]

for fold_idx in range(5):
    df_folds[f"fold_{fold_idx:04d}"]['matching'] = df_folds[f"fold_{fold_idx:04d}"].eq(df_folds[f"fold_{fold_idx:04d}"].iloc[:, 0], axis=0).all(1)
    df_folds[f"fold_{fold_idx:04d}"].to_excel(f"{path}/figure_splits/fold_{fold_idx:04d}.xlsx", index_label="index")
    n_matches = df_folds[f"fold_{fold_idx:04d}"].index[df_folds[f"fold_{fold_idx:04d}"]['matching'] == True].values.shape[0]
    print(n_matches)

## 8.2 Plot split histograms

In [None]:
pathlib.Path(f"{path}/figure_splits").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_trn_val = df.index[df["Split"] == "trn_val"].values
df["split_id"] = 0

model = "danet"

path_models = f"{path}/models/46_trn_val_tst"

df_folds = {f"fold_{fold_idx:04d}": pd.DataFrame(index=samples_trn_val) for fold_idx in range(5)}

for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_cv_ids = pd.read_excel(f"{path_head}/cv_ids.xlsx", index_col=0)
    for fold_idx in range(5):
        val_ids = df_cv_ids.index[df_cv_ids[f"fold_{fold_idx:04d}"] == "val"].values
        df.loc[val_ids, "split_id"] = fold_idx + 1

df_fig = df.loc[samples_trn_val, ["Age", "split_id"]].copy()
df_fig.rename(columns={'split_id': 'Split'}, inplace=True)
df_fig.to_excel(f"{path}/figure_splits/df_fig.xlsx", index_label="index")

hist_bins = np.linspace(0, 110, 12)

palette = {x: px.colors.qualitative.Plotly[x+4] for x in range(1, 6)}

fig = plt.figure()
sns.set_theme(style='ticks', font_scale=1.3)
hist = sns.histplot(
    data=df_fig,
    hue_order=list(range(1, 6))[::-1],
    bins=hist_bins,
    x="Age",
    hue="Split",
    edgecolor='black',
    palette=palette,
    multiple="stack",
)
hist.set(xlim=(0, 110))
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path}/figure_splits/hist.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_splits/hist.pdf", bbox_inches='tight')
plt.close(fig)

# 9. SimAge plots

In [None]:
pathlib.Path(f"{path}/figure_simage").mkdir(parents=True, exist_ok=True)

palette = {
    "Train": 'chartreuse',
    "Validation": 'cornflowerblue',
    "Test Controls": "orange",
    "Test ESRD": "crimson"
}

part_names = {
    "trn": 'Train',
    "val": 'Validation',
    "tst_ctrl_central": "Test Controls",
    "tst_esrd": "Test ESRD"
}

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
df_dead_alive = pd.read_excel(f"{path}/origin/df_samples_dead_or_alive.xlsx", index_col=0)
df.loc[:, "Dead_Alive"] = "Alive"
ids_dead = df_dead_alive.index[(df_dead_alive["Status"] == "ESRD") & (df_dead_alive["Dead_Alive"] == "Dead")].values
df.loc[ids_dead, "Dead_Alive"] = "Dead"

feats = pd.read_excel(f"{path}/feats_con.xlsx", index_col=0).index.values
df = df.loc[:, list(feats) + ["Age", "Sex", "Status", "Dead_Alive"]]

df_pred = pd.read_excel(f"{path}/models/10_trn_val_tst/widedeep_ft_transformer_trn_val_tst/multiruns/2023-05-07_19-40-40_1337/64/predictions.xlsx", index_col=0)
df_pred["Age acceleration"] = df_pred["Prediction"] - df_pred["Age"]
part_col = df_pred.columns[0]

df.loc[df.index.values, "Part"] = df_pred.loc[df.index.values, part_col]
df.loc[df.index.values, "Prediction"] = df_pred.loc[df.index.values, "Prediction"]
df.loc[df.index.values, "Age acceleration"] = df_pred.loc[df.index.values, "Age acceleration"]
#indexes_filtered = df.index[(df["Part"].isin(["trn", "val", "tst_ctrl_central"])) | ((df["Part"] == "tst_esrd") & (df["Dead_Alive"] == "Dead") & (df["Age acceleration"] > -30))].values
indexes_filtered = df.index[(df["Part"].isin(["trn", "val", "tst_ctrl_central"])) | ((df["Part"] == "tst_esrd") & (df["Dead_Alive"] == "Dead") & (df["Age acceleration"] > -30)  & (df["Age acceleration"] < 100))].values

df_dead_alive = df.loc[df["Part"] == "tst_esrd", :].copy()
df_dead_alive = df_dead_alive.loc[(df_dead_alive["Age acceleration"] < 100) & (df_dead_alive["Age acceleration"] > -30), :]
df_dead_alive.to_excel(f"{path}/figure_simage/dead_alive/df.xlsx", index_label="index")

df = df.loc[indexes_filtered, :]
df["Part"].replace(part_names, inplace=True)
df.rename(columns={"Part": "Dataset"}, inplace=True)

plt.figure()
sns.set_theme(style='ticks')
xy_min = df[["Age", 'Prediction']].min().min() - 7
xy_max = df[["Age", 'Prediction']].max().max()
xy_ptp = xy_max - xy_min
plt.gca().plot(
    [xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
    [xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
    color='k',
    linestyle='dotted',
    linewidth=1
)
scatter = sns.scatterplot(
    data=df,
    x="Age",
    y="Prediction",
    hue="Dataset",
    palette=palette,
    linewidth=0.2,
    alpha=0.75,
    edgecolor="k",
    s=16,
    hue_order=list(palette.keys())
)
scatter.set_xlabel("Age")
scatter.set_ylabel("SImAge")
sns.despine(left=False, right=True, bottom=False, top=True)
scatter.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
scatter.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
plt.gca().set_aspect('equal', adjustable='box')
plt.savefig(f"{path}/figure_simage/scatter.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_simage/scatter.pdf", bbox_inches='tight')
plt.close()
df.to_excel(f"{path}/figure_simage/df.xlsx", index_label="index")

plt.figure()
sns.set_theme(style='ticks')
violin = sns.violinplot(
    data=df,
    x="Dataset",
    y='Age acceleration',
    palette=palette,
    scale='width',
    order=list(palette.keys()),
    saturation=0.75,
)
violin.set_xlabel("")
violin.set_ylabel("SImAge acceleration")
violin.axhline(0.00, color='k', linestyle=':', linewidth=0.5)
sns.despine(left=False, right=True, bottom=False, top=True)
pval = mannwhitneyu(
    df.loc[df['Dataset'] == "Test Controls", 'Age acceleration'].values,
    df.loc[df['Dataset'] == "Test ESRD", 'Age acceleration'].values,
    alternative="two-sided"
).pvalue
pval_formatted = [f'p-value: {pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[("Test Controls", "Test ESRD")],
    data=df,
    x='Dataset',
    y='Age acceleration',
    order=list(palette.keys())
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path}/figure_simage/violin.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_simage/violin.pdf", bbox_inches='tight')
plt.close()

palette_dead_alive = {"Alive": "olive", "Dead": "crimson"}
plt.figure()
sns.set_theme(style='ticks')
violin = sns.violinplot(
    data=df_dead_alive,
    x="Dead_Alive",
    y='Age acceleration',
    scale='width',
    saturation=0.75,
    palette=palette_dead_alive,
    order=("Alive", "Dead"),
)
violin.axhline(0.00, color='k', linestyle=':', linewidth=1)
violin.set_xlabel("")
violin.set_ylabel("SImAge acceleration")
sns.despine(left=False, right=True, bottom=False, top=True)
pval = mannwhitneyu(
    df_dead_alive.loc[df_dead_alive['Dead_Alive'] == "Alive", 'Age acceleration'].values,
    df_dead_alive.loc[df_dead_alive['Dead_Alive'] == "Dead", 'Age acceleration'].values,
    alternative="two-sided"
).pvalue
pval_formatted = [f'p-value: {pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[("Alive", "Dead")],
    data=df_dead_alive,
    x='Dead_Alive',
    y='Age acceleration',
    order=("Alive", "Dead")
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path}/figure_simage/dead_alive/violin_global.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_simage/dead_alive/violin_global.pdf", bbox_inches='tight')
plt.close()

hist_bins = np.linspace(0, 110, 12)
plt.figure()
sns.set_theme(style='ticks')
hist = sns.histplot(
    data=df_dead_alive,
    x=f"Age",
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    hue_order=list(palette_dead_alive.keys()),
    hue="Dead_Alive",
    palette=palette_dead_alive,
    multiple="stack",
)
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path}/figure_simage/dead_alive/hist.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_simage/dead_alive/hist.pdf", bbox_inches='tight')
plt.clf()

plt.figure()
df_dead_alive_old = df_dead_alive.loc[df_dead_alive["Age"] > 60, :]
df_dead_alive_old.to_excel(f"{path}/figure_simage/dead_alive/df_old.xlsx", index_label="index")
sns.set_theme(style='ticks')
violin = sns.violinplot(
    data=df_dead_alive_old,
    x="Dead_Alive",
    y='Age acceleration',
    scale='width',
    saturation=0.75,
    palette=palette_dead_alive,
    order=("Alive", "Dead"),
)
violin.set_xlabel("")
violin.axhline(0.00, color='k', linestyle=':', linewidth=1)
violin.set_ylabel("SImAge acceleration")
sns.despine(left=False, right=True, bottom=False, top=True)
pval = mannwhitneyu(
    df_dead_alive_old.loc[df_dead_alive_old['Dead_Alive'] == "Alive", 'Age acceleration'].values,
    df_dead_alive_old.loc[df_dead_alive_old['Dead_Alive'] == "Dead", 'Age acceleration'].values,
    alternative="two-sided"
).pvalue
pval_formatted = [f'p-value: {pval:.2e}']
annotator = Annotator(
    violin,
    pairs=[("Alive", "Dead")],
    data=df_dead_alive_old,
    x='Dead_Alive',
    y='Age acceleration',
    order=("Alive", "Dead")
)
annotator.set_custom_annotations(pval_formatted)
annotator.configure(loc='outside')
annotator.annotate()
plt.savefig(f"{path}/figure_simage/dead_alive/violin_old.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_simage/dead_alive/violin_old.pdf", bbox_inches='tight')
plt.close()

# 10. SImAge SHAP plots

In [None]:
pathlib.Path(f"{path}/figure_shap").mkdir(parents=True, exist_ok=True)

ids_local = {
    "A3": "A",
    "L40": "B",
    "166": "C",
    "F2b-F1b-L1/L2": "D",
    "H24": "E",
    "H69": "F",
}

df = pd.read_excel(f"{path}/figure_simage/df.xlsx", index_col=0)
df["AbsError"] = df["Age acceleration"].abs()
mae = 6.94
ids_esrd = df.index[df["Status"] == "ESRD"].values
print(len(ids_esrd))
ids_lft = df.index[(df["Status"] == "Control") & (df["Age acceleration"] < -mae)].values
print(len(ids_lft))
ids_ctr = df.index[(df["Status"] == "Control") & (df["Age acceleration"] < mae) & (df["Age acceleration"] > -mae)].values
print(len(ids_ctr))
ids_rgt = df.index[(df["Status"] == "Control") & (df["Age acceleration"] > mae)].values
print(len(ids_rgt))

df["Part"] = ""
df.loc[ids_lft, "Part"] = "Controls:  Acc < -MAE"
df.loc[ids_ctr, "Part"] = "Controls: |Acc| < MAE"
df.loc[ids_rgt, "Part"] = "Controls:  Acc > MAE"
df.loc[ids_esrd, "Part"] = "ESRD"

parts = {
    "ESRD": ids_esrd,
    "Controls:  Acc < -MAE": ids_lft,
    "Controls: |Acc| < MAE": ids_ctr,
    "Controls:  Acc > MAE": ids_rgt,
}

palette = {
    "ESRD": "crimson",
    "Controls:  Acc < -MAE": 'cyan',
    "Controls: |Acc| < MAE": 'lime',
    "Controls:  Acc > MAE": "gold",
}
hue_order = list(palette.keys())[::-1],

n_bins_ctr = 5
binwidth = 2 * mae / (n_bins_ctr - 1)
bins_to_lft = abs(df["Age acceleration"].min()) // binwidth + 1
hist_min = -binwidth * bins_to_lft
bins_to_rgt = abs(df["Age acceleration"].max()) // binwidth + 1
hist_max = binwidth * bins_to_rgt
hist_n_bins = bins_to_lft + bins_to_rgt

plt.figure(figsize=(10, 4))
sns.set_theme(style='ticks')
hist = sns.histplot(
    data=df,
    x=f"Age acceleration",
    bins=hist_n_bins,
    binrange=(hist_min, hist_max),
    binwidth=binwidth,
    edgecolor='k',
    linewidth=1,
    hue_order=list(palette.keys()),
    hue="Part",
    palette=palette,
    multiple="stack",
)
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path}/figure_shap/hist.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_shap/hist.pdf", bbox_inches='tight')
plt.clf()
df_fig = df.loc[:, ["Age", "Prediction", "Age acceleration"]].copy()
df_fig.to_excel(f"{path}/figure_shap/hist.xlsx", index_label="index")

path_shap = f"{path}/models/10_shap/widedeep_ft_transformer_inference/runs/2023-05-08_15-51-52/shap/all"
df_shap = pd.read_excel(f"{path_shap}/shap.xlsx", index_col=0)
expval = pd.read_excel(f"{path_shap}/expected_value.xlsx", index_col=0).iloc[0,0]
feats = df_shap.columns.values

for f in feats:
    plt.figure()
    sns.set_theme(style='ticks')
    scatter = sns.scatterplot(
        data=df,
        x="Age",
        y=f,
        hue="Part",
        palette=palette,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=16,
        hue_order=list(palette.keys())
    )
    sns.despine(left=False, right=True, bottom=False, top=True)
    plt.savefig(f"{path}/figure_shap/feats/{f}.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path}/figure_shap/feats/{f}.pdf", bbox_inches='tight')
    plt.close()

df_fi = pd.DataFrame(index=feats, columns=list(parts.keys()))
for p in parts:
    shap.summary_plot(
        shap_values=df_shap.loc[parts[p], feats].values,
        features=df.loc[parts[p], feats].values,
        feature_names=feats,
        max_display=len(feats),
        plot_type="violin",
        show=False,
    )
    plt.savefig(f"{path}/figure_shap/violin_{palette[p]}.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path}/figure_shap/violin_{palette[p]}.pdf", bbox_inches='tight')
    plt.close()

    for f in feats:
        df_fi.at[f, p] = np.mean(np.abs(df_shap.loc[parts[p], f].values))
    df_fi.sort_values([p], ascending=[False], inplace=True)
    df_fig = df_fi.loc[:, [p]].copy()
    df_fig.rename(columns={p: "Mean(|SHAP values|)"}, inplace=True)

    plt.figure(figsize=(2, 4))
    sns.set_theme(style='ticks', font_scale=1)
    bar = sns.barplot(
        data=df_fig,
        y=df_fig.index,
        x="Mean(|SHAP values|)",
        edgecolor='black',
        orient="h",
        dodge=False,
        color=palette[p],
    )
    bar.set_xlabel("Mean(|SHAP values|)")
    bar.set_ylabel("")
    sns.despine(left=False, right=True, bottom=False, top=True)
    plt.savefig(f"{path}/figure_shap/bar_{palette[p]}.png", bbox_inches='tight', dpi=400)
    plt.savefig(f"{path}/figure_shap/bar_{palette[p]}.pdf", bbox_inches='tight')
    plt.close()
df_fi.to_excel(f"{path}/figure_shap/fi.xlsx", index_label="features")

for idl in ids_local:
    real = df.at[idl, "Age"]
    pred = df.at[idl, "Prediction"]
    diff = df.at[idl, "Age acceleration"]

    id_save = idl
    if isinstance(idl, str):
        id_save = slugify(idl)

    shap.plots.waterfall(
        shap.Explanation(
            values=df_shap.loc[idl, feats].values,
            base_values=expval,
            data=df.loc[idl, feats].values,
            feature_names=feats
        ),
        max_display=10,
        show=False,
    )
    fig = plt.gcf()
    fig.text(0.01, 0.99, f"Age = {real:0.2f}", fontsize=20)
    fig.text(0.01, 0.94, f"SImAge = {pred:0.2f}", fontsize=20)
    fig.text(-0.1, 0.96, ids_local[idl], fontsize=40)
    fig.savefig(f"{path}/figure_shap/local/{id_save}.pdf", bbox_inches='tight')
    fig.savefig(f"{path}/figure_shap/local/{id_save}.png", bbox_inches='tight')
    plt.close()

# 11. Data description

## Hist

In [None]:
pathlib.Path(f"{path}/figure_data_desc").mkdir(parents=True, exist_ok=True)

palette = {
    "Train/Validation": 'cyan',
    "Test Controls": "orange",
    "Test ESRD": "crimson"
}

df = pd.read_excel(f"{path}/figure_simage/df.xlsx", index_col=0)
df.loc[df["Dataset"].isin(["Train", "Validation"]), "Dataset"] = "Train/Validation"

hist_bins = np.linspace(10, 110, 11)
fig = plt.figure()
sns.set_theme(style='ticks', font_scale=1.0)
hist = sns.histplot(
    data=df,
    bins=hist_bins,
    x="Age",
    hue="Dataset",
    edgecolor='black',
    palette=palette,
    multiple="stack",
    linewidth=1,
    hue_order=list(palette.keys())[::-1],
)
hist.set(xlim=(5, 115))
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path}/figure_data_desc/hist.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_data_desc/hist.pdf", bbox_inches='tight')
plt.close(fig)

## Matrix

In [None]:
pathlib.Path(f"{path}/figure_data_desc").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/figure_simage/df.xlsx", index_col=0)
df = df.loc[df['Status'] == 'Control', :]
feats = pd.read_excel(f"{path}/feats_con.xlsx", index_col=0).index.values
feats = ["Age"] + list(feats)

df_corr = pd.DataFrame(data=np.zeros(shape=(len(feats), len(feats))), index=feats, columns=feats)
for f_id_1 in range(len(feats)):
    for f_id_2 in range(f_id_1, len(feats)):
        f_1 = feats[f_id_1]
        f_2 = feats[f_id_2]
        if f_id_1 != f_id_2:
            vals_1 = df.loc[:, f_1].values
            vals_2 = df.loc[:, f_2].values
            corr, pval = stats.pearsonr(vals_1, vals_2)
            df_corr.at[f_2, f_1] = pval
            df_corr.at[f_1, f_2] = corr
        else:
            df_corr.at[f_2, f_1] = np.nan
selection = np.tri(df_corr.shape[0], df_corr.shape[1], -1, dtype=np.bool)
df_fdr = df_corr.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
df_corr_fdr = df_corr.copy()
for line_id in range(df_fdr.shape[0]):
    df_corr_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])

sns.set_theme(style='whitegrid')
df_to_plot = df_corr_fdr.copy()
mtx_to_plot = df_to_plot.to_numpy()

mtx_triu = np.triu(mtx_to_plot, +1)
max_corr = np.max(mtx_triu)
min_corr = np.min(mtx_triu)
mtx_triu_mask = np.ma.masked_array(mtx_triu, mtx_triu==0)
cmap_triu = plt.get_cmap("bwr").copy()

mtx_tril = np.tril(mtx_to_plot, -1)
mtx_tril_mask = np.ma.masked_array(mtx_tril, mtx_tril==0)
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')

fig, ax = plt.subplots()

im_triu = ax.imshow(mtx_triu_mask, cmap=cmap_triu, vmin=-1, vmax=1)
cbar_triu = ax.figure.colorbar(im_triu, ax=ax, location='right', shrink=0.80)
cbar_triu.ax.tick_params(labelsize=8)
cbar_triu.set_label(r"$\mathrm{Correlation\:coefficient}$", horizontalalignment='center', fontsize=8)

im_tril = ax.imshow(mtx_tril_mask, cmap=cmap_tril, vmin=-np.log10(0.05))
cbar_tril = ax.figure.colorbar(im_tril, ax=ax, location='right', shrink=0.80)
cbar_tril.ax.tick_params(labelsize=8)
cbar_tril.set_label(r"$-\log_{10}(\mathrm{p-value})$", horizontalalignment='center', fontsize=8)

ax.grid(None)
ax.set_aspect("equal")
ax.set_xticks(np.arange(df_to_plot.shape[1]))
ax.set_yticks(np.arange(df_to_plot.shape[0]))
ax.set_xticklabels(df_to_plot.columns.values)
ax.set_yticklabels(df_to_plot.index.values)
plt.setp(ax.get_xticklabels(), rotation=90)
threshold = np.ptp(mtx_tril.flatten()) * 0.5
ax.tick_params(axis='both', which='major', labelsize=5)
ax.tick_params(axis='both', which='minor', labelsize=5)
textcolors = ("black", "white")
for i in range(df_to_plot.shape[0]):
    for j in range(df_to_plot.shape[1]):
        color = "black"
        if i > j:
            color = textcolors[int(mtx_tril[i, j] < threshold)]
        if np.isinf(mtx_to_plot[i, j]) or np.isnan(mtx_to_plot[i, j]):
            text = ax.text(j, i, f"", ha="center", va="center", color=color, fontsize=1.4)
        else:
            text = ax.text(j, i, f"{mtx_to_plot[i, j]:0.2f}", ha="center", va="center", color=color, fontsize=1.4)
fig.tight_layout()
plt.savefig(f"{path}/figure_data_desc/corr_mtx_fdr.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_data_desc/corr_mtx_fdr.pdf", bbox_inches='tight', dpi=400)
plt.clf()

df_save = df_corr_fdr
df_save.to_excel(f"{path}/figure_data_desc/corr_mtx_fdr.xlsx", index=True)

# 12. Supplementary table MAE

In [None]:
pathlib.Path(f"{path}/supp_table_mae").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/figure_simage/df.xlsx", index_col=0)

cond_group = {
    "Age < 30": (df["Age"] < 30),
    "30 <= Age < 50": (df["Age"] < 50) & (df["Age"] >= 30),
    "50 <= Age < 70": (df["Age"] < 70) & (df["Age"] >= 50),
    "Age > 70": (df["Age"] > 70),
    "Females": (df["Sex"] == "F"),
    "Males": (df["Sex"] == "M"),
    "All": (df["Sex"].isin(['M', 'F']))
}

cond_dataset = {
    "Train": (df["Dataset"] == "Train"),
    "Validation": (df["Dataset"] == "Validation"),
    "Test Controls": (df["Dataset"] == "Test Controls"),
    "Overall Controls": (df["Dataset"].isin(['Train', 'Validation', 'Test Controls'])),
    "ESRD": (df["Dataset"] == "Test ESRD"),
}

df_mae = pd.DataFrame(index=list(cond_group.keys()), columns=list(cond_dataset.keys()))
for g_name, g_cond in cond_group.items():
    for d_name, d_cond in cond_dataset.items():
        df_local = df[conjunction([g_cond, d_cond])]
        real = df_local.loc[:, "Age"]
        pred = df_local.loc[:, "Prediction"]
        mae = mean_absolute_error(real, pred)
        df_mae.at[g_name, d_name] = mae

df_mae.to_excel(f"{path}/supp_table_mae/df_mae.xlsx", index_label="Group")

# 13. Supplementary table data

In [None]:
pathlib.Path(f"{path}/supp_table_data").mkdir(parents=True, exist_ok=True)
df = pd.read_excel(f"{path}/figure_simage/df.xlsx", index_col=0)
df['Dataset'].replace({'Train': 'Train/Validation', 'Validation': 'Train/Validation'}, inplace=True)
df['new_index'] = 0
parts = {
    'Train/Validation': 'trn_val',
    'Test Controls': 'tst_ctrl',
    'Test ESRD': 'tst_esrd'
}
for part, p_name in parts.items():
    for id_new, id_old in enumerate(df.index[df['Dataset'] == part].values):
        df.at[id_old, 'new_index'] = f"{p_name}_{id_new:03d}"

df_mapping = df.loc[:, ["new_index"]].copy()
df_mapping.to_excel(f"{path}/supp_table_data/df_mapping.xlsx", index_label='old_index')

df.set_index("new_index", inplace=True)
df.drop(columns=['Dead_Alive'], inplace=True)
df.rename(columns={'Prediction': 'SImAge', 'Age acceleration': 'SImAge acceleration'}, inplace=True)
df.to_excel(f"{path}/supp_table_data/df.xlsx", index_label='index')
