# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [1]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data

In [2]:
feats_set = 'RheumatologyScreening_no-sex'

path = f"E:/YandexDisk/Work/bbd/atlas/subset_{feats_set}"
path_ckpts = f"E:/Git/bbs/notebooks/atlas/pt/{feats_set}"
path_configs = "E:/Git/bbs/notebooks/atlas/configs"

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 20

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337
val_fold_id = 5

data = pd.read_excel(f"{path}/data.xlsx", index_col=0)
df_feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0)
feat_trgt = 'Возраст'
feats_cnt = df_feats.index[df_feats['Type'] == 'continuous'].drop('Возраст').to_list()
feats_cat = df_feats.index[df_feats['Type'] == 'categorical'].to_list()
feats = list(feats_cnt) + list(feats_cat)

In [None]:
feats_ru_2_en = dict(zip(df_feats.index, df_feats['English']))
feats_en_2_ru = dict(zip(df_feats['English'], df_feats.index))
feats_cnt = [feats_ru_2_en[x] for x in feats_cnt]
feats_cat = [feats_ru_2_en[x] for x in feats_cat]
feats = [feats_ru_2_en[x] for x in feats]
feat_trgt = feats_ru_2_en[feat_trgt]
data.rename(columns=feats_ru_2_en, inplace=True)

# Load stratification

In [3]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)

# Train, Validation, Test selection

In [4]:
split_dict = samples[tst_split_id]

test = data.loc[split_dict['test'], feats + [feat_trgt]]
train = data.loc[split_dict['trains'][val_fold_id], feats + [feat_trgt]]
validation = data.loc[split_dict['validations'][val_fold_id], feats + [feat_trgt]]

# Optuna training

## Models setup

In [5]:
seed_target = 1337  # 1337 42 451 1984 1899 1408

models_runs = {
    # 'GANDALF': {
    #     'config': GANDALFConfig,
    #     'n_trials': 512,
    #     'seed': seed_target,
    #     'n_startup_trials': 128,
    #     'n_ei_candidates': 16
    # },
    # 'FTTransformer': {
    #     'config': FTTransformerConfig,
    #     'n_trials': 512,
    #     'seed': seed_target,
    #     'n_startup_trials': 128,
    #     'n_ei_candidates': 16
    # },
    'DANet': {
        'config': DANetConfig,
        'n_trials': 1024,
        'seed': seed_target,
        'n_startup_trials': 256,
        'n_ei_candidates': 16
    },
    # 'CategoryEmbeddingModel': {
    #     'config': CategoryEmbeddingModelConfig,
    #     'n_trials': 256,
    #     'seed': seed_target,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'TabNetModel': {
    #     'config': TabNetModelConfig,
    #     'n_trials': 256,
    #     'seed': seed_target,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # }
}

## Training

In [None]:
%%capture

dfs_models = []

