# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [1]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import itertools
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
import sklearn.metrics
from scipy import stats
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import optuna
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.cm
import matplotlib as mpl
from statsmodels.stats.multitest import multipletests
import shap
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy
import scipy.stats


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

def form_bar(base):
    def formatter(x):
        return f'{str(int(round(x * base)))}/{base}'
    return formatter


# Load data

In [2]:
epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(knn)_params(5)' # 'origin' 'imp_source(imm)_method(knn)_params(5)' 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'mrmr' # 'f_regression' 'mrmr'
n_feats = 100

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 20

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337
val_fold_id = 5

path = f"E:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}/EpImAge"
path_epi = f"E:/YandexDisk/Work/bbd/immunology/003_EpImAge/epi"
path_configs = "E:/Git/bbs/notebooks/immunology/003_EpImAge/age_regression_configs"

data_full = pd.read_excel(f"{path}/data.xlsx", index_col=0)

# Filtering data
status_count = data_full['Status'].value_counts()
statuses_passed = status_count[status_count >= 10].index.values.tolist()
data_full = data_full[data_full['Status'].isin(statuses_passed)]
data_full.drop(data_full.index[data_full['Status'] == 'ICU'], inplace=True)
# data_full.to_excel(f"{path}/data_filtered.xlsx", index_label='ID')

status_count = data_full['Status'].value_counts()

feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0).index.values.tolist()
feats = [f"{f}_log" for f in feats]

gse_preproc = pd.read_excel(f"{path_epi}/preproc.xlsx", index_col=0)

df_groups = pd.read_excel(f"{path_epi}/groups.xlsx", index_col=0)
icd_chpts = np.sort(df_groups['ICD-11 chapter'].unique())
icd_codes = np.sort(df_groups['ICD-11 code'].unique())
colors = distinctipy.get_colors(len(icd_chpts), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
colors_icd_chpts = {icd_chpt: colors[icd_chpt_id] for icd_chpt_id, icd_chpt in enumerate(icd_chpts)}
colormaps_icd_chpts = {
    icd_chpt: LinearSegmentedColormap.from_list(
        name=f"ICD-11 Chapter {icd_chpt} cmap",
        colors=[make_rgb_transparent(colors_icd_chpts[icd_chpt], (1, 1, 1), 0.2), colors_icd_chpts[icd_chpt]], N=256
    )
    for icd_chpt in icd_chpts
}
colormap_total = LinearSegmentedColormap.from_list(
    name=f"ICD-11 Total cmap",
    colors=[
        mcolors.hex2color(mcolors.CSS4_COLORS['lavender']),
        mcolors.hex2color(mcolors.CSS4_COLORS['dimgray'])],
    N=256,
)

clocks_tests = pd.read_excel(f"{path_epi}/clocks_tests.xlsx", index_col="Clock Name")
clocks_tests['Year'] = clocks_tests['Year'].astype(str)
clocks_tests.drop(index=['Knight', 'LeeControl', 'LeeRefinedRobust', 'LeeRobust', 'PedBE', 'RepliTali', 'ENCen100'], inplace=True)
clocks_tests_raw = pd.read_excel(f"{path_epi}/clocks_tests_raw_after_correction.xlsx", index_col="Clock Name")

# Load stratification

In [3]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# Train, Validation, Test selection

In [4]:
split_dict = samples[tst_split_id]

test = data_full.loc[split_dict['test'], feats + ["Age"]]
train = data_full.loc[split_dict['trains'][val_fold_id], feats + ["Age"]]
validation = data_full.loc[split_dict['validations'][val_fold_id], feats + ["Age"]]

# Optuna training

## Models setup

In [None]:
seed_target = 1337  # 1337 42 451 1984 1899 1408

models_runs = {
    'GANDALF': {
        'config': GANDALFConfig,
        'n_trials': 1024,
        'seed': seed_target,
        'n_startup_trials': 256,
        'n_ei_candidates': 16
    },
    # 'FTTransformer': {
    #     'config': FTTransformerConfig,
    #     'n_trials': 512,
    #     'seed': 1337,
    #     'n_startup_trials': 256,
    #     'n_ei_candidates': 16
    # },
    # 'DANet': {
    #     'config': DANetConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'CategoryEmbeddingModel': {
    #     'config': CategoryEmbeddingModelConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'TabNetModel': {
    #     'config': TabNetModelConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # }
}

## Training

In [None]:
dfs_models = []

for model_name, model_run in models_runs.items():

    model_config_name = model_run['config']
    n_trials = model_run['n_trials']
    seed = model_run['seed']
    n_startup_trials = model_run['n_startup_trials']
    n_ei_candidates = model_run['n_ei_candidates']

    data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
    data_config['target'] = ['Age']
    data_config['continuous_cols'] = feats
    trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
    trainer_config['checkpoints_path'] = f"{path}/pytorch_tabular"
    optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

    lr_find_min_lr = 1e-8
    lr_find_max_lr = 10
    lr_find_num_training = 256
    lr_find_mode = "exponential"
    lr_find_early_stop_threshold = 8.0

    trainer_config['seed'] = seed
    trainer_config['checkpoints'] = 'valid_loss'
    trainer_config['load_best'] = True
    trainer_config['auto_lr_find'] = False

    model_config_default = read_parse_config(f"{path_configs}/models/{model_name}Config.yaml", model_config_name)
    tabular_model_default = TabularModel(
        data_config=data_config,
        model_config=model_config_default,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=False,
    )
    datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)

    opt_parts = ['test', 'validation']
    opt_metrics = [('mean_absolute_error', 'minimize')]
    # opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('pearson_corrcoef', 'maximize')]
    opt_directions = []
    for part in opt_parts:
        for metric_pair in opt_metrics:
            opt_directions.append(f"{metric_pair[1]}")

    trials_results = []

    study = optuna.create_study(
        study_name=model_name,
        sampler=optuna.samplers.TPESampler(
            n_startup_trials=n_startup_trials,
            n_ei_candidates=n_ei_candidates,
            seed=seed,
        ),
        directions=opt_directions
    )
    study.optimize(
        func=lambda trial: train_hyper_opt(
            trial=trial,
            trials_results=trials_results,
            opt_metrics=opt_metrics,
            opt_parts=opt_parts,
            model_config_default=model_config_default,
            data_config_default=data_config,
            optimizer_config_default=optimizer_config,
            trainer_config_default=trainer_config,
            experiment_config_default=None,
            train=train,
            validation=validation,
            test=test,
            datamodule=datamodule,
            min_lr=lr_find_min_lr,
            max_lr=lr_find_max_lr,
            num_training=lr_find_num_training,
            mode=lr_find_mode,
            early_stop_threshold=lr_find_early_stop_threshold
        ),
        n_trials=n_trials,
        show_progress_bar=True
    )

    fn_trials = (
        f"model({model_name})_"
        f"trials({n_trials}_{seed}_{n_startup_trials}_{n_ei_candidates})_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
    )

    df_trials = pd.DataFrame(trials_results)
    df_trials['split_id'] = tst_split_id
    df_trials['fold_id'] = val_fold_id
    df_trials["train_more"] = False
    df_trials.loc[(df_trials["train_loss"] > df_trials["test_loss"]) | (
            df_trials["train_loss"] > df_trials["validation_loss"]), "train_more"] = True
    df_trials["validation_test_mean_loss"] = (df_trials["validation_loss"] + df_trials["test_loss"]) / 2.0
    df_trials["train_validation_test_mean_loss"] = (df_trials["train_loss"] + df_trials["validation_loss"] + df_trials["test_loss"]) / 3.0
    df_trials.style.background_gradient(
        subset=[
            "train_loss",
            "validation_loss",
            "validation_test_mean_loss",
            "train_validation_test_mean_loss",
            "test_loss",
            "time_taken",
            "time_taken_per_epoch"
        ], cmap="RdYlGn_r"
    ).to_excel(f"{trainer_config['checkpoints_path']}/{fn_trials}.xlsx")

    dfs_models.append(df_trials)

df_models = pd.concat(dfs_models, ignore_index=True)
df_models.insert(0, 'Selected', 0)
fn = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)
df_models.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "validation_test_mean_loss",
        "train_validation_test_mean_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path}/pytorch_tabular/{fn}.xlsx")


# Perform tests on models

In [None]:
%%capture

path_sweep = f"{path}/pytorch_tabular"

fn_sweep = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)

df_sweeps = pd.read_excel(f"{path_sweep}/{fn_sweep}.xlsx", index_col=0)

df_sweeps.insert(15, 'Passed\nICD-11\nTotal', None)
ins_pos = 16
for icd_chpt in icd_chpts:
    df_sweeps.insert(ins_pos, f'Passed\nICD-11\nChapter {icd_chpt}', None)
    ins_pos += 1

