In [None]:
import pandas as pd
import numpy as np
from plotly.subplots import make_subplots
from scipy import stats
import plotly.express as px
from scripts.python.routines.plot.scatter import add_scatter_trace
import plotly.graph_objects as go
from scripts.python.routines.plot.save import save_figure
from scripts.python.routines.plot.layout import add_layout, get_axis
from statsmodels.stats.multitest import multipletests
import plotly.io as pio
pio.kaleido.scope.mathjax = None
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=False)
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import seaborn as sns
from glob import glob
import pathlib
from sklearn.metrics import mean_absolute_error
from scipy import stats
import patchworklib as pw
import os
import functools


def conjunction(conditions):
    return functools.reduce(np.logical_and, conditions)


def disjunction(conditions):
    return functools.reduce(np.logical_or, conditions)

# 0. Setup

In [None]:
path = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/044_small_immuno_clocks_revision"
pathlib.Path(f"{path}").mkdir(parents=True, exist_ok=True)

# 1. Prepare additional test data

In [None]:
df_origin = pd.read_excel(f"{path}/data_origin/260_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
df_all = pd.read_excel(f"{path}/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
indexes_not_origin = df_all.index.difference(df_origin.index)
df_all["parts_danet"] = df_all["Region"].str.cat(df_all[["Status"]].astype(str), sep="_")
df_all["Split"] = "tst"

df_new = pd.concat([df_origin, df_all.loc[indexes_not_origin, :]])
df_new.to_excel(f"{path}/all_for_test.xlsx", index_label="index")

# 2. Create new dataset

In [None]:
df_origin = pd.read_excel(f"{path}/origin/260_imp(fast_knn)_replace(quarter).xlsx", index_col=0)
df_all = pd.read_excel(f"{path}/origin/df_samples(all_1052_121222)_proc(raw)_imp(fast_knn)_replace(quarter).xlsx", index_col=0)

df_tst_central_include = pd.read_excel(f"{path}/origin/samples_test_slctd.xlsx", index_col=0)
df_tst_central_include["index"] = df_tst_central_include.index.values
df_tst_central_include["index"] = df_tst_central_include["index"].str.rstrip('_copy')
indexes_test_ctrl_central = df_tst_central_include["index"].values
print(len(indexes_test_ctrl_central))
df_test_ctrl_central = df_all.loc[indexes_test_ctrl_central, :].copy()
df_test_ctrl_central["parts_danet"] = 'tst_ctrl_central'
df_test_ctrl_central["Split"] = 'tst_ctrl_central'

df_res = pd.read_excel(f"{path}/origin/models/danet_inference/runs/2023-04-12_12-16-05/df.xlsx", index_col=0)
indexes_test_ctrl_yakutia = df_res.index[(df_res["parts_danet"] == "Yakutia_Control")].values
print(len(indexes_test_ctrl_yakutia))
df_test_ctrl_yakutia = df_all.loc[indexes_test_ctrl_yakutia, :].copy()
df_test_ctrl_yakutia["parts_danet"] = 'tst_ctrl_yakutia'
df_test_ctrl_yakutia["Split"] = 'tst_ctrl_yakutia'

indexes_test_esrd = df_all.index[df_all["Status"] == "ESRD"].values
print(len(indexes_test_esrd))
df_test_esrd = df_all.loc[indexes_test_esrd, :]
df_test_esrd["parts_danet"] = 'tst_esrd'
df_test_esrd["Split"] = 'tst_esrd'

df_new = pd.concat([df_origin, df_test_ctrl_central, df_test_ctrl_yakutia, df_test_esrd])

df_dead_alive = pd.read_excel(f"{path}/origin/df_samples_dead_or_alive.xlsx", index_col=0)
inds_dead_alive = df_dead_alive.index[df_dead_alive["Dead_Alive"] == "Dead"].values
df_new.loc[inds_dead_alive, "Dead_Alive"] = "Dead"

df_new.to_excel(f"{path}/data_wtf.xlsx", index_label="index")

# 3. Collect ML results

In [None]:
model = 'danet_trn_val_tst'

part_check = "tst_ctrl_central"
part_check_thld_mean = 7.5
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_tst_ctrl_central = df.index[df["Split"] == "tst_ctrl_central"].values

path_runs = f"{path}/models/{model}/multiruns"

files = glob(f"{path_runs}/*/*/metrics_all_best_*.xlsx")

df_tmp = pd.read_excel(files[0], index_col="metric")
head, tail = os.path.split(files[0])
cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
params = []
for param_pair in cfg:
    param, val = param_pair.split('=')
    params.append(param)
df_res = pd.DataFrame(index=files)

for file in files:

    head, tail = os.path.split(file)
    df_pred = pd.read_excel(f"{head}/predictions.xlsx", index_col=0)
    df_pred["Error"] = df_pred["Prediction"] - df_pred["Age"]
    df_pred["AbsError"] = df_pred["Error"].abs()
    part_col = df_pred.columns[0]
    df_pred_tst_ctrl_central = df_pred.loc[df_pred[part_col] == part_check, :].copy()
    df_pred_tst_ctrl_central.sort_values(["AbsError"], ascending=[True], inplace=True)
    df_pred_tst_ctrl_central["MeanAbsErrorExpanding"] = df_pred_tst_ctrl_central["AbsError"].expanding().mean()
    samples_tst_ctrl_central_passed = df_pred_tst_ctrl_central.index[df_pred_tst_ctrl_central["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    n_samples_passed = len(samples_tst_ctrl_central_passed)

    df_res.at[file, "passed_test_samples"] = n_samples_passed

    # tst_ctrl_central
    y_real = df_pred.loc[samples_tst_ctrl_central, "Age"]
    y_pred = df_pred.loc[samples_tst_ctrl_central, "Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[file, 'mean_absolute_error_tst_ctrl_central'] = mae_tst
    df_res.at[file, 'pearson_corr_coef_tst_ctrl_central'] = rho_tst

    # Metrics
    df_metrics = pd.read_excel(file, index_col="metric")
    for metric in df_metrics.index.values:
        df_res.at[file, metric + "_val"] = df_metrics.at[metric, "val"]
        df_res.at[file, metric + "_trn"] = df_metrics.at[metric, "trn"]
        df_res.at[file, metric + "_tst_ctrl_central"] = df_metrics.at[metric, "tst_ctrl_central"]
        df_res.at[file, metric + "_trn_val_tst_ctrl_central"] = df_metrics.at[metric, "trn_val_tst_ctrl_central"]
        df_res.at[file, metric + "_val_tst_ctrl_central"] = df_metrics.at[metric, "val_tst_ctrl_central"]

    # Params
    cfg = OmegaConf.load(f"{head}/.hydra/overrides.yaml")
    for param_pair in cfg:
        param, val = param_pair.split('=')
        df_res.at[file, param] = val

df_res["train_more_val"] = False
df_res["selected"] = False
df_res.loc[df_res["mean_absolute_error_trn"] > df_res["mean_absolute_error_val"], "train_more_val"] = True

first_columns = [
    'selected',
    'passed_test_samples',
    'train_more_val',
    'mean_absolute_error_trn',
    'mean_absolute_error_val',
    'mean_absolute_error_tst_ctrl_central',
    'mean_absolute_error_val_tst_ctrl_central',
    'mean_absolute_error_trn_val_tst_ctrl_central',
    'pearson_corr_coef_trn',
    'pearson_corr_coef_val',
    'pearson_corr_coef_tst_ctrl_central',
    'pearson_corr_coef_val_tst_ctrl_central',
    'pearson_corr_coef_trn_val_tst_ctrl_central',
    'mean_absolute_error_cv_mean_trn',
    'mean_absolute_error_cv_std_trn',
    'mean_absolute_error_cv_mean_val',
    'mean_absolute_error_cv_std_val',
    'pearson_corr_coef_cv_mean_trn',
    'pearson_corr_coef_cv_std_trn',
    'pearson_corr_coef_cv_mean_val',
    'pearson_corr_coef_cv_std_val',
]
df_res = df_res[first_columns + [col for col in df_res.columns if col not in first_columns]]
df_res.to_excel(f"{path_runs}/summary.xlsx", index=True, index_label="file")

# 4. Decider for central controls

In [None]:
part_check = "tst_ctrl_all"
part_check_thld_mean = 8.43

df = pd.read_excel(f"{path}/data.xlsx", index_col=0)
samples_test = df.index[df["Split"] == "tst_ctrl_all"].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

models_main = [
    "danet",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"

df_res = pd.DataFrame(index=models_all)
df_samples_test = pd.DataFrame(index=samples_test, columns=models_all)

for m in models_all:
    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)
    path_head = path_head.replace('/46/', '/46_trn_val_tst/', 1)

    file_val = glob(f"{path_head}/metrics_val_best_*.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)

    df_res.at[m, 'val_mae_best'] = df_res_val.at['mean_absolute_error', 'val']
    df_res.at[m, 'val_mae_mean'] = df_res_val.at['mean_absolute_error_cv_mean', 'val']
    df_res.at[m, 'val_mae_std'] = df_res_val.at['mean_absolute_error_cv_std', 'val']
    df_res.at[m, 'val_rho_best'] = df_res_val.at['pearson_corr_coef', 'val']
    df_res.at[m, 'val_rho_mean'] = df_res_val.at['pearson_corr_coef_cv_mean', 'val']
    df_res.at[m, 'val_rho_std'] = df_res_val.at['pearson_corr_coef_cv_std', 'val']

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    df_pred["Error"] = df_pred["Prediction"] - df_pred["Age"]
    df_pred["AbsError"] = df_pred["Error"].abs()
    part_col = df_pred.columns[0]
    df_pred[part_col].replace({"tst_ctrl_central": part_check}, inplace=True)
    df_pred = df_pred.loc[df_pred[part_col] == part_check, :]
    df_pred.sort_values(["AbsError"], ascending=[True], inplace=True)
    df_pred["MeanAbsErrorExpanding"] = df_pred["AbsError"].expanding().mean()
    samples_passed = df_pred.index[df_pred["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    df_samples_test.loc[:, m] = 0
    df_samples_test.loc[samples_passed, m] = 1
    n_samples_passed = len(samples_passed)
    print(f"{m}: {n_samples_passed}")

df_samples_test.to_excel(f"{path_models}/samples_test_full.xlsx", index_label="model")

conditions = [df_samples_test[m] == 1 for m in models_main]
df_samples_test = df_samples_test[conjunction(conditions)]
samples_test_final = df_samples_test.index.values
print(len(samples_test_final))

for m in models_all:
    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)
    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    df_pred = df_pred.loc[samples_test_final, :]
    y_real = df_pred["Age"]
    y_pred = df_pred["Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[m, 'tst_mae'] = mae_tst
    df_res.at[m, 'tst_rho'] = rho_tst

df_res.to_excel(f"{path_models}/baseline_results.xlsx", index_label="model")
df_samples_test.to_excel(f"{path_models}/samples_test_slctd.xlsx", index_label="model")

# 5. Updating data with trn/val splits from best models

In [None]:
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"
for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_pred = pd.read_excel(f"{path_head}/predictions.xlsx", index_col=0)
    part_col = df_pred.columns[0]
    ids_trn = df_pred.index[df_pred[part_col] == "trn"].values
    df.loc[ids_trn, f"best_{m}"] = "trn"
    ids_val = df_pred.index[df_pred[part_col] == "val"].values
    df.loc[ids_val, f"best_{m}"] = "val"

df.to_excel(f"{path}/data_final1.xlsx", index_label="index")

# 6. Inference ckecking and yakutia thresholding

In [None]:
part_check = "tst_ctrl_yakutia"
part_check_thld_mean = 99999

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_test = df.index[df["Split"] == part_check].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

models_main = [
    "danet",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_inference"

df_res = pd.DataFrame(index=models_all)
df_samples_test = pd.DataFrame(index=samples_test, columns=models_all)

for m in models_all:
    file_val = glob(f"{path_models}/{m}_inference/runs/*/metrics_val.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)
    df_res.at[m, 'val_mae'] = df_res_val.at['mean_absolute_error', 'val']
    df_res.at[m, 'val_rho'] = df_res_val.at['pearson_corr_coef', 'val']

    file_tst_central = glob(f"{path_models}/{m}_inference/runs/*/metrics_tst_ctrl_central.xlsx")[0]
    df_res_tst_central = pd.read_excel(file_tst_central, index_col=0)
    df_res.at[m, 'tst_central_mae'] = df_res_tst_central.at['mean_absolute_error', 'tst_ctrl_central']
    df_res.at[m, 'tst_central_rho'] = df_res_tst_central.at['pearson_corr_coef', 'tst_ctrl_central']

    file_val = glob(f"{path_models}/{m}_inference/runs/*/metrics_tst_esrd.xlsx")[0]
    df_res_val = pd.read_excel(file_val, index_col=0)
    df_res.at[m, 'tst_esrd_mae'] = df_res_val.at['mean_absolute_error', 'tst_esrd']
    df_res.at[m, 'tst_esrd_rho'] = df_res_val.at['pearson_corr_coef', 'tst_esrd']

    file_pred = glob(f"{path_models}/{m}_inference/runs/*/df.xlsx")[0]
    df_pred = pd.read_excel(file_pred, index_col=0)
    df_pred = df_pred.loc[df_pred[f"best_{m}"] == part_check, :]
    df_pred.sort_values(["Prediction error abs"], ascending=[True], inplace=True)
    df_pred["MeanAbsErrorExpanding"] = df_pred["Prediction error abs"].expanding().mean()
    samples_passed = df_pred.index[df_pred["MeanAbsErrorExpanding"] < part_check_thld_mean].values
    df_samples_test.loc[:, m] = 0
    df_samples_test.loc[samples_passed, m] = 1
    n_samples_passed = len(samples_passed)
    print(f"{m}: {n_samples_passed}")

df_samples_test.to_excel(f"{path_models}/samples_test_full.xlsx", index_label="model")

conditions = [df_samples_test[m] == 1 for m in models_main]
df_samples_test = df_samples_test[conjunction(conditions)]
samples_test_final = df_samples_test.index.values
print(len(samples_test_final))

for m in models_all:
    file_pred = glob(f"{path_models}/{m}_inference/runs/*/df.xlsx")[0]
    df_pred = pd.read_excel(file_pred, index_col=0)
    df_pred = df_pred.loc[samples_test_final, :]
    y_real = df_pred["Age"]
    y_pred = df_pred["Prediction"]
    mae_tst = mean_absolute_error(y_real, y_pred)
    rho_tst = stats.pearsonr(y_real, y_pred).statistic
    df_res.at[m, 'tst_yakutia_mae'] = mae_tst
    df_res.at[m, 'tst_yakutia_rho'] = rho_tst

df_res.to_excel(f"{path_models}/baseline_results_{part_check_thld_mean}.xlsx", index_label="model")
df_samples_test.to_excel(f"{path_models}/samples_test_slctd_{part_check_thld_mean}.xlsx", index_label="model")

# 7. Feature selection and dimensionality reduction via SHAP

In [None]:
df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
df_trn_val = df.loc[df["Split"] == "trn_val", :]
feats = pd.read_excel(f"{path}/feats_con.xlsx", index_col=0).index.values
path_models = f"{path}/models/46_shap"

models = {
    "danet": "DANet",
    "widedeep_tab_net": "TabNet",
    "widedeep_saint": "SAINT",
    "widedeep_ft_transformer": "FT-Transformer"
}


df_fi = pd.DataFrame(index=feats)

for m in models:
    file_shap = glob(f"{path_models}/{m}_inference/runs/*/shap/trn_val/shap.xlsx")[0]
    df_shap = pd.read_excel(file_shap, index_col=0)

    for f in feats:
        df_fi.at[f, models[m]] = df_shap[f].abs().mean()

df_fi['Summary'] = df_fi.sum(axis=1)
df_fi.sort_values(['Summary'], ascending=[False], inplace=True)
df_fi.to_excel(f"{path_models}/fi.xlsx", index_label="features")

feats_top10 = df_fi.index.values[0:10]
df_fig = df_fi.loc[:, list(models.values())]

sns.set_theme(style='ticks', font_scale=3)
df_fig.plot(kind='bar', stacked=True, color=px.colors.qualitative.D3, figsize=(34, 10), edgecolor='black')
plt.xticks(rotation=90)
plt.xlabel('Features')
plt.ylabel('Mean(|SHAP values|)')
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path_models}/fi.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path_models}/fi.pdf", bbox_inches='tight')

pw.overwrite_axisgrid()
sns.set_theme(style='ticks', font_scale=3)
ax1 = pw.Brick(figsize=(30, 10))
tmp = df_fig.plot(kind='bar', stacked=True, color=px.colors.qualitative.D3, ax=ax1, edgecolor='black')
sns.despine(left=False, right=True, bottom=False, top=True, ax=ax1)
ax1.set_xlabel('')
ax1.set_ylabel('Mean(|SHAP values|)')
ax1.text(-4, 45,'A', fontsize=72)
ax1.text(-4, -15,'B', fontsize=72)
feats_plot = ["Age"] + list(feats_top10)
sns.set_theme(style="ticks", font_scale=1.5)
g1 = sns.PairGrid(df_trn_val, vars=feats_plot)
g1.map_diag(plt.hist, bins=15, color='darkred', edgecolor='k')
g1.map_upper(sns.scatterplot, color='darkred', s=10, alpha=0.5, edgecolor='k', linewidth=0.2)
g1.map_lower(sns.kdeplot, fill=True, cbar=False, cmap='rocket_r', thresh=-0.1)
g1.add_legend()
for ax in g1.axes.flatten():
    ax.get_yaxis().set_label_coords(-0.5, 0.5)
ax2 = pw.load_seaborngrid(g1, figsize=(30, 30))
(ax1/ax2).savefig(f"{path_models}/together.pdf", bbox_inches='tight')
(ax1/ax2).savefig(f"{path_models}/together.png", bbox_inches='tight', dpi=400)

# 8. trn/val split

## 8.1 Check trn/val identity

In [None]:
pathlib.Path(f"{path}/figure_splits").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_trn_val = df.index[df["Split"] == "trn_val"].values

models_all = [
    "elastic_net",
    "xgboost",
    "lightgbm",
    "catboost",
    "widedeep_tab_mlp",
    "nam",
    "nbm_spam_nam",
    "pytorch_tabular_node",
    "danet",
    "widedeep_tab_net",
    "pytorch_tabular_autoint",
    "widedeep_saint",
    "widedeep_ft_transformer"
]

path_models = f"{path}/models/46_trn_val_tst"

df_folds = {f"fold_{fold_idx:04d}": pd.DataFrame(index=samples_trn_val) for fold_idx in range(5)}

for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_cv_ids = pd.read_excel(f"{path_head}/cv_ids.xlsx", index_col=0)
    for fold_idx in range(5):
        df_folds[f"fold_{fold_idx:04d}"].loc[samples_trn_val, m] = df_cv_ids.loc[samples_trn_val, f"fold_{fold_idx:04d}"]

for fold_idx in range(5):
    df_folds[f"fold_{fold_idx:04d}"]['matching'] = df_folds[f"fold_{fold_idx:04d}"].eq(df_folds[f"fold_{fold_idx:04d}"].iloc[:, 0], axis=0).all(1)
    df_folds[f"fold_{fold_idx:04d}"].to_excel(f"{path}/figure_splits/fold_{fold_idx:04d}.xlsx", index_label="index")
    n_matches = df_folds[f"fold_{fold_idx:04d}"].index[df_folds[f"fold_{fold_idx:04d}"]['matching'] == True].values.shape[0]
    print(n_matches)

# 8.2 Plot split histograms

In [None]:
pathlib.Path(f"{path}/figure_splits").mkdir(parents=True, exist_ok=True)

df = pd.read_excel(f"{path}/data_final.xlsx", index_col=0)
samples_trn_val = df.index[df["Split"] == "trn_val"].values
df["split_id"] = 0

model = "danet"

path_models = f"{path}/models/46_trn_val_tst"

df_folds = {f"fold_{fold_idx:04d}": pd.DataFrame(index=samples_trn_val) for fold_idx in range(5)}

for m in models_all:
    df[f"best_{m}"] = df["Split"]

    df_summary = pd.read_excel(f"{path_models}/{m}_trn_val_tst/multiruns/summary.xlsx", index_col=0)
    files_slctd = df_summary.index[df_summary["selected"] == True].values
    if len(files_slctd) != 1:
        raise ValueError(f"{m} model selection error")
    file_slctd = files_slctd[0]
    path_head, _ = os.path.split(file_slctd)

    path_head = path_head.replace('models/46/', 'models/46_trn_val_tst/', 1)

    df_cv_ids = pd.read_excel(f"{path_head}/cv_ids.xlsx", index_col=0)
    for fold_idx in range(5):
        val_ids = df_cv_ids.index[df_cv_ids[f"fold_{fold_idx:04d}"] == "val"].values
        df.loc[val_ids, "split_id"] = fold_idx + 1

df_fig = df.loc[samples_trn_val, ["Age", "split_id"]].copy()
df_fig.rename(columns={'split_id': 'Split'}, inplace=True)
df_fig.to_excel(f"{path}/figure_splits/df_fig.xlsx", index_label="index")

hist_bins = np.linspace(0, 110, 12)

palette = {x: px.colors.qualitative.Plotly[x+4] for x in range(1, 6)}

fig = plt.figure()
sns.set_theme(style='ticks', font_scale=1.3)
hist = sns.histplot(
    data=df_fig,
    hue_order=list(range(1, 6))[::-1],
    bins=hist_bins,
    x="Age",
    hue="Split",
    edgecolor='black',
    palette=palette,
    multiple="stack",
)
hist.set(xlim=(0, 110))
sns.despine(left=False, right=True, bottom=False, top=True)
plt.savefig(f"{path}/figure_splits/hist.png", bbox_inches='tight', dpi=400)
plt.savefig(f"{path}/figure_splits/hist.pdf", bbox_inches='tight')
plt.close(fig)