for model_name, model_run in models_runs.items():

    model_config_name = model_run['config']
    n_trials = model_run['n_trials']
    seed = model_run['seed']
    n_startup_trials = model_run['n_startup_trials']
    n_ei_candidates = model_run['n_ei_candidates']

    data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
    data_config['target'] = [feat_trgt]
    data_config['continuous_cols'] = [str(x) for x in feats_cnt]
    data_config['categorical_cols'] = feats_cat
    trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
    trainer_config['checkpoints_path'] = path_ckpts
    optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

    lr_find_min_lr = 1e-8
    lr_find_max_lr = 10
    lr_find_num_training = 256
    lr_find_mode = "exponential"
    lr_find_early_stop_threshold = 8.0

    trainer_config['seed'] = seed
    trainer_config['checkpoints'] = 'valid_loss'
    trainer_config['load_best'] = True
    trainer_config['auto_lr_find'] = False

    model_config_default = read_parse_config(f"{path_configs}/models/{model_name}Config.yaml", model_config_name)
    tabular_model_default = TabularModel(
        data_config=data_config,
        model_config=model_config_default,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=False,
    )
    datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)

    opt_parts = ['test', 'validation']
    opt_metrics = [('mean_absolute_error', 'minimize')]
    # opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('pearson_corrcoef', 'maximize')]
    opt_directions = []
    for part in opt_parts:
        for metric_pair in opt_metrics:
            opt_directions.append(f"{metric_pair[1]}")

    trials_results = []

    study = optuna.create_study(
        study_name=model_name,
        sampler=optuna.samplers.TPESampler(
            n_startup_trials=n_startup_trials,
            n_ei_candidates=n_ei_candidates,
            seed=seed,
        ),
        directions=opt_directions
    )
    study.optimize(
        func=lambda trial: train_hyper_opt(
            trial=trial,
            trials_results=trials_results,
            opt_metrics=opt_metrics,
            opt_parts=opt_parts,
            model_config_default=model_config_default,
            data_config_default=data_config,
            optimizer_config_default=optimizer_config,
            trainer_config_default=trainer_config,
            experiment_config_default=None,
            train=train,
            validation=validation,
            test=test,
            datamodule=datamodule,
            min_lr=lr_find_min_lr,
            max_lr=lr_find_max_lr,
            num_training=lr_find_num_training,
            mode=lr_find_mode,
            early_stop_threshold=lr_find_early_stop_threshold
        ),
        n_trials=n_trials,
        show_progress_bar=True
    )

    fn_trials = (
        f"model({model_name})_"
        f"trials({n_trials}_{seed}_{n_startup_trials}_{n_ei_candidates})_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
    )

    df_trials = pd.DataFrame(trials_results)
    df_trials['split_id'] = tst_split_id
    df_trials['fold_id'] = val_fold_id
    df_trials["train_more"] = False
    df_trials.loc[(df_trials["train_loss"] > df_trials["test_loss"]) | (
            df_trials["train_loss"] > df_trials["validation_loss"]), "train_more"] = True
    df_trials["validation_test_mean_loss"] = (df_trials["validation_loss"] + df_trials["test_loss"]) / 2.0
    df_trials["train_validation_test_mean_loss"] = (df_trials["train_loss"] + df_trials["validation_loss"] + df_trials["test_loss"]) / 3.0
    df_trials.style.background_gradient(
        subset=[
            "train_loss",
            "validation_loss",
            "validation_test_mean_loss",
            "train_validation_test_mean_loss",
            "test_loss",
            "time_taken",
            "time_taken_per_epoch"
        ], cmap="RdYlGn_r"
    ).to_excel(f"{trainer_config['checkpoints_path']}/{fn_trials}.xlsx")

    dfs_models.append(df_trials)

df_models = pd.concat(dfs_models, ignore_index=True)
df_models.insert(0, 'Selected', 0)
fn = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)
df_models.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "validation_test_mean_loss",
        "train_validation_test_mean_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{trainer_config['checkpoints_path']}/{fn}.xlsx")


## Best models analysis

In [None]:
explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

models_type = 'DANet' # 'GANDALF'
models_ids = [
208,
48,
1015,
348,
]
models_ids = sorted(list(set(models_ids)))

df_sweeps = pd.read_excel(
    (
        f"{path_ckpts}/"
        f"model({models_type})_"
        f"trials("
        f"{models_runs[models_type]['n_trials']}_"
        f"{models_runs[models_type]['seed']}_"
        f"{models_runs[models_type]['n_startup_trials']}_"
        f"{models_runs[models_type]['n_ei_candidates']})_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
        f".xlsx"
    ),
    index_col=0
)
# df_sweeps = pd.read_excel(
#     (
#         f"{path_ckpts}/"
#         f"progress"
#         f".xlsx"
#     ),
#     index_col=0
# )

path_to_candidates = f"{path_ckpts}/candidates/{models_type}"
pathlib.Path(path_to_candidates).mkdir(parents=True, exist_ok=True)
df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_to_candidates}/selected.xlsx")