models_tests_raw = pd.DataFrame(index=df_sweeps.index)
for model_id, model_row in (pbar := tqdm(df_sweeps.iterrows())):
    pbar.set_description(f"Processing {model_id}")

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    
    model = TabularModel.load_model(model_dir)

    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']
    
    for section_id, section_row in df_groups.iterrows():
        section_statuses = ast.literal_eval(section_row['Statuses'])
        section_groups = ast.literal_eval(section_row['Groups'])
        section_directions = ast.literal_eval(section_row['Directions'])
        df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status'].isin(section_statuses)), ['Status', 'Error']]
        
        for section_group_id, section_group in enumerate(section_groups):
            
            _, pval = mannwhitneyu(
                df_section.loc[df_section["Status"] == section_group[0], "Error"].values,
                df_section.loc[df_section["Status"] == section_group[1], "Error"].values,
                alternative="two-sided",
            )
            bias_0 = np.mean(df_section.loc[df_section['Status'] == section_group[0], 'Error'])
            bias_1 = np.mean(df_section.loc[df_section['Status'] == section_group[1], 'Error'])
            
            models_tests_raw.at[model_id, f"pval\n{section_id}\n{section_group}"] = pval
            models_tests_raw.at[model_id, f"bias_0\n{section_id}\n{section_group}"] = bias_0
            models_tests_raw.at[model_id, f"bias_1\n{section_id}\n{section_group}"] = bias_1
models_tests_raw.to_excel(f"{path_sweep}/models_tests_raw_before_correction.xlsx", index_label='Model ID')

# Here we can modify test results (p-values)
pvals_cols = [col for col in models_tests_raw.columns if 'pval' in col]
for model_id, model_row in (pbar := tqdm(df_sweeps.iterrows())):
    _, models_tests_raw.loc[model_id, pvals_cols], _, _ = multipletests(models_tests_raw.loc[model_id, pvals_cols], 0.05, method='fdr_bh')
models_tests_raw.to_excel(f"{path_sweep}/models_tests_raw_after_correction.xlsx", index_label='Model ID')
    
for model_id, model_row in (pbar := tqdm(df_sweeps.iterrows())):
    pbar.set_description(f"Processing {model_id}")
    passed_icd = {icd_chpt: 0 for icd_chpt in icd_chpts}
    for icd_chpt in icd_chpts:
        df_ckpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
        for section_id, section_row in df_ckpt.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])
            
            for section_group_id, section_group in enumerate(section_groups):
                
                pval = models_tests_raw.at[model_id, f"pval\n{section_id}\n{section_group}"]
                bias_0 = models_tests_raw.at[model_id, f"bias_0\n{section_id}\n{section_group}"]
                bias_1 = models_tests_raw.at[model_id, f"bias_1\n{section_id}\n{section_group}"]
                
                group_direction = section_directions[section_group_id]
                
                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd[icd_chpt] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd[icd_chpt] += 1
        
        df_sweeps.at[model_id, f'Passed\nICD-11\nChapter {icd_chpt}'] = passed_icd[icd_chpt]              
        
    df_sweeps.at[model_id, f'Passed\nICD-11\nTotal'] = sum(passed_icd.values())
    
df_sweeps.to_excel(f"{path_sweep}/models.xlsx", index_label='Model ID')

# Best models processing

