# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import itertools
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.pt.hyper_opt import train_hyper_opt
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
import optuna


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]


# Load data

In [None]:
epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(knn)_params(5)' # 'origin' 'imp_source(imm)_method(knn)_params(5)' 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'mrmr' # 'f_regression' 'mrmr'
n_feats = 100

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 20

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337
val_fold_id = 5

path = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}/EpImAge"
path_epi = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/epi"
path_configs = "D:/Work/bbs/notebooks/immunology/003_EpImAge/age_regression_configs"

data_full = pd.read_excel(f"{path}/data.xlsx", index_col=0)

# Filtering data
status_count = data_full['Status'].value_counts()
statuses_passed = status_count[status_count >= 10].index.values.tolist()
data_full = data_full[data_full['Status'].isin(statuses_passed)]
data_full.drop(data_full.index[data_full['Status'] == 'ICU'], inplace=True)
# data_full.to_excel(f"{path}/data_filtered.xlsx", index_label='ID')

status_count = data_full['Status'].value_counts()

feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0).index.values.tolist()
feats = [f"{f}_log" for f in feats]

gse_preproc = pd.read_excel(f"{path_epi}/preproc.xlsx", index_col=0)

df_groups = pd.read_excel(f"{path_epi}/groups.xlsx", index_col=0)
icd_chpts = np.sort(df_groups['ICD-11 chapter'].unique())
icd_codes = np.sort(df_groups['ICD-11 code'].unique())
colors = distinctipy.get_colors(len(icd_chpts), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
colors = [make_rgb_transparent(color, (1, 1, 1), 0.75) for color in colors]
colors_icd_chpts = {icd_chpt: colors[icd_chpt_id] for icd_chpt_id, icd_chpt in enumerate(icd_chpts)}

clocks_tests = pd.read_excel(f"{path_epi}/clocks.xlsx", index_col=0)

# Load stratification

In [None]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# Train, Validation, Test selection

In [None]:
split_dict = samples[tst_split_id]

test = data_full.loc[split_dict['test'], feats + ["Age"]]
train = data_full.loc[split_dict['trains'][val_fold_id], feats + ["Age"]]
validation = data_full.loc[split_dict['validations'][val_fold_id], feats + ["Age"]]

# Optuna training

## Models setup

In [None]:
seed_target = 1337  # 1337 42 451 1984 1899 1408

models_runs = {
    'GANDALF': {
        'config': GANDALFConfig,
        'n_trials': 1024,
        'seed': seed_target,
        'n_startup_trials': 256,
        'n_ei_candidates': 16
    },
    # 'FTTransformer': {
    #     'config': FTTransformerConfig,
    #     'n_trials': 512,
    #     'seed': 1337,
    #     'n_startup_trials': 256,
    #     'n_ei_candidates': 16
    # },
    # 'DANet': {
    #     'config': DANetConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'CategoryEmbeddingModel': {
    #     'config': CategoryEmbeddingModelConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'TabNetModel': {
    #     'config': TabNetModelConfig,
    #     'n_trials': 256,
    #     'seed': 1337,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # }
}

## Training

In [None]:
dfs_models = []

for model_name, model_run in models_runs.items():

    model_config_name = model_run['config']
    n_trials = model_run['n_trials']
    seed = model_run['seed']
    n_startup_trials = model_run['n_startup_trials']
    n_ei_candidates = model_run['n_ei_candidates']

    data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
    data_config['target'] = ['Age']
    data_config['continuous_cols'] = feats
    trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
    trainer_config['checkpoints_path'] = f"{path}/pytorch_tabular"
    optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

    lr_find_min_lr = 1e-8
    lr_find_max_lr = 10
    lr_find_num_training = 256
    lr_find_mode = "exponential"
    lr_find_early_stop_threshold = 8.0

    trainer_config['seed'] = seed
    trainer_config['checkpoints'] = 'valid_loss'
    trainer_config['load_best'] = True
    trainer_config['auto_lr_find'] = False

    model_config_default = read_parse_config(f"{path_configs}/models/{model_name}Config.yaml", model_config_name)
    tabular_model_default = TabularModel(
        data_config=data_config,
        model_config=model_config_default,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=False,
    )
    datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)

    opt_parts = ['test', 'validation']
    opt_metrics = [('mean_absolute_error', 'minimize')]
    # opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('pearson_corrcoef', 'maximize')]
    opt_directions = []
    for part in opt_parts:
        for metric_pair in opt_metrics:
            opt_directions.append(f"{metric_pair[1]}")

    trials_results = []

    study = optuna.create_study(
        study_name=model_name,
        sampler=optuna.samplers.TPESampler(
            n_startup_trials=n_startup_trials,
            n_ei_candidates=n_ei_candidates,
            seed=seed,
        ),
        directions=opt_directions
    )
    study.optimize(
        func=lambda trial: train_hyper_opt(
            trial=trial,
            trials_results=trials_results,
            opt_metrics=opt_metrics,
            opt_parts=opt_parts,
            model_config_default=model_config_default,
            data_config_default=data_config,
            optimizer_config_default=optimizer_config,
            trainer_config_default=trainer_config,
            experiment_config_default=None,
            train=train,
            validation=validation,
            test=test,
            datamodule=datamodule,
            min_lr=lr_find_min_lr,
            max_lr=lr_find_max_lr,
            num_training=lr_find_num_training,
            mode=lr_find_mode,
            early_stop_threshold=lr_find_early_stop_threshold
        ),
        n_trials=n_trials,
        show_progress_bar=True
    )

    fn_trials = (
        f"model({model_name})_"
        f"trials({n_trials}_{seed}_{n_startup_trials}_{n_ei_candidates})_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
    )

    df_trials = pd.DataFrame(trials_results)
    df_trials['split_id'] = tst_split_id
    df_trials['fold_id'] = val_fold_id
    df_trials["train_more"] = False
    df_trials.loc[(df_trials["train_loss"] > df_trials["test_loss"]) | (
            df_trials["train_loss"] > df_trials["validation_loss"]), "train_more"] = True
    df_trials["validation_test_mean_loss"] = (df_trials["validation_loss"] + df_trials["test_loss"]) / 2.0
    df_trials["train_validation_test_mean_loss"] = (df_trials["train_loss"] + df_trials["validation_loss"] + df_trials["test_loss"]) / 3.0
    df_trials.sort_values(by=['test_loss'], ascending=[True], inplace=True)
    df_trials.style.background_gradient(
        subset=[
            "train_loss",
            "validation_loss",
            "validation_test_mean_loss",
            "train_validation_test_mean_loss",
            "test_loss",
            "time_taken",
            "time_taken_per_epoch"
        ], cmap="RdYlGn_r"
    ).to_excel(f"{trainer_config['checkpoints_path']}/{fn_trials}.xlsx")

    dfs_models.append(df_trials)