for model_id in models_ids:

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_to_candidates}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_to_candidates}/{model_id}", dirs_exist_ok=True)
    
    df = data.loc[:, [feat_trgt]]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = model.predict(data)
    df['Error'] = df['Prediction'] - df[feat_trgt]
    corrector = LinearBiasCorrector()
    corrector.fit(df.loc[df['Group'] == 'Train', feat_trgt].values, df.loc[df['Group'] == 'Train', 'Prediction'].values)
    df['Prediction Unbiased'] = corrector.predict(df['Prediction'].values)
    df['Error Unbiased'] = df['Prediction Unbiased'] - df[feat_trgt]
    df.to_excel(f"{path_to_candidates}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=[
            'mean_absolute_error', 'pearson_corrcoef', 'bias',
            'mean_absolute_error_unbiased', 'pearson_corrcoef_unbiased', 'bias_unbiased'
        ]
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        pred_unbiased = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction Unbiased'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, feat_trgt].values.astype(np.float32))
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
        df_metrics.at[group, 'mean_absolute_error_unbiased'] = mean_absolute_error(pred_unbiased, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef_unbiased'] = pearson_corrcoef(pred_unbiased, real).numpy()
        df_metrics.at[group, 'bias_unbiased'] = np.mean(df.loc[df['Group'] == group, 'Error Unbiased'].values)
    df_metrics.to_excel(f"{path_to_candidates}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    xy_min = df[[feat_trgt, 'Prediction']].min().min()
    xy_max = df[[feat_trgt, 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    
    xy_min_unbiased = df[[feat_trgt, 'Prediction Unbiased']].min().min()
    xy_max_unbiased = df[[feat_trgt, 'Prediction Unbiased']].max().max()
    xy_ptp_unbiased = xy_max_unbiased - xy_min_unbiased
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction Unbiased",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        y=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    ax.set_ylim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))   
    scatter = sns.scatterplot(
        data=df,
        x=feat_trgt,
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error Unbiased', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error_unbiased']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef_unbiased']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias_unbiased']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error Unbiased',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    try:
        explanation = model.explain(data, method=explain_method, baselines=explain_baselines)
        explanation.index = data.index
        explanation.to_excel(f"{path_to_candidates}/{model_id}/explanation.xlsx")
        
        # sns.set_theme(style='whitegrid')
        # fig = shap.summary_plot(
        #     shap_values=explanation.loc[:, feats].values,
        #     features=data.loc[:, feats].values,
        #     feature_names=feats,
        #     max_display=explain_n_feats_to_plot,
        #     plot_type="violin",
        #     show=False,
        # )
        # plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
        # plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
        # plt.close(fig)
        
        sns.set_theme(style='ticks')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="bar",
            show=False,
            plot_size=[12,8]
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass

# Best models processing

## Load data and models for all subsets of features

In [None]:
path = f"E:/YandexDisk/Work/bbd/atlas"

feat_trgt = 'Возраст'

datasets = {
    "RheumatologyScreening": {
        "name": "Ревматологический тест",
        "path": f"{path}/subset_RheumatologyScreening_no-sex",
        "path_model": f"{path}/subset_RheumatologyScreening_no-sex/models/DANet/1/48",
        "color": "gray",
    },
    "ProstateSpecificAntigenTest": {
        "name": "Простатический специфический антиген",
        "path": f"{path}/subset_ProstateSpecificAntigenTest_no-sex",
        "path_model": f"{path}/subset_ProstateSpecificAntigenTest_no-sex/models/DANet/1/529",
        "color": "lawngreen",
    },
    "HormoneProfile": {
        "name": "Гормональный профиль",
        "path": f"{path}/subset_HormoneProfile_no-sex",
        "path_model": f"{path}/subset_HormoneProfile_no-sex/models/DANet/1/425",
        "color": "chocolate",
    },
    "BloodBiochemical": {
        "name": "Биохимия крови",
        "path": f"{path}/subset_BloodBiochemical_no-sex",
        "path_model": f"{path}/subset_BloodBiochemical_no-sex/models/DANet/1/630",
        "color": "cyan",
    },
    "BloodPressure": {
        "name": "Кровяное Давление",
        "path": f"{path}/subset_BloodPressure_no-sex",
        "path_model": f"{path}/subset_BloodPressure_no-sex/models/DANet/1/75",
        "color": "orchid",
    },
    "CoagulationTest": {
        "name": "Коагулограмма",
        "path": f"{path}/subset_CoagulationTest_no-sex",
        "path_model": f"{path}/subset_CoagulationTest_no-sex/models/DANet/1/868",
        "color": "olive",
    },
    "CBC": {
        "name": "Общий анализ крови",
        "path": f"{path}/subset_CBC_no-sex",
        "path_model": f"{path}/subset_CBC_no-sex/models/DANet/2/67",
        "color": "crimson",
    },
    "InBody": {
        "name": "Биоимпеданс (InBody)",
        "path": f"{path}/subset_InBody-mRMR_no-sex",
        "path_model": f"{path}/subset_InBody-mRMR_no-sex/models/DANet/2/432",
        "color": "dodgerblue",
    },
    "LipidProfile": {
        "name": "Липидный профиль",
        "path": f"{path}/subset_LipidProfile_no-sex",
        "path_model": f"{path}/subset_LipidProfile_no-sex/models/DANet/1/798",
        "color": "gold",
    },
    
    
    # 'inbody':
    #     {
    #         'name': 'Биоимпеданс (InBody)',
    #         'path': f"{path}/subset_inbody",
    #         'path_model': f"{path}/subset_inbody/models/DANet/205",
    #         'color': 'dodgerblue'
    #     },
    # 'inbody_portable':
    #     {
    #         'name': 'Биоимпеданс (InBody), mRMR',
    #         'path': f"{path}/subset_inbody_portable",
    #         'path_model': f"{path}/subset_inbody_portable/models/DANet/857",
    #         'color': 'lawngreen'
    #     },
    # 'inbody_portable_cnt':
    #     {
    #         'name': 'Биоимпеданс (InBody), mRMR',
    #         'path': f"{path}/subset_inbody_portable_cnt",
    #         'path_model': f"{path}/subset_inbody_portable_cnt/models/DANet/414",
    #         'color': 'lawngreen'
    #     },
    # 'inbody_mrmr':
    #     {
    #         'name': 'Биоимпеданс (InBody), mRMR',
    #         'path': f"{path}/subset_inbody_mrmr",
    #         'path_model': f"{path}/subset_inbody_mrmr/models/DANet/1/261",
    #         'color': 'lawngreen'
    #     },
    # 'lab':
    #     {
    #         'name': 'Анализ Крови',
    #         'path': f"{path}/subset_lab",
    #         'path_model': f"{path}/subset_lab/models/DANet/446",
    #         'color': 'crimson'
    #     },
    # 'inbody_mrmr_lab':
    #     {
    #         'name': 'Анализ Крови + Биоимпеданс',
    #         'path': f"{path}/subset_inbody_mrmr_lab",
    #         'path_model': f"{path}/subset_inbody_mrmr_lab/models/DANet/3/368",
    #         'color': 'orange'
    #     },
}

for ds in datasets:
    datasets[ds]['data'] = pd.read_excel(f"{datasets[ds]['path']}/data.xlsx", index_col=0)
    datasets[ds]['feats'] = pd.read_excel(f"{datasets[ds]['path']}/feats.xlsx", index_col=0)
    datasets[ds]['results'] = pd.read_excel(f"{datasets[ds]['path_model']}/df.xlsx", index_col=0)
    datasets[ds]['metrics'] = pd.read_excel(f"{datasets[ds]['path_model']}/metrics.xlsx", index_col=0)
    datasets[ds]['shap'] = pd.read_excel(f"{datasets[ds]['path_model']}/explanation.xlsx", index_col=0)
    datasets[ds]['model'] = TabularModel.load_model(f"{datasets[ds]['path_model']}")
    datasets[ds]['corrector'] = LinearBiasCorrector()
    ds_results = datasets[ds]['results']
    datasets[ds]['corrector'].fit(ds_results.loc[ds_results['Group'] == 'Train', feat_trgt].values, ds_results.loc[ds_results['Group'] == 'Train', 'Prediction'].values)

## Plot models results

In [6]:
for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats = ds_feats.index.to_list()
    feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
    feats_cnt_wo_trgt = list(feats_cnt[feats_cnt != 'Возраст'])
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    xy_min, xy_max = np.quantile(ds_results[[feat_trgt, 'Prediction Unbiased']].values.flatten(), [0.01, 0.99])
    xy_ptp = xy_max - xy_min
    
    n_rows = 2
    n_cols = 2
    fig_height = 5
    fig_width = 7
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[2, 8],  width_ratios=[4, 2], gridspec_kw={'wspace':0.10, 'hspace': 0.05}, layout='constrained')

    ds_table = pd.DataFrame(index=['Средняя абсолютная ошибка', 'Коэффициент корреляции Пирсона', 'Среднее смещение'], columns=['Тестовая выборка'])
    ds_table.at['Средняя абсолютная ошибка', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'mean_absolute_error_unbiased']:0.2f}"
    ds_table.at['Коэффициент корреляции Пирсона', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'pearson_corrcoef_unbiased']:0.2f}"
    ds_table.at['Среднее смещение', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'bias_unbiased']:0.2f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='Метрики',
            textprops={"ha": "left"},
            width=4.5,
        ),
        ColumnDefinition(
            name="Тестовая выборка",
            textprops={"ha": "center"},
            width=2.0,
        ),
    ]
    table = Table(
        ds_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[0, 0],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Тестовая выборка'])

    kdeplot = sns.kdeplot(
        data=ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :],
        x='Возраст',
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=ds_color,
        legend=False,
        ax=axs[1, 0]
    )
    scatter = sns.scatterplot(
        data=ds_results.loc[ds_results['Group'] == 'Test', :],
        x='Возраст',
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=25,
        color=ds_color,
        ax=axs[1, 0],
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs[1, 0]
    )
    regplot = sns.regplot(
        data=ds_results,
        x='Возраст',
        y='Prediction Unbiased',
        color='k',
        scatter=False,
        truncate=False,
        ax=axs[1, 0]
    )
    axs[1, 0].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs[1, 0].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs[1, 0].set_ylabel("Биологический возраст")
    axs[1, 0].set_xlabel("Возраст")
    
    axs[0, 1].axis('off')
    
    violin = sns.violinplot(
        data=ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :],
        x=[0] * ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :].shape[0],
        y='Error Unbiased',
        color=make_rgb_transparent(mcolors.to_rgb(ds_color), (1, 1, 1), 0.5),
        density_norm='width',
        saturation=0.75,
        linewidth=1.0,
        ax=axs[1, 1],
        legend=False,
    )
    swarm = sns.swarmplot(
        data=ds_results.loc[ds_results['Group'] == 'Test', :],
        x=[0] * ds_results.loc[ds_results['Group'] == 'Test', :].shape[0],
        y='Error Unbiased',
        color=ds_color,
        linewidth=0.5,
        ax=axs[1, 1],
        size= 50 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
        legend=False,
    )
    axs[1, 1].set_ylabel('Возрастная акселерация')
    axs[1, 1].set_xlabel('')
    axs[1, 1].set(xticklabels=[]) 
    axs[1, 1].set(xticks=[]) 
    fig.suptitle(datasets[ds]['name'], fontsize='large')
    fig.savefig(f"{datasets[ds]['path']}/model.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model.pdf", bbox_inches='tight')
    plt.close(fig)