In [None]:
gse_controls_count = data_full.loc[data_full['Status'] == 'Control', 'GSE'].value_counts()
gses_controls = gse_controls_count.index.values
gse_controls_ids = {gse: data_full.index[(data_full['Status'] == 'Control') & (data_full['GSE'] == gse)].values for gse in gses_controls}
colors = distinctipy.get_colors(len(gses_controls), [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_gse = {gses_controls[gse_id]: colors[gse_id] for gse_id in range(len(gses_controls))}

statuses_rename = {x: x.replace(' ', '\n').replace('-', '\n') for x in data_full['Status'].value_counts().index.values}
data_full['Status Origin'] = data_full['Status']
data_full['Status'] = data_full['Status'].replace(statuses_rename) 
status_count = data_full['Status'].value_counts()
statuses = status_count.index.values
colors = distinctipy.get_colors(len(statuses) - 1, [mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue']), mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_status = {'Control': mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue'])}
for status_id, status in enumerate(statuses):
    if status != 'Control':
        colors_status[status] = colors[status_id - 1]

mosaic_violins = []
mosaic_rows = np.sort(df_groups['Row on violin plot'].unique())
for mosaic_row in mosaic_rows:
    df_mosaic_row = df_groups[df_groups['Row on violin plot'] == mosaic_row].sort_values(by=['ICD-11 chapter', 'ICD-11 code', 'GSE'], ascending=[True, True, True])
    mosaic_row_labels = []
    for plot_id, plot_row in df_mosaic_row.iterrows():
        n_violins = len(ast.literal_eval(plot_row['Statuses']))
        mosaic_row_labels += [plot_id]*n_violins
    mosaic_violins.append(mosaic_row_labels)

max_mosaic_row = max([len(x) for x in mosaic_violins])
violons_empty_panels = set()
for row_id, row in enumerate(mosaic_violins):
    for added_spaces in range(len(row), max_mosaic_row):
        violons_empty_panels.add(f'Empty row {row_id}')
        mosaic_violins[row_id].append(f'Empty row {row_id}')

In [None]:
icd_cols = []
icd_code_chpt = {}
icd_codes_for_chpts = {}
for icd_chpt in icd_chpts:
    icd_cols.append(f'Passed\nICD-11\nChapter {icd_chpt}')
    icd_codes_for_chpts[icd_chpt] = np.sort(df_groups.loc[df_groups['ICD-11 chapter'] == icd_chpt, 'ICD-11 code'].unique())
    for icd_code in [f"Passed\nICD-11\nCode {x}" for x in icd_codes_for_chpts[icd_chpt]]:
        icd_code_chpt[icd_code] = f'Passed\nICD-11\nChapter {icd_chpt}'
    icd_cols += [f"Passed\nICD-11\nCode {x}" for x in icd_codes_for_chpts[icd_chpt]]
icd_cols_max = [f"Max\n{x}" for x in icd_cols]

col_names_common = ["Year", "Total Rho", "Total MAE", f"Passed\nICD-11\nTotal"]
col_defs_common = [
    ColumnDefinition(
        name="Clock Name",
        title="Clocks",
        textprops={"ha": "right", "weight": "bold"},
        width=2.25,
    ),
    ColumnDefinition(
        name="Year",
        title="Year",
        textprops={"ha": "center"},
        width=1.0,
        border="left"
    ),
    ColumnDefinition(
        name="Total Rho",
        title="Total\n" + r"Pearson $\rho$",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(clocks_tests["Total Rho"].dropna(), cmap=matplotlib.cm.Greens, num_stds=2.5),
        width=1.0,
        border="left"
    ),
    ColumnDefinition(
        name="Total MAE",
        title="Total\nMAE",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(clocks_tests["Total MAE"].dropna(), cmap=matplotlib.cm.Reds, num_stds=2.5),
        width=1.0,
    ),
    ColumnDefinition(
        name=f"Passed\nICD-11\nTotal",
        title="Passed\nICD-11",
        width=1.5,
        border="left",
        textprops={"ha": "center"},
        plot_fn=bar,
        plot_kw={
            "cmap": colormap_total,
            "plot_bg_bar": True,
            "annotate": True,
            "height": 0.95,
            "linewidth": 0.5,
            "formatter": form_bar(clocks_tests.at['Hannum', f'Max\nPassed\nICD-11\nTotal']),
        },
    ),
]

icd_chpt_col_defs = copy.deepcopy(col_defs_common)
icd_chpt_col_names = copy.deepcopy(col_names_common)
for icd_chpt in icd_chpts:
    if icd_chpt == 1:
        border = 'left'
    else:
        border = None
    max_passed = clocks_tests.at['Hannum', f'Max\nPassed\nICD-11\nChapter {icd_chpt}']
    icd_chpt_col_names.append(f'Passed\nICD-11\nChapter {icd_chpt}')
    col_def = ColumnDefinition(
        name=f'Passed\nICD-11\nChapter {icd_chpt}',
        title=f'Chapter {icd_chpt}',
        width=1.0,
        plot_fn=bar,
        border=border,
        textprops={"ha": "center"},
        plot_kw={
            "cmap": colormaps_icd_chpts[icd_chpt],
            "plot_bg_bar": True,
            "annotate": True,
            "height": 0.95,
            "lw": 0.5,
            "formatter": form_bar(max_passed),
        },
    )
    icd_chpt_col_defs.append(col_def)
    
icd_code_col_defs = copy.deepcopy(col_defs_common)
icd_code_col_names = copy.deepcopy(col_names_common)
for icd_chpt in icd_chpts:
    if icd_chpt in [11, 12]:
        group=fr"$\mathbf{{{icd_chpt}}}$"
    else:
        group=fr"$\mathbf{{Chapter\,{icd_chpt}}}$"
        
    for code_in_chpt in icd_codes_for_chpts[icd_chpt]:
        max_passed = clocks_tests.at['Hannum', f'Max\nPassed\nICD-11\nCode {code_in_chpt}']
        icd_code_col_names.append(f'Passed\nICD-11\nCode {code_in_chpt}')
        col_def = ColumnDefinition(
            name=f'Passed\nICD-11\nCode {code_in_chpt}',
            title=f'{code_in_chpt}',
            width=1.0,
            plot_fn=bar,
            border=None,
            textprops={"ha": "center"},
            plot_kw={
                "cmap": colormaps_icd_chpts[icd_chpt],
                "plot_bg_bar": True,
                "annotate": True,
                "height": 0.95,
                "lw": 0.5,
                "formatter": form_bar(max_passed),
            },
            group=group,
        )
        icd_code_col_defs.append(col_def)

In [None]:
explain_method = "GradientShap"
explain_baselines = "b|100000"

models_type = 'DANet' # 'FTTransformer' 'GANDALF' 'DANet'
models_ids = [
    # 39,
    # 43,
    # 25,
    # 0,
    
    #231,
    194,
    #22,
]
models_ids = sorted(list(set(models_ids)))

df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{models_type}/models.xlsx", index_col=0)
path_models = f"{path}/pytorch_tabular/{models_type}/candidates"
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/selected.xlsx")

df_models_metrics = pd.DataFrame(index=models_ids)
df_models_check = pd.DataFrame(index=models_ids)
df_gses_models_metrics = {}
for md in [f"{x[0]} {x[1]}" for x in itertools.product(['MAE', 'Rho', 'Bias'], ['Train', 'Validation', 'Test', 'Total'])]:
    df_gses_models_metrics[md] = pd.DataFrame(index=gses_controls, columns=['Count'] + list(models_ids))
    df_gses_models_metrics[md].loc[gses_controls, 'Count'] = gse_controls_count[gses_controls]

for model_id in models_ids:
    print(model_id)

    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]
    ids_test = split_dict['test']
    ids_train = split_dict['trains'][fold_id]
    ids_validation = split_dict['validations'][fold_id]
    ids_total = np.concatenate([ids_train, ids_validation, ids_test])
    ids_dict = {
        'Test': ids_test,
        'Train': ids_train,
        'Validation': ids_validation,
        'Total': ids_total
    }

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('D:\\', 'E:\\').replace('\\', '/') + f'/{models_type}/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_models}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_models}/{model_id}", dirs_exist_ok=True)
    
    def predict_func(X):
        X_df = pd.DataFrame(data=X, columns=feats)
        y = model.predict(X_df)['Age_prediction'].values
        return y

    data_full['Group'] = ''
    data_full.loc[ids_train, 'Group'] = 'Train'
    data_full.loc[ids_validation, 'Group'] = 'Validation'
    data_full.loc[ids_test, 'Group'] = 'Test'
    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']
    data_full[['GPL', 'GSE', 'Age', 'Sex', 'Status', 'Group', 'Prediction', 'Error']].to_excel(f"{path_models}/{model_id}/data.xlsx")

    df_metrics = pd.DataFrame(
        index=list(ids_dict.keys()),
        columns=['MAE', 'Rho', 'Bias']
    )

    for part, ids_part in ids_dict.items():

        pred = torch.from_numpy(data_full.loc[ids_part, 'Prediction'].values)
        real = torch.from_numpy(data_full.loc[ids_part, 'Age'].values)

        df_metrics.at[part, 'MAE'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[part, 'Rho'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[part, 'Bias'] = np.mean(data_full.loc[ids_part, 'Error'].values)

        df_models_metrics.at[model_id, f'MAE\n{part}'] = df_metrics.at[part, 'MAE']
        df_models_metrics.at[model_id, f'Rho\n{part}'] = df_metrics.at[part, 'Rho']
        df_models_metrics.at[model_id, f'Bias\n{part}'] = df_metrics.at[part, 'Bias']

        if part != 'Total':
            df_models_check.at[model_id, f'{part} MAE Before'] = df_sweeps.at[model_id, f'{part.lower()}_mean_absolute_error']
            df_models_check.at[model_id, f'{part} Rho Before'] = df_sweeps.at[model_id, f'{part.lower()}_pearson_corrcoef']
            df_models_check.at[model_id, f'{part} MAE After'] = df_metrics.at[part, 'MAE']
            df_models_check.at[model_id, f'{part} Rho After'] = df_metrics.at[part, 'Rho']

    df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")

    for gse, ids_gse in gse_controls_ids.items():
        for part, ids_part in ids_dict.items():
            ids_intxn = list(set.intersection(set(ids_gse), set(ids_part)))
            if len(ids_intxn) > 0:
                pred = torch.from_numpy(data_full.loc[ids_intxn, 'Prediction'].values)
                real = torch.from_numpy(data_full.loc[ids_intxn, 'Age'].values)
                df_gses_models_metrics[f'MAE {part}'].at[gse, model_id] = mean_absolute_error(pred, real).numpy()
                df_gses_models_metrics[f'Rho {part}'].at[gse, model_id] = pearson_corrcoef(pred, real).numpy()
                df_gses_models_metrics[f'Bias {part}'].at[gse, model_id] = np.mean(data_full.loc[ids_intxn, 'Error'].values)
                
    # # SHAP age acceleration table

    # df_imm_shap_raw = pd.DataFrame(index=feats)
    # for section_id, section_row in df_groups.iterrows():
    #     section_statuses = ast.literal_eval(section_row['Statuses'])
    #     section_groups = ast.literal_eval(section_row['Groups'])
    #     section_directions = ast.literal_eval(section_row['Directions'])
    #     df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status Origin'].isin(section_statuses)), :]

    #     for section_group_id, section_group in enumerate(section_groups):
    #         group_direction = section_directions[section_group_id]

    #         df_section_group_0 = df_section.loc[df_section["Status Origin"] == section_group[0], :]
    #         explainer_0 = shap.SamplingExplainer(predict_func, df_section_group_0.loc[:, feats].values)
    #         shap_values_0 = explainer_0.shap_values(df_section_group_0.loc[:, feats].values)
    #         section_group_expl_0 = pd.DataFrame(data=shap_values_0, index=df_section_group_0.index, columns=feats)
    #         shap_values_mean_0 = section_group_expl_0.mean()

    #         df_section_group_1 = df_section.loc[df_section["Status Origin"] == section_group[1], :]
    #         explainer_1 = shap.SamplingExplainer(predict_func, df_section_group_1.loc[:, feats].values)
    #         shap_values_1 = explainer_1.shap_values(df_section_group_1.loc[:, feats].values)
    #         section_group_expl_1 = pd.DataFrame(data=shap_values_1, index=df_section_group_1.index, columns=feats)
    #         shap_values_mean_1 = section_group_expl_1.mean()

    #         if group_direction == 'Increasing':
    #             df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_1[feats] - shap_values_mean_0[feats]
    #         else:
    #             df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_0[feats] - shap_values_mean_1[feats]
    # df_imm_shap_raw.to_excel(f"{path_models}/{model_id}/shap_mean_diff_raw.xlsx", index_label='Immunomarker')  

    # df_imm_shap = pd.DataFrame(index=feats)
    # for icd_chpt in icd_chpts:
    #     df_imm_shap.loc[feats, f'Chapter {icd_chpt}'] = 0.0
    #     df_chpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
    #     for section_id, section_row in df_chpt.iterrows():
    #         section_groups = ast.literal_eval(section_row['Groups'])
    #         for section_group_id, section_group in enumerate(section_groups):
    #             df_imm_shap.loc[feats, f'Chapter {icd_chpt}'] += df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"]   
    # for icd_code in icd_codes:
    #     df_imm_shap.loc[feats, f'Code {icd_code}'] = 0.0
    #     df_code = df_groups[df_groups['ICD-11 code'] == icd_code]
    #     for section_id, section_row in df_code.iterrows():
    #         section_groups = ast.literal_eval(section_row['Groups'])
    #         for section_group_id, section_group in enumerate(section_groups):
    #             df_imm_shap.loc[feats, f'Code {icd_code}'] += df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"]
    # df_imm_shap.to_excel(f"{path_models}/{model_id}/shap_mean_diff.xlsx", index_label='Immunomarker')  

    # df_imm_shap_table = pd.read_excel(f"{path_models}/{model_id}/shap_mean_diff.xlsx", index_col=0)
    # df_imm_shap_table = df_imm_shap_table[[f'Chapter {icd_chpt}' for icd_chpt in icd_chpts]]
    # df_imm_shap_table.index = df_imm_shap_table.index.str.replace('_log', '')
    
    # df_fig = df_imm_shap_table.astype(float)
    # x_ticks_colors = {f'Chapter {icd_chpt}': colors_icd_chpts[icd_chpt] for icd_chpt in icd_chpts}
    # sns.set_theme(style='ticks')
    # clustermap = sns.clustermap(
    #     df_fig,
    #     annot=True,
    #     col_cluster=False,
    #     row_cluster=True,
    #     fmt=".2f",
    #     center=0.0,
    #     cmap='seismic',
    #     linewidth=0.1,
    #     linecolor='black',
    #     tree_kws=dict(linewidths=1.5),
    #     figsize=(16, 12),
    #     cbar_kws={'orientation': 'horizontal'}
    # )
    # clustermap.ax_heatmap.set_xlabel('')
    # clustermap.ax_heatmap.set_ylabel('')
    # for spine in clustermap.ax_cbar.spines.values():
    #     spine.set(visible=True, lw=0.25, edgecolor="black")
    # clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), rotation=0, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
    # for tick_label in clustermap.ax_heatmap.get_xticklabels():
    #     tick_label.set_color(x_ticks_colors[tick_label.get_text()])
    # clustermap_pos = clustermap.ax_heatmap.get_position()
    # clustermap.ax_cbar.set_position([clustermap_pos.x0, clustermap_pos.y1 + 0.05, clustermap_pos.width, 0.03])
    # clustermap.ax_cbar.set_title("XAI age acceleration difference", fontsize='large')
    # clustermap.ax_cbar.tick_params(labelsize='large')
    # for spine in clustermap.ax_cbar.spines:
    #     clustermap.ax_cbar.spines[spine].set_linewidth(1)
    # plt.savefig(f"{path_models}/{model_id}/shap_mean_diff.png", bbox_inches='tight', dpi=200)
    # plt.savefig(f"{path_models}/{model_id}/shap_mean_diff.pdf", bbox_inches='tight')
    # plt.close(clustermap.figure)

    # # SHAP age acceleration table

    # df_imm_shap_raw = pd.DataFrame(index=feats)
    # for section_id, section_row in df_groups.iterrows():
    #     section_statuses = ast.literal_eval(section_row['Statuses'])
    #     section_groups = ast.literal_eval(section_row['Groups'])
    #     section_directions = ast.literal_eval(section_row['Directions'])
    #     df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status Origin'].isin(section_statuses)), :]

    #     for section_group_id, section_group in enumerate(section_groups):
    #         group_direction = section_directions[section_group_id]

    #         df_section_group_0 = df_section.loc[df_section["Status Origin"] == section_group[0], :]
    #         section_group_expl_0 = model.explain(df_section_group_0, method=explain_method, baselines=explain_baselines)
    #         shap_values_mean_0 = section_group_expl_0.mean()

    #         df_section_group_1 = df_section.loc[df_section["Status Origin"] == section_group[1], :]
    #         section_group_expl_1 = model.explain(df_section_group_1, method=explain_method, baselines=explain_baselines)
    #         shap_values_mean_1 = section_group_expl_1.mean()

    #         if group_direction == 'Increasing':
    #             df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_1[feats] - shap_values_mean_0[feats]
    #         else:
    #             df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"] = shap_values_mean_0[feats] - shap_values_mean_1[feats]
    # df_imm_shap_raw.to_excel(f"{path_models}/{model_id}/shap_mean_diff_raw.xlsx", index_label='Immunomarker')  

    # df_imm_shap = pd.DataFrame(index=feats)
    # for icd_chpt in icd_chpts:
    #     df_imm_shap.loc[feats, f'Chapter {icd_chpt}'] = 0.0
    #     df_chpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
    #     for section_id, section_row in df_chpt.iterrows():
    #         section_groups = ast.literal_eval(section_row['Groups'])
    #         for section_group_id, section_group in enumerate(section_groups):
    #             df_imm_shap.loc[feats, f'Chapter {icd_chpt}'] += df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"]   
    # for icd_code in icd_codes:
    #     df_imm_shap.loc[feats, f'Code {icd_code}'] = 0.0
    #     df_code = df_groups[df_groups['ICD-11 code'] == icd_code]
    #     for section_id, section_row in df_code.iterrows():
    #         section_groups = ast.literal_eval(section_row['Groups'])
    #         for section_group_id, section_group in enumerate(section_groups):
    #             df_imm_shap.loc[feats, f'Code {icd_code}'] += df_imm_shap_raw.loc[feats, f"shap_mean_diff\n{section_id}\n{section_group}"]
    # df_imm_shap.to_excel(f"{path_models}/{model_id}/shap_mean_diff.xlsx", index_label='Immunomarker')  

    # df_imm_shap_table = pd.read_excel(f"{path_models}/{model_id}/shap_mean_diff.xlsx", index_col=0)
    # df_imm_shap_table = df_imm_shap_table[[f'Chapter {icd_chpt}' for icd_chpt in icd_chpts]]
    # df_imm_shap_table.index = df_imm_shap_table.index.str.replace('_log', '')
    
    # df_fig = df_imm_shap_table.astype(float)
    # x_ticks_colors = {f'Chapter {icd_chpt}': colors_icd_chpts[icd_chpt] for icd_chpt in icd_chpts}
    # sns.set_theme(style='ticks')
    # clustermap = sns.clustermap(
    #     df_fig,
    #     annot=True,
    #     col_cluster=False,
    #     row_cluster=True,
    #     fmt=".2f",
    #     center=0.0,
    #     cmap='seismic',
    #     linewidth=0.1,
    #     linecolor='black',
    #     tree_kws=dict(linewidths=1.5),
    #     figsize=(16, 12),
    #     cbar_kws={'orientation': 'horizontal'}
    # )
    # clustermap.ax_heatmap.set_xlabel('')
    # clustermap.ax_heatmap.set_ylabel('')
    # for spine in clustermap.ax_cbar.spines.values():
    #     spine.set(visible=True, lw=0.25, edgecolor="black")
    # clustermap.ax_heatmap.set_xticklabels(clustermap.ax_heatmap.get_xmajorticklabels(), rotation=0, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
    # for tick_label in clustermap.ax_heatmap.get_xticklabels():
    #     tick_label.set_color(x_ticks_colors[tick_label.get_text()])
    # clustermap_pos = clustermap.ax_heatmap.get_position()
    # clustermap.ax_cbar.set_position([clustermap_pos.x0, clustermap_pos.y1 + 0.05, clustermap_pos.width, 0.03])
    # clustermap.ax_cbar.set_title("XAI age acceleration difference", fontsize='large')
    # clustermap.ax_cbar.tick_params(labelsize='large')
    # for spine in clustermap.ax_cbar.spines:
    #     clustermap.ax_cbar.spines[spine].set_linewidth(1)
    # plt.savefig(f"{path_models}/{model_id}/shap_mean_diff.png", bbox_inches='tight', dpi=200)
    # plt.savefig(f"{path_models}/{model_id}/shap_mean_diff.pdf", bbox_inches='tight')
    # plt.close(clustermap.figure)

    # Table with all clocks

    clocks_tests.at['This work', 'Total Rho'] = df_metrics.at['Total', 'Rho']
    clocks_tests.at['This work', 'Total MAE'] = df_metrics.at['Total', 'MAE']
    clocks_tests.at['This work', 'Year'] = ''

    for section_id, section_row in df_groups.iterrows():
        section_statuses = ast.literal_eval(section_row['Statuses'])
        section_groups = ast.literal_eval(section_row['Groups'])
        section_directions = ast.literal_eval(section_row['Directions'])
        df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status Origin'].isin(section_statuses)), ['Status Origin', 'Error']]

        for section_group_id, section_group in enumerate(section_groups):

            _, pval = mannwhitneyu(
                df_section.loc[df_section["Status Origin"] == section_group[0], "Error"].values,
                df_section.loc[df_section["Status Origin"] == section_group[1], "Error"].values,
                alternative="two-sided",
            )
            bias_0 = np.mean(df_section.loc[df_section['Status Origin'] == section_group[0], 'Error'])
            bias_1 = np.mean(df_section.loc[df_section['Status Origin'] == section_group[1], 'Error'])

            clocks_tests_raw.at['This work', f"pval\n{section_id}\n{section_group}"] = pval
            clocks_tests_raw.at['This work', f"bias_0\n{section_id}\n{section_group}"] = bias_0
            clocks_tests_raw.at['This work', f"bias_1\n{section_id}\n{section_group}"] = bias_1

    # Here we can modify clocks' test results (p-values)
    pvals_cols = [col for col in clocks_tests_raw.columns if 'pval' in col]
    _, clocks_tests_raw.loc['This work', pvals_cols], _, _ = multipletests(clocks_tests_raw.loc['This work', pvals_cols], 0.05, method='fdr_bh')

    passed_icd_chpt = {icd_chpt: 0 for icd_chpt in icd_chpts}
    passed_icd_chpt_max = {icd_chpt: 0 for icd_chpt in icd_chpts}
    for icd_chpt in icd_chpts:
        df_chpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
        for section_id, section_row in df_chpt.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])

            for section_group_id, section_group in enumerate(section_groups):
                passed_icd_chpt_max[icd_chpt] += 1

                pval = clocks_tests_raw.at['This work', f"pval\n{section_id}\n{section_group}"]
                bias_0 = clocks_tests_raw.at['This work', f"bias_0\n{section_id}\n{section_group}"]
                bias_1 = clocks_tests_raw.at['This work', f"bias_1\n{section_id}\n{section_group}"]

                group_direction = section_directions[section_group_id]

                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd_chpt[icd_chpt] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd_chpt[icd_chpt] += 1
        clocks_tests.at['This work', f'Passed\nICD-11\nChapter {icd_chpt}'] = passed_icd_chpt[icd_chpt]
        clocks_tests.at['This work', f'Max\nPassed\nICD-11\nChapter {icd_chpt}'] = passed_icd_chpt_max[icd_chpt]
    clocks_tests.at['This work', f'Passed\nICD-11\nTotal'] = sum(passed_icd_chpt.values())
    clocks_tests.at['This work', f'Max\nPassed\nICD-11\nTotal'] = sum(passed_icd_chpt_max.values())

    passed_icd_code = {icd_code: 0 for icd_code in icd_codes}
    passed_icd_code_max = {icd_code: 0 for icd_code in icd_codes}
    for icd_code in icd_codes:
        df_code = df_groups[df_groups['ICD-11 code'] == icd_code]
        for section_id, section_row in df_code.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])

            for section_group_id, section_group in enumerate(section_groups):
                passed_icd_code_max[icd_code] += 1

                pval = clocks_tests_raw.at['This work', f"pval\n{section_id}\n{section_group}"]
                bias_0 = clocks_tests_raw.at['This work', f"bias_0\n{section_id}\n{section_group}"]
                bias_1 = clocks_tests_raw.at['This work', f"bias_1\n{section_id}\n{section_group}"]

                group_direction = section_directions[section_group_id]

                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd_code[icd_code] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd_code[icd_code] += 1
        clocks_tests.at['This work', f'Passed\nICD-11\nCode {icd_code}'] = passed_icd_code[icd_code]
        clocks_tests.at['This work', f'Max\nPassed\nICD-11\nCode {icd_code}'] = passed_icd_code_max[icd_code]

    clocks_tests.to_excel(f"{path_models}/{model_id}/clocks_test.xlsx", index_label='Model ID')
    clocks_tests_raw.to_excel(f"{path_models}/{model_id}/clocks_tests_raw.xlsx", index_label='Model ID')
    df_clocks = clocks_tests.copy()
    df_clocks[f"Passed\nICD-11\nTotal"] /= clocks_tests.at['Hannum', f'Max\nPassed\nICD-11\nTotal']
    for col in icd_cols:
        df_clocks[col] /= clocks_tests.at['Hannum', f'Max\n{col}']
    df_clocks = df_clocks.iloc[np.arange(-1, len(df_clocks)-1)] # Move the last row to the first position

    fig, ax = plt.subplots(figsize=(25, 17))
    table = Table(
        df_clocks[icd_chpt_col_names],
        column_definitions=icd_chpt_col_defs,
        row_dividers=True,
        footer_divider=False,
        odd_row_color="#ffffff",
        even_row_color="#f0f0f0",
        ax=ax,
        # textprops={"fontsize": 10},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=icd_chpt_col_names)
    fig.savefig(f"{path_models}/{model_id}/clocks_chpts.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/clocks_chpts.pdf", bbox_inches='tight')
    plt.close(fig)
    
    fig, ax = plt.subplots(figsize=(45, 22))
    table = Table(
        df_clocks[icd_code_col_names],
        column_definitions=icd_code_col_defs,
        row_dividers=True,
        footer_divider=False,
        odd_row_color="#ffffff",
        even_row_color="#f0f0f0",
        ax=ax,
        # textprops={"fontsize": 10},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=icd_code_col_names)
    fig.savefig(f"{path_models}/{model_id}/clocks_codes.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/clocks_codes.pdf", bbox_inches='tight')
    plt.close(fig)

    # Results on all controls

    sns.set_theme(style='ticks')
    fig = plt.figure(layout='constrained', figsize=(15, 10))
    subfigs = fig.subfigures(1, 2, width_ratios=[4, 9], wspace=0.05)
    
    axs = subfigs[0].subplots(4, 1, height_ratios=[0.2, 0.2, 0.8, 0.58], gridspec_kw={'wspace':0.25, 'hspace': 0.05}, sharey=False, sharex=False)

    data_ctrl = data_full.loc[data_full['Status'] == 'Control', ['Age', 'Status', 'Group', 'Prediction', 'Error']]
    row_id_table = 0
    row_id_hist = 1
    row_id_scatter = 2
    row_id_violin = 3

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test', 'Total'])
    for part in ['Train', 'Validation', 'Test', 'Total']:
        df_table.at['MAE', part] = f"{df_metrics.at[part, 'MAE']:0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{df_metrics.at[part, 'Rho']:0.3f}"
        df_table.at["Bias", part] = f"{df_metrics.at[part, 'Bias']:0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left"
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=2.2,
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
        ),
        ColumnDefinition(
            name="Total",
            textprops={"ha": "left"},
            width=1.5,
        )
    ]
    
    axs[row_id_table].text(-2, -1, 'A', fontsize=30, fontfamily='arial')

    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test', 'Total'])

    hist_bins = np.linspace(0, 120, 25)
    histplot = sns.histplot(
        data=data_ctrl,
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color='lightslategray',
        ax=axs[row_id_hist]
    )
    axs[row_id_hist].set_xticks([])
    axs[row_id_hist].set_xlim(0, 105)
    axs[row_id_hist].set_ylabel("Count")
    axs[row_id_hist].set_xlabel("")

    kdeplot = sns.kdeplot(
        data=data_ctrl.loc[data_ctrl['Group'].isin(['Train', 'Validation']), :],
        x='Age',
        y='Prediction',
        fill=True,
        cbar=False,
        thresh=0.005,
        color=make_rgb_transparent(mcolors.hex2color(mcolors.CSS4_COLORS['lightslategray']), (1, 1, 1), 0.25),
        cut=0,
        legend=False,
        ax=axs[row_id_scatter]
    )
    scatter = sns.scatterplot(
        data=data_ctrl.loc[data_ctrl['Group'] == 'Test', :],
        x='Age',
        y="Prediction",
        linewidth=0.01,
        alpha=0.8,
        edgecolor="k",
        s=2,
        color='lightslategray',
        ax=axs[row_id_scatter],
    )
    axs[row_id_scatter].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter].set_xlim(0, 105)
    axs[row_id_scatter].set_ylim(0, 105)
    axs[row_id_scatter].set_ylabel("Prediction")
    axs[row_id_scatter].set_xlabel("Age")
    
    violin = sns.violinplot(
        data=data_ctrl,
        x='Group',
        y='Error',
        palette={'Train': 'lightgray', 'Validation': 'lightslategray', 'Test': 'dimgrey'},
        scale='width',
        order=['Train', 'Validation', 'Test'],
        saturation=0.75,
        legend=False,
        ax=axs[row_id_violin]
    )
    axs[row_id_violin].set_xlabel('')
    axs[row_id_violin].set_ylabel('Age Acceleration')
    
    ids_shap = ids_test
            
    # explanation = model.explain(data_full.loc[ids_shap, :], method=explain_method, baselines=explain_baselines)
    # explanation.index = data_full.loc[ids_shap, :].index
    # explanation.rename(columns=lambda s: s.replace("_log", ""), inplace=True)
    # explanation.to_excel(f"{path_models}/{model_id}/explanation.xlsx")
    explanation = pd.read_excel(f"{path_models}/{model_id}/shap_global/explanation.xlsx", index_col=0)
    
    path_imm = f"E:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}"
    df_imm_models = pd.read_excel(f"{path_imm}/best_models_v5.xlsx", index_col=0)
    df_imm_models.sort_values(['test_pearson_corrcoef'], ascending=[False], inplace=True)
    imm_colors = distinctipy.get_colors(n_colors=df_imm_models.shape[0], exclude_colors=[mcolors.hex2color(mcolors.CSS4_COLORS['gray'])], rng=42)
    imm_colors_dict = {}
    for imm_id, imm in enumerate(df_imm_models.index.values):
        imm_color = imm_colors[imm_id]
        imm_colors_dict[imm] = imm_color
    
    ds_fi = pd.DataFrame(index=explanation.columns.values, columns=['mean(|SHAP|)'])
    for f in explanation.columns.values:
        ds_fi.at[f, 'mean(|SHAP|)'] = explanation[f].abs().mean()
    ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    ds_fi['Features'] = ds_fi.index.values
    

    axs = subfigs[1].subplots(1, 2, width_ratios=[2, 7], gridspec_kw={'wspace':0.01, 'hspace': 0.05}, sharey=True, sharex=False)
    
    axs[0].text(-2, -1, 'B', fontsize=30, fontfamily='arial')
    
    barplot = sns.barplot(
        data=ds_fi,
        x='mean(|SHAP|)',
        y='Features',
        # color=ds_color,
        hue='Features',
        palette=imm_colors_dict,
        edgecolor='black',
        dodge=False,
        ax=axs[0]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=3.0)
    axs[0].set_ylabel('')
    axs[0].set(yticklabels=ds_fi.index.to_list())
    # axs[0].get_legend().remove()
    
    is_colorbar = False
    f_legends = []
    for f in ds_fi.index:
        
        f_shap_ll = explanation[f].quantile(0.01)
        f_shap_hl = explanation[f].quantile(0.99)
        f_shap_index = explanation.index[(explanation[f] >= f_shap_ll) & (explanation[f] <= f_shap_hl)].values
        
        df_f_vals = data_full.loc[f_shap_index, :]
        f_vals_ll = df_f_vals[f"{f}_log"].quantile(0.01)
        f_vals_hl = df_f_vals[f"{f}_log"].quantile(0.99)
        f_shap_index = df_f_vals.index[(df_f_vals[f"{f}_log"] >= f_vals_ll) & (df_f_vals[f"{f}_log"] <= f_vals_hl)].values
        
        f_shap = explanation.loc[f_shap_index, f].values
        f_vals = data_full.loc[f_shap_index, f"{f}_log"].values
        
        f_cmap = sns.color_palette("coolwarm", as_cmap=True)
        f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
        f_colors = {}
        for cval in f_vals:
            f_colors.update({cval: f_cmap(f_norm(cval))})

        strip = sns.stripplot(
            x=f_shap,
            y=[f]*len(f_shap),
            hue=f_vals,
            palette=f_colors,
            jitter=0.35,
            alpha=0.5,
            edgecolor='gray',
            linewidth=0.00,
            size=2,
            legend=False,
            ax=axs[1],
        )
        
        if not is_colorbar:
            sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
            sm.set_array([])
            cbar = strip.figure.colorbar(sm)
            cbar.set_label('Inflammatory markers', labelpad=-4, fontsize='large')
            cbar.set_ticks([min(f_vals), max(f_vals)])
            cbar.set_ticklabels(["Min", "Max"])
            is_colorbar = True
        
    axs[1].set_xlabel('SHAP')

    fig.savefig(f"{path_models}/{model_id}/controls.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/controls.pdf", bbox_inches='tight')
    plt.close(fig)

    # Results on GSEs

    n_rows = 6 * 4
    n_cols = 12
    fig_height = 32
    fig_width = 46
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.2, 0.8, 0.15] * 6, gridspec_kw={'wspace':0.25, 'hspace': 0.05}, sharey=False, sharex=False)

    for gse_id, (gse, gse_samples) in tqdm(enumerate(gse_controls_ids.items())):
        color_gse = colors_gse[gse]
        data_gse = data_full.loc[gse_samples, ['GSE', 'Age', 'Status', 'Group', 'Prediction', 'Error']]
        row_id, col_id = divmod(gse_id, n_cols)
        row_id_table = row_id * 4
        row_id_hist = row_id * 4 + 1
        row_id_scatter = row_id * 4 + 2
        row_id_empty = row_id * 4 + 3

        df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test', 'Total'])
        for part in ['Train', 'Validation', 'Test', 'Total']:
            df_table.at['MAE', part] = f"{df_gses_models_metrics[f'MAE {part}'].at[gse, model_id]:0.3f}"
            df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{df_gses_models_metrics[f'Rho {part}'].at[gse, model_id]:0.3f}"
            df_table.at["Bias", part] = f"{df_gses_models_metrics[f'Bias {part}'].at[gse, model_id]:0.3f}"

        col_defs = [
            ColumnDefinition(
                name="index",
                title=gse if gse != 'GSEUNN' else 'This work',
                textprops={"ha": "center", "weight": "bold"},
                width=2.5,
            ),
            ColumnDefinition(
                name="Train",
                textprops={"ha": "left"},
                width=1.5,
                border="left"
            ),
            ColumnDefinition(
                name="Validation",
                textprops={"ha": "left"},
                width=2.2,
            ),
            ColumnDefinition(
                name="Test",
                textprops={"ha": "left"},
                width=1.5,
            ),
            ColumnDefinition(
                name="Total",
                textprops={"ha": "left"},
                width=1.5,
            )
        ]

        table = Table(
            df_table,
            column_definitions=col_defs,
            row_dividers=True,
            footer_divider=False,
            ax=axs[row_id_table, col_id],
            textprops={"fontsize": 8},
            row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
            col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
            column_border_kw={"linewidth": 1, "linestyle": "-"},
        ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test', 'Total'])

        hist_bins = np.linspace(0, 120, 13)
        histplot = sns.histplot(
            data=data_gse,
            bins=hist_bins,
            edgecolor='k',
            linewidth=1,
            x="Age",
            color=color_gse,
            ax=axs[row_id_hist, col_id]
        )
        axs[row_id_hist, col_id].set_xticks([])
        axs[row_id_hist, col_id].set_xlim(0, 115)
        if col_id == 0:
            axs[row_id_hist, col_id].set_ylabel("Count")
        else:
            axs[row_id_hist, col_id].set_ylabel("")

        kdeplot = sns.kdeplot(
            data=data_gse.loc[data_gse['Group'].isin(['Train', 'Validation']), :],
            x='Age',
            y='Prediction',
            fill=True,
            cbar=False,
            color=make_rgb_transparent(color_gse, (1, 1, 1), 0.25),
            cut=0,
            legend=False,
            ax=axs[row_id_scatter, col_id]
        )
        scatter = sns.scatterplot(
            data=data_gse.loc[data_gse['Group'] == 'Test', :],
            x='Age',
            y="Prediction",
            linewidth=0.5,
            alpha=0.8,
            edgecolor="k",
            s=35,
            color=color_gse,
            ax=axs[row_id_scatter, col_id],
        )
        axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
        axs[row_id_scatter, col_id].set_xlim(0, 115)
        axs[row_id_scatter, col_id].set_ylim(0, 115)
        if col_id == 0:
            axs[row_id_scatter, col_id].set_ylabel("Prediction")
        else:
            axs[row_id_scatter, col_id].set_ylabel("")
        if row_id_empty == n_rows - 1:
            axs[row_id_scatter, col_id].set_xlabel("Age")
        else:
            axs[row_id_scatter, col_id].set_xlabel("")
        axs[row_id_empty, col_id].axis('off')

    fig.tight_layout()
    fig.savefig(f"{path_models}/{model_id}/gses.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/gses.pdf", bbox_inches='tight')
    plt.close(fig)

    # Mosaic violins plot

    sns.set_theme(style='ticks')
    fig_height = 5 * len(mosaic_violins)
    fig_width = 1.4 * max_mosaic_row
    fig, axs = plt.subplot_mosaic(mosaic=mosaic_violins, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)

    for plot_id, plot_row in df_groups.iterrows():

        plot_statuses = [statuses_rename[x] for x in ast.literal_eval(plot_row['Statuses'])]
        plot_groups_raw = ast.literal_eval(plot_row['Groups'])
        plot_groups = [(statuses_rename[x[0]], statuses_rename[x[1]]) for x in ast.literal_eval(plot_row['Groups'])]

        df_plot = data_full.loc[(data_full['GSE'] == plot_row['GSE']) & (data_full['Status'].isin(plot_statuses)), ['Status', 'Error']]
        plot_status_count = df_plot['Status'].value_counts()
        plot_statuses_rename = {}
        for x in plot_statuses:
            plot_statuses_rename[x] = x + f"\nCount: {plot_status_count[x]}\nBias: {np.mean(df_plot.loc[df_plot['Status'] == x, 'Error']):0.1f}"
        plot_statuses = [plot_statuses_rename[x] for x in plot_statuses]
        plot_groups = [(plot_statuses_rename[x[0]], plot_statuses_rename[x[1]]) for x in plot_groups]
        df_plot['Status'] = df_plot['Status'].replace(plot_statuses_rename)
        colors_plot_status = {plot_statuses_rename[x]: colors_status[x] for x in plot_statuses_rename}

        pval_formatted = []
        for plot_group_id, plot_group in enumerate(plot_groups):
            pval = clocks_tests_raw.at['This work', f"pval\n{plot_id}\n{plot_groups_raw[plot_group_id]}"]
            pval_formatted.append(f"{pval:.1e}")

        violinplot = sns.violinplot(
            data=df_plot,
            x='Status',
            y='Error',
            palette=colors_plot_status,
            scale='width',
            order=plot_statuses,
            saturation=0.75,
            legend=False,
            ax=axs[plot_id]
        )
        annotator = Annotator(
            ax=axs[plot_id],
            pairs=plot_groups,
            data=df_plot,
            x="Status",
            y="Error",
            order=plot_statuses,
        )
        annotator.set_custom_annotations(pval_formatted)
        annotator.configure(loc='inside', verbose=0)
        annotator.annotate()

        axs[plot_id].set_xlabel('')
        axs[plot_id].set_ylabel('Age acceleration')
        if plot_row['GSE'] == 'GSEUNN':
            axs[plot_id].set_title(f"This work ({df_plot.shape[0]})")
        else:
            axs[plot_id].set_title(f"{plot_row['GSE']} ({df_plot.shape[0]})")
        # axs[plot_id].set_facecolor(make_rgb_transparent(colors_icd_chpts[plot_row['ICD-11 chapter']], (1, 1, 1), 0.33))

    for empty_panel in violons_empty_panels:
        axs[empty_panel].axis('off')
    fig.tight_layout()
    fig.savefig(f"{path_models}/{model_id}/violins_Status.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/violins_Status.pdf", bbox_inches='tight')
    plt.close(fig)

    # Ridgeline (column of violins) for controls' error

    # for part, ids_part in ids_dict.items():
    #     # Errors in GSEs
    #     df_fig = data_full.loc[ids_part, ['Error', 'GSE']].copy()
    #     df_fig['GSE'] = pd.Categorical(df_fig.GSE, categories=gses_controls, ordered=True)
    #     df_fig = df_fig.sort_values('GSE')
    #     gses_rename = {
    #         gse: f"{gse} ({gse_controls_count[gse]})" + "\n" +
    #                fr"MAE: {df_gses_models_metrics[f'MAE {part}'].at[gse, model_id]:0.2f}" + "\n"
    #                fr"Pearson $\rho$: {df_gses_models_metrics[f'Rho {part}'].at[gse, model_id]:0.2f}" + "\n" +
    #                fr"Bias: {df_gses_models_metrics[f'Bias {part}'].at[gse, model_id]:0.2f}"  + "\n" +
    #                f"{gse_preproc.at[gse, 'Preproc']}"
    #         for gse in colors_gse
    #     }
    #     gse_colors_grid = {gses_rename[gse]: colors_gse[gse] for gse in colors_gse}
    #     df_fig['GSE'].replace(gses_rename, inplace=True)
    #     sns.set_theme(style="whitegrid", rc={"axes.facecolor": (0, 0, 0, 0)})
    #     g = sns.FacetGrid(df_fig, row="GSE", hue="GSE", aspect=8, height=0.5, palette=gse_colors_grid)
    #     g.map(
    #         sns.kdeplot,
    #         'Error',
    #         fill=True,
    #         alpha=1.0,
    #         linewidth=0.5
    #     )
    #     g.map(
    #         sns.kdeplot,
    #         'Error',
    #         color="black",
    #         linewidth=1.0,
    #     )
    #     # g.refline(y=0, linewidth=2.0, linestyle="-", color=None, clip_on=False)
    #     def label(x, color, label):
    #         ax = plt.gca()
    #         ax.text(-0.15, 0.2, label, size=4, fontweight="light", color=color, ha="left", va="center", transform=ax.transAxes, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
    #     g.map(label, 'Error')
    #     # Set the subplots to overlap
    #     g.figure.subplots_adjust(hspace=0.0)
    #     # Remove axes details that don't play well with overlap
    #     g.set_titles("")
    #     g.set(yticks=[], ylabel="")
    #     g.despine(bottom=True, left=True)
    #     # Save
    #     g.savefig(f"{path_models}/{model_id}/ridgeline_GSEs_error_{part}.png", bbox_inches='tight', dpi=200)
    #     g.savefig(f"{path_models}/{model_id}/ridgeline_GSEs_error_{part}.pdf", bbox_inches='tight')
    #     plt.close(g.fig)

df_models_metrics.to_excel(f"{path_models}/models_metrics.xlsx", index_label='Model ID')

for group in ['Train', 'Validation', 'Test']:
    df_models_check[f'{group} MAE Diff'] = df_models_check[f'{group} MAE After'] - df_models_check[f'{group} MAE Before']
    df_models_check[f'{group} Rho Diff'] = df_models_check[f'{group} Rho After'] - df_models_check[f'{group} Rho Before']
df_models_check.to_excel(f"{path_models}/models_check.xlsx", index_label='Model ID')

with pd.ExcelWriter(f"{path_models}/gses_models_metrics.xlsx", engine='xlsxwriter') as writer:
    for md in [f"{x[0]} {x[1]}" for x in itertools.product(['MAE', 'Rho', 'Bias'], ['Train', 'Validation', 'Test', 'Total'])]:
        df_gses_models_metrics[md].insert(1, 'Preproc', gse_preproc.loc[df_gses_models_metrics[md].index, 'Preproc'])
        df_gses_models_metrics[md].to_excel(writer, sheet_name=md)

# Best model routines

## Supplemenatary table with all predicted data for all samples

In [None]:
model_type = 'DANet' # 'FTTransformer' 'GANDALF' 'DANet'
model_id = 194

df_model = pd.read_excel(f"{path}/pytorch_tabular/{model_type}/candidates/{model_id}/data.xlsx", index_col=0)
df_data = pd.read_excel(f"{path}/data_filtered.xlsx", index_col=0)

df_model.loc[df_model.index, 'Status'] = df_data.loc[df_model.index, 'Status']
df_model.loc[df_model.index, 'ICD-11 chapter'] = df_data.loc[df_model.index, 'ICD-11 chapter']
df_model.loc[df_model.index, 'ICD-11 chapter and description'] = df_data.loc[df_model.index, 'ICD-11 chapter and description']
df_model.loc[df_model.index, 'ICD-11 code'] = df_data.loc[df_model.index, 'ICD-11 code']
df_model.loc[df_model.index, 'ICD-11 code and description'] = df_data.loc[df_model.index, 'ICD-11 code and description']

for clock in clocks_tests.index:
    df_model.loc[df_model.index, clock] = df_data.loc[df_model.index, clocks_tests.at[clock, 'Model ID']]
    
for f in feats:
    df_model.loc[df_model.index, f.replace('_log', '')] = df_data.loc[df_model.index, f]

df_model.to_excel(f"{path}/EpImAge.xlsx", index_label="Sample ID")

## KDE plot for all clocks

In [None]:
df = pd.read_excel(f"{path}/EpImAge.xlsx", index_col="Sample ID")
df_ctrl = df.loc[df['Status'] == 'Control', :]
epiages = clocks_tests.index[clocks_tests['Type'] == 'Age'].to_list()

sns.set_theme(style='ticks')
fig = plt.figure(
    figsize=(13, 25),
    layout="constrained"
)
subfigs = fig.subfigures(
    nrows=7,
    ncols=4,
    # wspace=0.001,
    # hspace=0.001,
)
for epiage_id, epiage in tqdm(enumerate(epiages)):
    row_id, col_id = divmod(epiage_id, 4)

    axs = subfigs[row_id, col_id].subplot_mosaic(
        [
            ['1'],
            ['2'],
        ],
        height_ratios=[1, 4],
        gridspec_kw={
            "bottom": 0.14,
            "top": 0.95,
            # "left": 0.1,
            # "right": 0.5,
            "wspace": 0.33,
            "hspace": 0.01,
        },
    )
    
    ds_table = pd.DataFrame(index=['MAE', fr"Pearson $\rho$", "Bias"], columns=[epiage])
    mae = sklearn.metrics.mean_absolute_error(df_ctrl['Age'].values, df_ctrl[epiage].values)
    rho, _ = stats.pearsonr(df_ctrl['Age'].values, df_ctrl[epiage].values)
    bias = np.mean(df_ctrl[epiage] - df_ctrl['Age'])
    ds_table.at['MAE', epiage] = f"{mae:0.3f}"
    ds_table.at[fr"Pearson $\rho$", epiage] = f"{rho:0.3f}"
    ds_table.at["Bias", epiage] = f"{bias:0.3f}"
    table_title = f"{epiage}\:({clocks_tests.at[epiage, 'Year']})"
    col_defs = [
        ColumnDefinition(
            name="index",
            title=fr"$\mathbf{{{table_title}}}$",
            textprops={"ha": "left"},
            width=4.5,
        ),
        ColumnDefinition(
            name=epiage,
            title='',
            textprops={"ha": "center"},
            width=2.0,
        ),
    ]
    table = Table(
        ds_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs['1'],
        textprops={"fontsize": 7},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=[epiage])
    
    xy_min = df_ctrl[['Age', epiage]].min().min()
    xy_max = df_ctrl[['Age', epiage]].max().max()
    xy_ptp = xy_max - xy_min
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        zorder=0,
        ax=axs['2']
    )
    regplot = sns.regplot(
        data=df_ctrl,
        x='Age',
        y=epiage,
        color='crimson',
        scatter=False,
        truncate=False,
        ax=axs['2'],
    )
    kdeplot = sns.kdeplot(
        data=df_ctrl,
        x='Age',
        y=epiage,
        fill=True,
        cbar=False,
        color='gray',
        thresh=0.002,
        cut=0,
        legend=False,
        zorder=0,
        ax=axs['2']
    )
    axs['2'].set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    axs['2'].set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    