df_models = pd.concat(dfs_models, ignore_index=True)
df_models.insert(0, 'Selected', 0)
fn = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)
df_models.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "validation_test_mean_loss",
        "train_validation_test_mean_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path}/pytorch_tabular/{fn}.xlsx")


# Perform tests on models

In [None]:
fn_sweep = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)

df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{fn_sweep}.xlsx", index_col=0)

df_sweeps.insert(15, 'Passed\nICD-11\nTotal', None)
ins_pos = 16
for icd_chpt in icd_chpts:
    df_sweeps.insert(ins_pos, f'Passed\nICD-11\nChapter {icd_chpt}', None)
    ins_pos += 1

for model_id, model_row in (pbar := tqdm(df_sweeps.iterrows())):
    pbar.set_description(f"Processing {model_id}")

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)

    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']

    passed_icd = {icd_chpt: 0 for icd_chpt in icd_chpts}
    for icd_chpt in icd_chpts:
        df_ckpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
        for section_id, section_row in df_ckpt.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])
            df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status'].isin(section_statuses)), ['Status', 'Error']]
            
            for section_group_id, section_group in enumerate(section_groups):
                stat, pval = mannwhitneyu(
                    df_section.loc[df_section["Status"] == section_group[0], "Error"].values,
                    df_section.loc[df_section["Status"] == section_group[1], "Error"].values,
                    alternative="two-sided",
                )
                bias_0 = np.mean(df_section.loc[df_section['Status'] == section_group[0], 'Error'])
                bias_1 = np.mean(df_section.loc[df_section['Status'] == section_group[1], 'Error'])
                group_direction = section_directions[section_group_id]
                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd[icd_chpt] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd[icd_chpt] += 1
        
        df_sweeps.at[model_id, f'Passed\nICD-11\nChapter {icd_chpt}'] = passed_icd[icd_chpt]              
        
    df_sweeps.at[model_id, f'Passed\nICD-11\nTotal'] = sum(passed_icd.values())
    
df_sweeps.to_excel(f"{path}/pytorch_tabular/models.xlsx", index_label='Model ID')