## Plot models explainability

In [None]:
expl_type = 'current' # 'current' 'recalc_gradient' 'recalc_sampling'

for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats = ds_feats.index.values
    feats_wo_trgt = feats[feats != 'Возраст']
    feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
    feats_cnt_wo_trgt = list(feats_cnt[feats_cnt != 'Возраст'])
    feats_cat_wo_trgt = ds_feats.index[ds_feats['Type'] != 'continuous'].to_list()
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    if expl_type == 'recalc_gradient':
        ds_shap = ds_model.explain(ds_data, method="GradientShap", baselines="b|100000")
        ds_shap.index = ds_data.index
    elif expl_type == 'recalc_sampling':
        ds_data_shap = ds_data.copy()
        ds_cat_encoders = {}
        for f in feats_cat_wo_trgt:
            ds_cat_encoders[f] = LabelEncoder()
            ds_data_shap[f] = ds_cat_encoders[f].fit_transform(ds_data_shap[f])
        def predict_func(X):
            X_df = pd.DataFrame(data=X, columns=feats_wo_trgt)
            for f in feats_cat_wo_trgt:
                X_df[f] = ds_cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
            y = ds_model.predict(X_df)[f'Возраст_prediction'].values
            y = ds_corrector.predict(y)
            return y
        explainer = shap.SamplingExplainer(predict_func, ds_data_shap.loc[:, feats_wo_trgt].values)
        print(explainer.expected_value)
        shap_values = explainer.shap_values(ds_data_shap.loc[:, feats_wo_trgt].values)
        ds_shap = pd.DataFrame(index=ds_data.index, columns=feats_wo_trgt, data=shap_values)
    
    
    ds_fi = pd.DataFrame(index=feats_wo_trgt, columns=['mean(|SHAP|)'])
    for f in feats_wo_trgt:
        ds_fi.at[f, 'mean(|SHAP|)'] = ds_shap[f].abs().mean()
    ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    if ds != 'inbody_mrmr_lab':
        ds_fi = ds_fi.head(30)
    ds_fi['Features'] = ds_fi.index.values
    
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(1, 2, figsize=(12, 0.6 * ds_fi.shape[0]), width_ratios=[4, 8], gridspec_kw={'wspace':0.1, 'hspace': 0.05}, sharey=True, sharex=False)
    
    barplot = sns.barplot(
        data=ds_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=ds_color,
        edgecolor='black',
        dodge=False,
        ax=axs[0]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=4.0)
    axs[0].set_ylabel('')
    axs[0].set(yticklabels=ds_fi.index.to_list())
    
    is_colorbar = False
    f_legends = []
    for f in ds_fi.index:
        
        if ds_shap[f].abs().max() > 10:
            f_shap_ll = ds_shap[f].quantile(0.01)
            f_shap_hl = ds_shap[f].quantile(0.99)
        else:
            f_shap_ll = ds_shap[f].min()
            f_shap_hl = ds_shap[f].max()
        
        f_index = ds_shap.index[(ds_shap[f] >= f_shap_ll) & (ds_shap[f] <= f_shap_hl)].values
        f_shap = ds_shap.loc[f_index, f].values
        f_vals = ds_data.loc[f_index, f].values
        
        if ds_feats.at[f, 'Type'] == 'continuous':
            f_cmap = sns.color_palette("Spectral_r", as_cmap=True)
            f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
            f_colors = {}
            for cval in f_vals:
                f_colors.update({cval: f_cmap(f_norm(cval))})

            strip = sns.stripplot(
                x=f_shap,
                y=[f]*len(f_shap),
                hue=f_vals,
                palette=f_colors,
                jitter=0.35,
                alpha=0.5,
                edgecolor='gray',
                linewidth=0.1,
                size=25 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
                legend=False,
                ax=axs[1],
            )
            
            if not is_colorbar:
                sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
                sm.set_array([])
                cbar = strip.figure.colorbar(sm)
                cbar.set_label('Значения\nчисленных\nпризнаков', labelpad=-8, fontsize='large')
                cbar.set_ticks([min(f_vals), max(f_vals)])
                cbar.set_ticklabels(["Min", "Max"])
                is_colorbar = True
        else:
            if f == 'Пол':
                f_unique = ['жен', 'муж']
                f_palette = {'жен': 'crimson', 'муж': 'dodgerblue'}
            elif f == 'Уровень висцерального жира':
                f_unique = sorted(ds_data['Уровень висцерального жира'].unique(), key=lambda x: int(x.replace('Level ', '')))
                f_cmap = sns.color_palette("husl", n_colors=len(f_unique))
                f_palette = {x: f_cmap[x_id] for x_id, x in enumerate(f_unique)}
            else:
                f_unique = ds_data.loc[f_index, f].unique()
                f_unique_colors = distinctipy.get_colors(len(f_unique), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
                f_palette = {x: f_unique_colors[x_id] for x_id, x in enumerate(f_unique)}
            strip = sns.stripplot(
                x=f_shap,
                y=[f]*len(f_shap),
                hue=f_vals,
                palette=f_palette,
                hue_order=f_unique,
                jitter=0.35,
                label=f,
                alpha=0.5,
                edgecolor='gray',
                linewidth=0.1,
                size=25 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
                legend=True,
                ax=axs[1],
            )
            sns.move_legend(strip, "upper left", bbox_to_anchor=(1.3, 1), ncol=1, title='Категориальные\nпризнаки', title_font=dict(size='large'), frameon=False) 
          
    axs[1].set_xlabel('SHAP: Влияние на предсказание модели')
    fig.savefig(f"{datasets[ds]['path']}/model_importance.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model_importance.pdf", bbox_inches='tight')
    ds_shap.to_excel(f"{datasets[ds]['path']}/model_importance.xlsx")
    plt.close(fig)

## Stupid bias correction

In [None]:
path = f"D:/YandexDisk/Work/bbd/atlas"

feat_trgt = 'Возраст'

datasets = {
    'lab': 
        {
            'name': 'Анализ Крови',
            'path': f"{path}/subset_lab",
            'path_model': f"{path}/subset_lab/models/DANet/446", 
            'color': 'crimson'
        },
}

for ds in datasets:
    ds_data = pd.read_excel(f"{datasets[ds]['path']}/data.xlsx", index_col=0)
    ds_feats = pd.read_excel(f"{datasets[ds]['path']}/feats.xlsx", index_col=0)
    ds_results = pd.read_excel(f"{datasets[ds]['path_model']}/df.xlsx", index_col=0)
    ds_corrector = LinearBiasCorrector()
    stupid_selection = (ds_results['Group'] == 'Train') & ((ds_results[feat_trgt] < 30) | (ds_results[feat_trgt] > 60))
    stupid_data_for_fit = ds_results.loc[stupid_selection, :]
    print(stupid_data_for_fit.shape)
    ds_corrector.fit(stupid_data_for_fit.loc[:, feat_trgt].values, stupid_data_for_fit.loc[:, 'Prediction'].values)
    ds_results['Prediction Unbiased'] = ds_corrector.predict(ds_results['Prediction'].values)
    ds_results['Error Unbiased'] = ds_results['Prediction Unbiased'] - ds_results[feat_trgt]
    
    xy_min = ds_results[[feat_trgt, 'Prediction']].min().min()
    xy_max = ds_results[[feat_trgt, 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    
    xy_min_unbiased = ds_results[[feat_trgt, 'Prediction Unbiased']].min().min()
    xy_max_unbiased = ds_results[[feat_trgt, 'Prediction Unbiased']].max().max()
    xy_ptp_unbiased = xy_max_unbiased - xy_min_unbiased
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=ds_results.loc[ds_results['Group'] == group, :],
            x=feat_trgt,
            y="Prediction",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{datasets[ds]['path']}/stupid/regplot.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/stupid/regplot.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=ds_results.loc[ds_results['Group'] == group, :],
            x=feat_trgt,
            y="Prediction Unbiased",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        y=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    ax.set_ylim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{datasets[ds]['path']}/stupid/regplot_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/stupid/regplot_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)

# Local explainability checking

In [None]:
model = TabularModel.load_model(f"{path}/models/DANet/1/261")
results = pd.read_excel(f"{path}/models/DANet/1/261/df.xlsx", index_col=0)
res_cols = ['Group', 'Prediction', 'Error', 'Prediction Unbiased', 'Error Unbiased']
data.loc[data.index, res_cols] = results.loc[data.index, res_cols]
corrector = LinearBiasCorrector()
corrector.fit(data.loc[data['Group'] == 'Train', feat_trgt].values, data.loc[data['Group'] == 'Train', 'Prediction'].values)

data_shap = data.copy()
cat_encoders = {}
for f in feats_cat:
    cat_encoders[f] = LabelEncoder()
    data_shap[f] = cat_encoders[f].fit_transform(data_shap[f])

def predict_func(X):
    X_df = pd.DataFrame(data=X, columns=feats)
    for f in feats_cat:
        X_df[f] = cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
    y = model.predict(X_df)[f'{feat_trgt}_prediction'].values
    y = corrector.predict(y)
    return y

In [None]:
trgt_id = 82235 # 1159
trgt_age = data_shap.at[trgt_id, feat_trgt]
trgt_pred = data_shap.at[trgt_id, 'Prediction Unbiased']
trgt_aa = trgt_pred - trgt_age
print(trgt_age)
print(trgt_pred)

n_closest = 32
data_closest = data_shap.iloc[(data_shap['Prediction Unbiased'] - trgt_age).abs().argsort()[:n_closest]]

explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, feats].values)
print(explainer.expected_value)
shap_values = explainer.shap_values(data_shap.loc[[trgt_id], feats].values)[0]
shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)