fig.savefig(f"{path}/epi_ages_distribution.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/epi_ages_distribution.pdf", bbox_inches='tight')
plt.close(fig)

## Correlation of all epigenetic clocks for controls and cases

In [None]:
df = pd.read_excel(f"{path}/EpImAge.xlsx", index_col="Sample ID")
df_ctrls = df.loc[df['Status'] == 'Control', :]
df_cases = df.loc[df['Status'] != 'Control', :]
epiages = ['EpImAge'] + clocks_tests.index[clocks_tests['Type'] == 'Age'].to_list()

df_corr = pd.DataFrame(index=epiages, columns=epiages, data=np.zeros(shape=(len(epiages), len(epiages))),)
for f_id_1 in range(len(epiages)):
    for f_id_2 in range(f_id_1, len(epiages)):
        f_1 = epiages[f_id_1]
        f_2 = epiages[f_id_2]
        vals_ctrls_1 = df_ctrls.loc[:, f_1].values
        vals_ctrls_2 = df_ctrls.loc[:, f_2].values
        vals_cases_1 = df_cases.loc[:, f_1].values
        vals_cases_2 = df_cases.loc[:, f_2].values
        if f_id_1 != f_id_2:
            corr_ctrls, _ = stats.pearsonr(vals_ctrls_1, vals_ctrls_2)
            corr_cases, _ = stats.pearsonr(vals_cases_1, vals_cases_2)
            df_corr.at[f_2, f_1] = corr_ctrls
            df_corr.at[f_1, f_2] = corr_cases
            if f_1 == 'EpImAge' and f_2 == 'Horvath':
                print(f"{f_1} vs {f_2} controls: {corr_ctrls}")
                print(f"{f_1} vs {f_2} cases: {corr_cases}")
        else:
            df_corr.at[f_2, f_1] = np.nan

sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(4.5 + 0.2 * len(epiages), 2.5 + 0.2 * len(epiages)), layout='constrained')
cmap_triu = plt.get_cmap("seismic").copy()
heatmap = sns.heatmap(
    df_corr,
    annot=True,
    fmt=".2f",
    center=0.0,
    cmap=cmap_triu,
    linewidth=0.1,
    linecolor='black',
    annot_kws={"fontsize": 32 / np.sqrt(len(df_corr.values) + 10)},
    ax=ax
)
ax.figure.axes[-1].set_ylabel(r"Pearson $\rho$", fontsize='x-large')
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('')
fig.savefig(f"{path}/epi_ages_correlation.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/epi_ages_correlation.pdf", bbox_inches='tight')
plt.close(fig)

## Correlation of all epigenetic clocks acceleration for controls and cases

In [None]:
df = pd.read_excel(f"{path}/EpImAge.xlsx", index_col="Sample ID")
df_ctrls = df.loc[df['Status'] == 'Control', :]
df_cases = df.loc[df['Status'] != 'Control', :]
epiages = ['EpImAge'] + clocks_tests.index[clocks_tests['Type'] == 'Age'].to_list()

df_corr = pd.DataFrame(index=epiages, columns=epiages, data=np.zeros(shape=(len(epiages), len(epiages))),)
for f_id_1 in range(len(epiages)):
    for f_id_2 in range(f_id_1, len(epiages)):
        f_1 = epiages[f_id_1]
        f_2 = epiages[f_id_2]
        vals_ctrls_1 = df_ctrls.loc[:, f_1].values - df_ctrls.loc[:, 'Age'].values
        vals_ctrls_2 = df_ctrls.loc[:, f_2].values - df_ctrls.loc[:, 'Age'].values
        vals_cases_1 = df_cases.loc[:, f_1].values - df_cases.loc[:, 'Age'].values
        vals_cases_2 = df_cases.loc[:, f_2].values - df_cases.loc[:, 'Age'].values
        if f_id_1 != f_id_2:
            corr_ctrls, _ = stats.pearsonr(vals_ctrls_1, vals_ctrls_2)
            corr_cases, _ = stats.pearsonr(vals_cases_1, vals_cases_2)
            df_corr.at[f_2, f_1] = corr_ctrls
            df_corr.at[f_1, f_2] = corr_cases
            if f_1 == 'EpImAge' and f_2 == 'Horvath':
                print(f"{f_1} vs {f_2} controls: {corr_ctrls}")
                print(f"{f_1} vs {f_2} cases: {corr_cases}")
        else:
            df_corr.at[f_2, f_1] = np.nan

sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(4.5 + 0.2 * len(epiages), 2.5 + 0.2 * len(epiages)), layout='constrained')
cmap_triu = plt.get_cmap("seismic").copy()
heatmap = sns.heatmap(
    df_corr,
    annot=True,
    fmt=".2f",
    center=0.0,
    cmap=cmap_triu,
    linewidth=0.1,
    linecolor='black',
    annot_kws={"fontsize": 32 / np.sqrt(len(df_corr.values) + 10)},
    ax=ax
)
ax.figure.axes[-1].set_ylabel(r"Age acceleration Pearson $\rho$", fontsize='x-large')
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title('')
fig.savefig(f"{path}/epi_ages_acceleration_correlation.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/epi_ages_acceleration_correlation.pdf", bbox_inches='tight')
plt.close(fig)

## Gradio background data

In [None]:
model_type = 'DANet' # 'FTTransformer' 'GANDALF' 'DANet'
model_id = 194

df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{model_type}/models.xlsx", index_col=0)
model = TabularModel.load_model(f"{path}/pytorch_tabular/{model_type}/candidates/{model_id}")

data_full['Group'] = ''
data_full['EpImAge'] = model.predict(data_full)
data_full['Age Acceleration'] = data_full['EpImAge'] - data_full['Age']

data_gradio = data_full.loc[data_full['Status'] == 'Control', ['Age', 'EpImAge', 'Age Acceleration'] + list(feats)]
data_gradio.to_pickle(f"{path}/Background.pkl")