# Best models processing

In [None]:
gse_controls_count = data_full.loc[data_full['Status'] == 'Control', 'GSE'].value_counts()
gses_controls = gse_controls_count.index.values
gse_controls_ids = {gse: data_full.index[(data_full['Status'] == 'Control') & (data_full['GSE'] == gse)].values for gse in gses_controls}
colors = distinctipy.get_colors(len(gses_controls), [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_gse = {gses_controls[gse_id]: colors[gse_id] for gse_id in range(len(gses_controls))}

statuses_rename = {x: x.replace(' ', '\n').replace('-', '\n') for x in data_full['Status'].value_counts().index.values}
data_full['Status'] = data_full['Status'].replace(statuses_rename) 
status_count = data_full['Status'].value_counts()
statuses = status_count.index.values
colors = distinctipy.get_colors(len(statuses) - 1, [mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue']), mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_status = {'Control': mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue'])}
for status_id, status in enumerate(statuses):
    if status != 'Control':
        colors_status[status] = colors[status_id - 1]

mosaic_violins = []
mosaic_rows = np.sort(df_groups['Row on violin plot'].unique())
for mosaic_row in mosaic_rows:
    df_mosaic_row = df_groups[df_groups['Row on violin plot'] == mosaic_row].sort_values(by=['ICD-11 chapter', 'ICD-11 code', 'GSE'], ascending=[True, True, True])
    mosaic_row_labels = []
    for plot_id, plot_row in df_mosaic_row.iterrows():
        n_violins = len(ast.literal_eval(plot_row['Statuses']))
        mosaic_row_labels += [plot_id]*n_violins
    mosaic_violins.append(mosaic_row_labels)

max_mosaic_row = max([len(x) for x in mosaic_violins])
violons_empty_panels = set()
for row_id, row in enumerate(mosaic_violins):
    for added_spaces in range(len(row), max_mosaic_row):
        violons_empty_panels.add(f'Empty row {row_id}')
        mosaic_violins[row_id].append(f'Empty row {row_id}')


In [None]:

is_explain = False
explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

models_type = 'FTTransformer'
models_ids = [
    39,
    # 15,
    # 43,
    # 20,
    # 2,
    # 25,
    # 12,
    # 0,
    # 94,
]
models_ids = sorted(list(set(models_ids)))

df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{models_type}/models.xlsx", index_col=0)
path_models = f"{path}/pytorch_tabular/{models_type}/candidates"
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/selected.xlsx")

df_models_metrics = pd.DataFrame(index=models_ids)
df_models_check = pd.DataFrame(index=models_ids)
df_gses_models_metrics = {}
for md in [f"{x[0]} {x[1]}" for x in itertools.product(['MAE', 'Rho', 'Bias'], ['Train', 'Validation', 'Test', 'Total'])]:
    df_gses_models_metrics[md] = pd.DataFrame(index=gses_controls, columns=['Count'] + list(models_ids))
    df_gses_models_metrics[md].loc[gses_controls, 'Count'] = gse_controls_count[gses_controls]

for model_id in models_ids:
    print(model_id)

    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]
    ids_test = split_dict['test']
    ids_train = split_dict['trains'][fold_id]
    ids_validation = split_dict['validations'][fold_id]
    ids_total = np.concatenate([ids_train, ids_validation, ids_test])
    ids_dict = {
        'Test': ids_test,
        'Train': ids_train,
        'Validation': ids_validation,
        'Total': ids_total
    }

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + f'/{models_type}/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_models}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_models}/{model_id}", dirs_exist_ok=True)

    data_full['Group'] = ''
    data_full.loc[ids_train, 'Group'] = 'Train'
    data_full.loc[ids_validation, 'Group'] = 'Validation'
    data_full.loc[ids_test, 'Group'] = 'Test'
    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']
    data_full[['GPL', 'GSE', 'Age', 'Sex', 'Status', 'Group', 'Prediction', 'Error']].to_excel(f"{path_models}/{model_id}/data.xlsx")
    
    
    df_metrics = pd.DataFrame(
        index=list(ids_dict.keys()),
        columns=['MAE', 'Rho', 'Bias']
    )
    
    for part, ids_part in ids_dict.items():
        
        pred = torch.from_numpy(data_full.loc[ids_part, 'Prediction'].values)
        real = torch.from_numpy(data_full.loc[ids_part, 'Age'].values)
        
        df_metrics.at[part, 'MAE'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[part, 'Rho'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[part, 'Bias'] = np.mean(data_full.loc[ids_part, 'Error'].values)
        
        df_models_metrics.at[model_id, f'MAE\n{part}'] = df_metrics.at[part, 'MAE']
        df_models_metrics.at[model_id, f'Rho\n{part}'] = df_metrics.at[part, 'Rho']
        df_models_metrics.at[model_id, f'Bias\n{part}'] = df_metrics.at[part, 'Bias']
        
        if part != 'Total':
            df_models_check.at[model_id, f'{part} MAE Before'] = df_sweeps.at[model_id, f'{part.lower()}_mean_absolute_error']
            df_models_check.at[model_id, f'{part} Rho Before'] = df_sweeps.at[model_id, f'{part.lower()}_pearson_corrcoef']
            df_models_check.at[model_id, f'{part} MAE After'] = df_metrics.at[part, 'MAE']
            df_models_check.at[model_id, f'{part} Rho After'] = df_metrics.at[part, 'Rho']
        
    df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    for gse, ids_gse in gse_controls_ids.items():
        for part, ids_part in ids_dict.items():
            ids_intxn = list(set.intersection(set(ids_gse), set(ids_part)))
            if len(ids_intxn) > 0:
                pred = torch.from_numpy(data_full.loc[ids_intxn, 'Prediction'].values)
                real = torch.from_numpy(data_full.loc[ids_intxn, 'Age'].values)
                df_gses_models_metrics[f'MAE {part}'].at[gse, model_id] = mean_absolute_error(pred, real).numpy()
                df_gses_models_metrics[f'Rho {part}'].at[gse, model_id] = pearson_corrcoef(pred, real).numpy()
                df_gses_models_metrics[f'Bias {part}'].at[gse, model_id] = np.mean(data_full.loc[ids_intxn, 'Error'].values)
    
    # Mosaic violins plot
    sns.set_theme(style='ticks')
    fig_height = 5 * len(mosaic_violins)
    fig_width = 1.4 * max_mosaic_row
    fig, axs = plt.subplot_mosaic(mosaic=mosaic_violins, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)

    for plot_id, plot_row in df_groups.iterrows():

        plot_statuses = [statuses_rename[x] for x in ast.literal_eval(plot_row['Statuses'])]
        plot_groups = [(statuses_rename[x[0]], statuses_rename[x[1]]) for x in ast.literal_eval(plot_row['Groups'])]

        df_plot = data_full.loc[(data_full['GSE'] == plot_row['GSE']) & (data_full['Status'].isin(plot_statuses)), ['Status', 'Error']]
        plot_status_count = df_plot['Status'].value_counts()
        plot_statuses_rename = {}
        for x in plot_statuses:
            plot_statuses_rename[x] = x + f"\nCount: {plot_status_count[x]}\nBias: {np.mean(df_plot.loc[df_plot['Status'] == x, 'Error']):0.1f}"
        plot_statuses = [plot_statuses_rename[x] for x in plot_statuses]
        plot_groups = [(plot_statuses_rename[x[0]], plot_statuses_rename[x[1]]) for x in plot_groups]
        df_plot['Status'] = df_plot['Status'].replace(plot_statuses_rename)
        colors_plot_status = {plot_statuses_rename[x]: colors_status[x] for x in plot_statuses_rename}
        
        pval_formatted = []
        for plot_group in plot_groups:
            stat, pval = mannwhitneyu(
                df_plot.loc[df_plot["Status"] == plot_group[0], "Error"].values,
                df_plot.loc[df_plot["Status"] == plot_group[1], "Error"].values,
                alternative="two-sided",
            )
            pval_formatted.append(f"{pval:.1e}")
            
        violinplot = sns.violinplot(
            data=df_plot,
            x='Status',
            y='Error',
            palette=colors_plot_status,
            scale='width',
            order=plot_statuses,
            saturation=0.75,
            legend=False,
            ax=axs[plot_id]
        )
        annotator = Annotator(
            ax=axs[plot_id],
            pairs=plot_groups,
            data=df_plot,
            x="Status",
            y="Error",
            order=plot_statuses,
        )
        annotator.set_custom_annotations(pval_formatted)
        annotator.configure(loc='inside', verbose=0)
        annotator.annotate()

        axs[plot_id].set_xlabel('')
        axs[plot_id].set_title(f"{plot_row['GSE']} ({df_plot.shape[0]})")
        axs[plot_id].set_facecolor(colors_icd_chpts[plot_row['ICD-11 chapter']])

    for empty_panel in violons_empty_panels:
        axs[empty_panel].axis('off')
    fig.tight_layout()    
    fig.savefig(f"{path_models}/{model_id}/violins_Status.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/violins_Status.pdf", bbox_inches='tight')
    plt.close(fig)
    
    # Ridgeline (column of violins) for controls' error 
    for part, ids_part in ids_dict.items():
        # Errors in GSEs
        df_fig = data_full.loc[ids_part, ['Error', 'GSE']].copy()
        df_fig['GSE'] = pd.Categorical(df_fig.GSE, categories=gses_controls, ordered=True)
        df_fig = df_fig.sort_values('GSE')
        gses_rename = {
            gse: f"{gse} ({gse_controls_count[gse]})" + "\n" +
                   fr"MAE: {df_gses_models_metrics[f'MAE {part}'].at[gse, model_id]:0.2f}" + "\n"
                   fr"Pearson $\rho$: {df_gses_models_metrics[f'Rho {part}'].at[gse, model_id]:0.2f}" + "\n" +
                   fr"Bias: {df_gses_models_metrics[f'Bias {part}'].at[gse, model_id]:0.2f}"  + "\n" + 
                   f"{gse_preproc.at[gse, 'Preproc']}"
            for gse in colors_gse
        }
        gse_colors_grid = {gses_rename[gse]: colors_gse[gse] for gse in colors_gse}
        df_fig['GSE'].replace(gses_rename, inplace=True) 
        sns.set_theme(style="whitegrid", rc={"axes.facecolor": (0, 0, 0, 0)})
        g = sns.FacetGrid(df_fig, row="GSE", hue="GSE", aspect=8, height=0.5, palette=gse_colors_grid)
        g.map(
            sns.kdeplot,
            'Error',
            fill=True, 
            alpha=1.0,
            linewidth=0.5
        )
        g.map(
            sns.kdeplot,
            'Error',
            color="black",
            linewidth=1.0,
        )
        # g.refline(y=0, linewidth=2.0, linestyle="-", color=None, clip_on=False)
        def label(x, color, label):
            ax = plt.gca()
            ax.text(-0.15, 0.2, label, size=4, fontweight="light", color=color, ha="left", va="center", transform=ax.transAxes, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
        g.map(label, 'Error')
        # Set the subplots to overlap
        g.figure.subplots_adjust(hspace=0.0)
        # Remove axes details that don't play well with overlap
        g.set_titles("")
        g.set(yticks=[], ylabel="")
        g.despine(bottom=True, left=True)
        # Save
        g.savefig(f"{path_models}/{model_id}/ridgeline_GSEs_error_{part}.png", bbox_inches='tight', dpi=200)
        g.savefig(f"{path_models}/{model_id}/ridgeline_GSEs_error_{part}.pdf", bbox_inches='tight')
        plt.close(g.fig)

    if is_explain:
        try:
            explanation = model.explain(data_full[ids_total], method=explain_method, baselines=explain_baselines)
            explanation.index = data_full[ids_total].index
            explanation.to_excel(f"{path_models}/{model_id}/explanation.xlsx")

            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data_full.loc[ids_total, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="violin",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
            plt.close(fig)

            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data_full.loc[ids_total, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="bar",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_bar.pdf", bbox_inches='tight')
            plt.close(fig)

        except NotImplementedError:
            pass

df_models_metrics.to_excel(f"{path_models}/models_metrics.xlsx", index_label='Model ID')

for group in ['Train', 'Validation', 'Test']:
    df_models_check[f'{group} MAE Diff'] = df_models_check[f'{group} MAE After'] - df_models_check[f'{group} MAE Before']
    df_models_check[f'{group} Rho Diff'] = df_models_check[f'{group} Rho After'] - df_models_check[f'{group} Rho Before']
df_models_check.to_excel(f"{path_models}/models_check.xlsx", index_label='Model ID')

with pd.ExcelWriter(f"{path_models}/gses_models_metrics.xlsx", engine='xlsxwriter') as writer:
    for md in [f"{x[0]} {x[1]}" for x in itertools.product(['MAE', 'Rho', 'Bias'], ['Train', 'Validation', 'Test', 'Total'])]:
        df_gses_models_metrics[md].insert(1, 'Preproc', gse_preproc.loc[df_gses_models_metrics[md].index, 'Preproc'])
        df_gses_models_metrics[md].to_excel(writer, sheet_name=md)