# shap.plots.waterfall(
#     shap.Explanation(
#         values=shap_values,
#         base_values=trgt_age,
#         data=data.loc[trgt_id, feats].values,
#         feature_names=feats
#     ),
#     max_display=len(feats),
#     show=True,
# )

df_shap = pd.DataFrame(index=feats, data=shap_values, columns=[trgt_id])
df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap['cumsum'] = df_shap[trgt_id].cumsum()

df_less_more = pd.DataFrame(index=df_shap.index, columns=['Less', 'More'])
df_cat_part = {}
for f_id, f in enumerate(df_less_more.index):
    if df_feats.at[f, 'Type'] != 'categorical':
        df_less_more.at[f, 'Меньше'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f].values, data_shap.at[trgt_id, f]))
        df_less_more.at[f, 'Больше'] = 100.0 - df_less_more.at[f, 'Меньше']
    else:
        df_less_more.at[f, 'Меньше'] = np.nan
        df_less_more.at[f, 'Больше'] = np.nan
        
        f_value_counts = data_closest.loc[:, 'Пол'].value_counts()
        f_value_counts_rename = {x: cat_encoders['Пол'].inverse_transform([x])[0] for x in f_value_counts.index.astype(int).values}
        f_value_counts.rename(index=f_value_counts_rename, inplace=True)
        f_value_counts = np.rint(f_value_counts / f_value_counts.sum() * 100)
        
        df_cat_part[f_id] = {
            'name': f,
            'distribution': f_value_counts.astype(int)
        }
        if f == 'Пол':
            df_cat_part[f_id]['palette'] = {'жен': 'crimson', 'муж': 'dodgerblue'}

fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.05, subplot_titles=['', "Распределение признаков у людей<br>в данном возрастном диапазоне"])
fig.add_trace(
    go.Waterfall(
        hovertext=["Хронологический возраст", "Биологический возраст"],
        orientation="h",
        measure=['absolute', 'relative'],
        y=[-1.5, df_shap.shape[0] + 0.5],
        x=[trgt_age, trgt_aa],
        base=0,
        text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}"],
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Waterfall(
        hovertext=df_shap.index.values,
        orientation="h",
        measure=["relative"] * len(feats),
        y=list(range(df_shap.shape[0])),
        x=df_shap[trgt_id].values,
        base=trgt_age,
        text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values],
        textposition = "auto",
        decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "solid"},
        },
    ),
    row=1,
    col=1,
)
fig.update_traces(row=1, col=1, showlegend=False)
fig.update_yaxes(
    row=1,
    col=1,
    automargin=True,
    tickmode="array",
    tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5],
    ticktext=["Хронологический возраст"] + [f"{x} = {data.at[trgt_id, x]:0.2f}" if df_feats.at[x, 'Type'] != 'categorical' else f"{x} = {data.at[trgt_id, x]}" for x in df_shap.index] + ["Биологический возраст"],
    tickfont=dict(size=18),
)
fig.update_xaxes(
    row=1,
    col=1,
    automargin=True,
    title='Возраст',
    titlefont=dict(size=25),
    range=[
        trgt_age - df_shap['cumsum'].abs().max() * 1.25,
        trgt_age + df_shap['cumsum'].abs().max() * 1.25
    ],
)

fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Меньше',
        x=df_less_more.loc[df_shap.index.values, 'Меньше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='steelblue', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Меньше'],
        textposition='auto'
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Больше',
        x=df_less_more.loc[df_shap.index.values, 'Больше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='violet', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Больше'],
        textposition='auto',
    ),
    row=1,
    col=2
)

for f_cat_id, f_cat_dict in df_cat_part.items():
    for f_val in f_cat_dict['distribution'].index:
        fig.add_trace(
            go.Bar(
                hovertext=[f_cat_dict['name']],
                orientation="h",
                name=f_val,
                x=[f_cat_dict['distribution'][f_val]],
                y=[f_cat_id],
                marker=dict(color=f_cat_dict['palette'][f_val], line=dict(color="black", width=1)),
                text=[f_val],
                textposition='auto',
                showlegend=False
            ),
            row=1,
            col=2
        )

fig.update_xaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_yaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
    # tickmode="array",
    # tickvals=list(range(df_less_more.shape[0])),
    # ticktext=[f"{data.at[trgt_id, x]:0.2f}" if df_feats.at[x, 'Type'] != 'categorical' else data.at[trgt_id, x] for x in df_less_more.index],
    # tickfont=dict(size=18),
    # showticklabels=True
)
fig.update_layout(barmode="stack")
fig.update_layout(
    legend=dict(
        title=dict(side="top"),
        orientation="h",
        yanchor="bottom",
        y=0.95,
        xanchor="center",
        x=0.84
    ),
)

fig.update_layout(
    title=f"Возрастная акселерация для {trgt_id}",
    titlefont=dict(size=25),
    template="none",
    width=1200,
    height=1000,
    margin=go.layout.Margin(l=120, r=80, b=50, t=50, pad=0),
)

fig.show()