## Local Explainability

In [None]:
model_type = 'DANet' # 'FTTransformer' 'GANDALF' 'DANet'
model_id = 194

df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{model_type}/models.xlsx", index_col=0)
model = TabularModel.load_model(f"{path}/pytorch_tabular/{model_type}/candidates/{model_id}")

split_id = df_sweeps.at[model_id, 'split_id']
fold_id = df_sweeps.at[model_id, 'fold_id']
split_dict = samples[split_id]
ids_test = split_dict['test']
ids_train = split_dict['trains'][fold_id]
ids_validation = split_dict['validations'][fold_id]
ids_total = np.concatenate([ids_train, ids_validation, ids_test])
ids_dict = {
    'Test': ids_test,
    'Train': ids_train,
    'Validation': ids_validation,
    'Total': ids_total
}

data_full['Group'] = ''
data_full.loc[ids_train, 'Group'] = 'Train'
data_full.loc[ids_validation, 'Group'] = 'Validation'
data_full.loc[ids_test, 'Group'] = 'Test'
data_full['Prediction'] = model.predict(data_full)
data_full['Error'] = data_full['Prediction'] - data_full['Age']

data_ctrl = data_full.loc[data_full['Status'] == 'Control', :]

def predict_func(X):
    X_df = pd.DataFrame(data=X, columns=feats)
    y = model.predict(X_df)['Age_prediction'].values
    return y

In [None]:
trgt_id = 'I1' # 'I1' 'I8' 'I1 (1)' 'I1 (2)'
trgt_age = data_full.at[trgt_id, 'Age']
trgt_pred = data_full.at[trgt_id, 'Prediction']
trgt_aa = trgt_pred-trgt_age
print(trgt_age)
print(trgt_pred)

n_closest = 200
data_closest = data_ctrl.iloc[(data_ctrl['Prediction'] - trgt_age).abs().argsort()[:n_closest]]

explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, feats])
print(explainer.expected_value)
shap_values = explainer.shap_values(data_full.loc[[trgt_id], feats].values)[0]
shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)

shap.plots.waterfall(
    shap.Explanation(
        values=shap_values,
        base_values=trgt_age,
        data=data_full.loc[trgt_id, feats].values,
        feature_names=[f.replace('_log', '') for f in feats]
    ),
    max_display=len(feats),
    show=True,
)

In [None]:
sns.set_theme(style='ticks')
n_rows = 4
n_cols = 6
fig_height = 10
fig_width = 20
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)
for feat_id, feat in tqdm(enumerate(feats)):
    row_id, col_id = divmod(feat_id, n_cols)

    kdeplot = sns.kdeplot(
        data=data_closest,
        x=feat,
        color='gray',
        linewidth=1,
        cut=0,
        ax=axs[row_id, col_id]
    )
    kdeline = axs[row_id, col_id].lines[0]
    xs = kdeline.get_xdata()
    ys = kdeline.get_ydata()
    trgt_val = data_full.at[trgt_id, feat]
    trgt_prctl = scipy.stats.percentileofscore(data_closest.loc[:, feat], trgt_val)
    axs[row_id, col_id].fill_between(xs, 0, ys, where=(xs <= trgt_val), interpolate=True, facecolor='dodgerblue', alpha=0.7)
    axs[row_id, col_id].fill_between(xs, 0, ys, where=(xs >= trgt_val), interpolate=True, facecolor='crimson', alpha=0.7)
    axs[row_id, col_id].vlines(trgt_val, 0, np.interp(trgt_val, xs, ys), color='black', linewidth=1.5)
    axs[row_id, col_id].text(np.mean([min(xs), trgt_val]), 0.1 * max(ys), f"{trgt_prctl:0.1f}%", fontstyle="oblique",
            color="black", ha="center", va="center")
    axs[row_id, col_id].text(np.mean([max(xs), trgt_val]), 0.1 * max(ys), f"{100 - trgt_prctl:0.1f}%", fontstyle="oblique",
            color="black", ha="center", va="center")
    axs[row_id, col_id].ticklabel_format(style='scientific', scilimits=(-1, 1), axis='y', useOffset=True)
fig.tight_layout()    
plt.show()


In [None]:
df_less_more = pd.DataFrame(index=[f.replace('_log', '') for f in feats], columns=['Less', 'More'])
for f in df_less_more.index:
    df_less_more.at[f, 'Less'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f"{f}_log"].values, data_full.at[trgt_id, f"{f}_log"]))
    df_less_more.at[f, 'More'] = 100.0 - df_less_more.at[f, 'Less']

df_shap = pd.DataFrame(index=[f.replace('_log', '') for f in feats], data=shap_values, columns=[trgt_id])
df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap['cumsum'] = df_shap[trgt_id].cumsum()

fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=True, column_widths=[2, 1])
fig.add_trace(
    go.Waterfall(
        hovertext=["Chrono Age", "EpImAge"],
        orientation="h",
        measure=['absolute', 'relative'],
        y=[-1.5, df_shap.shape[0] + 0.5],
        x=[trgt_age, trgt_aa],
        base=0,
        text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}"],
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
    row=1,
    col=1
)
fig.add_trace(
    go.Waterfall(
        hovertext=df_shap.index.values,
        orientation="h",
        measure=["relative"] * len(feats),
        y=list(range(df_shap.shape[0])),
        x=df_shap[trgt_id].values,
        base=trgt_age,
        text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values],
        textposition = "auto",
        decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "solid"},
        },
    ),
    row=1,
    col=1,
)
fig.update_traces(row=1, col=1, showlegend=False)

fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Less',
        x=df_less_more.loc[df_shap.index.values, 'Less'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='steelblue', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Less'],
        textposition='auto'
    ),
    row=1,
    col=2
)
fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='More',
        x=df_less_more.loc[df_shap.index.values, 'More'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='violet', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'More'],
        textposition='auto',
    ),
    row=1,
    col=2
)
fig.update_layout(barmode="relative")
fig.update_layout(legend=dict(
    title=dict(text="Immunomarkers' disribution<br>in samples with same age", side="top center"),
    orientation="h",
    yanchor="bottom",
    y=0.95,
    xanchor="center",
    x=0.84
))

fig.update_layout(
    title=f"{trgt_id} XAI age acceleration",
    titlefont=dict(size=25),
    template="none",  # 'simple_white'
    width=800,
    height=1000,
    margin=go.layout.Margin(l=120, r=20, b=50, t=50, pad=0),
)
fig.update_yaxes(
    tickmode="array",
    tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5],
    ticktext=["Chrono Age"] + df_shap.index.to_list() + ["EpImAge"],
    tickfont=dict(size=18),
    row=1,
    col=1
)
fig.update_xaxes(
    title='Age',
    titlefont=dict(size=25),
    range=[
        trgt_age - df_shap['cumsum'].abs().max() * 1.25,
        trgt_age + df_shap['cumsum'].abs().max() * 1.25
    ],
    row=1,
    col=1
)
fig.update_xaxes(
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
    row=1,
    col=2
)
fig.update_yaxes(
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
    row=1,
    col=2
)
fig.show()