# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats
from functools import reduce


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data and models for subsets of features

In [None]:
path = f"D:/YandexDisk/Work/bbd/atlas"

feat_trgt = 'Возраст'

datasets = {
    'inbody_mrmr': 
        {
            'name': 'Биоимпеданс (InBody), mRMR',
            'path': f"{path}/subset_inbody_mrmr",
            'path_model': f"{path}/subset_inbody_mrmr/models/DANet/1/261", 
            'color': 'lawngreen'
        },
    'lab': 
        {
            'name': 'Анализ Крови',
            'path': f"{path}/subset_lab",
            'path_model': f"{path}/subset_lab/models/DANet/446", 
            'color': 'crimson'
        }
}

for ds in datasets:
    datasets[ds]['data'] = pd.read_excel(f"{datasets[ds]['path']}/data.xlsx", index_col=0)
    datasets[ds]['feats'] = pd.read_excel(f"{datasets[ds]['path']}/feats.xlsx", index_col=0)
    datasets[ds]['results'] = pd.read_excel(f"{datasets[ds]['path_model']}/df.xlsx", index_col=0)
    datasets[ds]['metrics'] = pd.read_excel(f"{datasets[ds]['path_model']}/metrics.xlsx", index_col=0)
    datasets[ds]['shap'] = pd.read_excel(f"{datasets[ds]['path_model']}/explanation.xlsx", index_col=0)
    datasets[ds]['model'] = TabularModel.load_model(f"{datasets[ds]['path_model']}")
    datasets[ds]['corrector'] = LinearBiasCorrector()
    ds_results = datasets[ds]['results']
    datasets[ds]['corrector'].fit(ds_results.loc[ds_results['Group'] == 'Train', feat_trgt].values, ds_results.loc[ds_results['Group'] == 'Train', 'Prediction'].values)
    
    res_cols = ['Group', 'Prediction', 'Error', 'Prediction Unbiased', 'Error Unbiased']
    datasets[ds]['data'].loc[datasets[ds]['data'].index, res_cols] = ds_results.loc[datasets[ds]['data'].index, res_cols]
    
    feats = datasets[ds]['feats'].index.values
    feats = feats[feats != 'Возраст']
    feats_cnt = datasets[ds]['feats'].index[datasets[ds]['feats']['Type'] == 'continuous'].to_list()
    feats_cnt = list(feats_cnt[feats_cnt != 'Возраст'])
    feats_cat = datasets[ds]['feats'].index[datasets[ds]['feats']['Type'] != 'continuous'].to_list()
    
    datasets[ds]['data_shap'] = datasets[ds]['data'].copy()
    datasets[ds]['cat_encoders'] = {}
    for f in feats_cat:
        datasets[ds]['cat_encoders'][f] = LabelEncoder()
        datasets[ds]['data_shap'][f] = datasets[ds]['cat_encoders'][f].fit_transform(datasets[ds]['data_shap'][f])
        
    def predict_func(X):
        X_df = pd.DataFrame(data=X, columns=feats)
        for f in feats_cat:
            X_df[f] = datasets[ds]['cat_encoders'][f].inverse_transform(X_df[f].astype(int).values)
        y = datasets[ds]['model'].predict(X_df)[f'{feat_trgt}_prediction'].values
        y = datasets[ds]['corrector'].predict(y)
        return y
    
    datasets[ds]['predict_func'] = predict_func

## Plot models results

In [None]:
for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats = ds_feats.index.to_list()
    feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
    feats_cnt_wo_trgt = list(feats_cnt[feats_cnt != 'Возраст'])
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    xy_min, xy_max = np.quantile(ds_results[[feat_trgt, 'Prediction Unbiased']].values.flatten(), [0.01, 0.99])
    xy_ptp = xy_max - xy_min
    
    n_rows = 2
    n_cols = 2
    fig_height = 5
    fig_width = 7
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[2, 8],  width_ratios=[4, 2], gridspec_kw={'wspace':0.10, 'hspace': 0.05}, layout='constrained')

    ds_table = pd.DataFrame(index=['Средняя абсолютная ошибка', 'Коэффициент корреляции Пирсона', 'Среднее смещение'], columns=['Тестовая выборка'])
    ds_table.at['Средняя абсолютная ошибка', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'mean_absolute_error_unbiased']:0.2f}"
    ds_table.at['Коэффициент корреляции Пирсона', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'pearson_corrcoef_unbiased']:0.2f}"
    ds_table.at['Среднее смещение', 'Тестовая выборка'] = f"{ds_metrics.at['Test', 'bias_unbiased']:0.2f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='Метрики',
            textprops={"ha": "left"},
            width=4.5,
        ),
        ColumnDefinition(
            name="Тестовая выборка",
            textprops={"ha": "center"},
            width=2.0,
        ),
    ]
    table = Table(
        ds_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[0, 0],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Тестовая выборка'])

    kdeplot = sns.kdeplot(
        data=ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :],
        x='Возраст',
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=ds_color,
        legend=False,
        ax=axs[1, 0]
    )
    scatter = sns.scatterplot(
        data=ds_results.loc[ds_results['Group'] == 'Test', :],
        x='Возраст',
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=25,
        color=ds_color,
        ax=axs[1, 0],
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs[1, 0]
    )
    regplot = sns.regplot(
        data=ds_results,
        x='Возраст',
        y='Prediction Unbiased',
        color='k',
        scatter=False,
        truncate=False,
        ax=axs[1, 0]
    )
    axs[1, 0].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs[1, 0].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs[1, 0].set_ylabel("Биологический возраст")
    axs[1, 0].set_xlabel("Возраст")
    
    axs[0, 1].axis('off')
    
    violin = sns.violinplot(
        data=ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :],
        x=[0] * ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :].shape[0],
        y='Error Unbiased',
        color=make_rgb_transparent(mcolors.to_rgb(ds_color), (1, 1, 1), 0.5),
        density_norm='width',
        saturation=0.75,
        linewidth=1.0,
        ax=axs[1, 1],
        legend=False,
    )
    swarm = sns.swarmplot(
        data=ds_results.loc[ds_results['Group'] == 'Test', :],
        x=[0] * ds_results.loc[ds_results['Group'] == 'Test', :].shape[0],
        y='Error Unbiased',
        color=ds_color,
        linewidth=0.5,
        ax=axs[1, 1],
        size= 50 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
        legend=False,
    )
    axs[1, 1].set_ylabel('Возрастная акселерация')
    axs[1, 1].set_xlabel('')
    axs[1, 1].set(xticklabels=[]) 
    axs[1, 1].set(xticks=[]) 
    fig.suptitle(datasets[ds]['name'], fontsize='large')
    fig.savefig(f"{datasets[ds]['path']}/model.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model.pdf", bbox_inches='tight')
    plt.close(fig)

## Plot models explainability

In [None]:
for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats = ds_feats.index.values
    feats_wo_trgt = feats[feats != 'Возраст']
    feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
    feats_cnt_wo_trgt = list(feats_cnt[feats_cnt != 'Возраст'])
    feats_cat_wo_trgt = ds_feats.index[ds_feats['Type'] != 'continuous'].to_list()
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    
    # ds_shap = ds_model.explain(ds_data, method="GradientShap", baselines="b|100000")
    # ds_shap.index = ds_data.index
    
    # ds_data_shap = ds_data.copy()
    # ds_cat_encoders = {}
    # for f in feats_cat_wo_trgt:
    #     ds_cat_encoders[f] = LabelEncoder()
    #     ds_data_shap[f] = ds_cat_encoders[f].fit_transform(ds_data_shap[f])
    # def predict_func(X):
    #     X_df = pd.DataFrame(data=X, columns=feats_wo_trgt)
    #     for f in feats_cat_wo_trgt:
    #         X_df[f] = ds_cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
    #     y = model.predict(X_df)[f'Возраст_prediction'].values
    #     y = corrector.predict(y)
    #     return y
    # explainer = shap.SamplingExplainer(predict_func, ds_data_shap.loc[:, feats_wo_trgt].values)
    # print(explainer.expected_value)
    # shap_values = explainer.shap_values(ds_data_shap.loc[:, feats_wo_trgt].values)
    # ds_shap = pd.DataFrame(index=ds_data.index, columns=feats_wo_trgt, data=shap_values)
    
    
    ds_fi = pd.DataFrame(index=feats_wo_trgt, columns=['mean(|SHAP|)'])
    for f in feats_wo_trgt:
        ds_fi.at[f, 'mean(|SHAP|)'] = ds_shap[f].abs().mean()
    ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    ds_fi = ds_fi.head(30)
    ds_fi['Features'] = ds_fi.index.values
    
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(1, 2, figsize=(12, 8), width_ratios=[4, 8], gridspec_kw={'wspace':0.1, 'hspace': 0.05}, sharey=True, sharex=False)
    
    barplot = sns.barplot(
        data=ds_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=ds_color,
        edgecolor='black',
        dodge=False,
        ax=axs[0]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=4.0)
    axs[0].set_ylabel('')
    axs[0].set(yticklabels=ds_fi.index.to_list())
    
    is_colorbar = False
    f_legends = []
    for f in ds_fi.index:
        
        if ds_shap[f].abs().max() > 10:
            f_shap_ll = ds_shap[f].quantile(0.01)
            f_shap_hl = ds_shap[f].quantile(0.99)
        else:
            f_shap_ll = ds_shap[f].min()
            f_shap_hl = ds_shap[f].max()
        
        f_index = ds_shap.index[(ds_shap[f] >= f_shap_ll) & (ds_shap[f] <= f_shap_hl)].values
        f_shap = ds_shap.loc[f_index, f].values
        f_vals = ds_data.loc[f_index, f].values
        
        if ds_feats.at[f, 'Type'] == 'continuous':
            f_cmap = sns.color_palette("Spectral_r", as_cmap=True)
            f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
            f_colors = {}
            for cval in f_vals:
                f_colors.update({cval: f_cmap(f_norm(cval))})

            strip = sns.stripplot(
                x=f_shap,
                y=[f]*len(f_shap),
                hue=f_vals,
                palette=f_colors,
                jitter=0.35,
                alpha=0.5,
                edgecolor='gray',
                linewidth=0.1,
                size=25 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
                legend=False,
                ax=axs[1],
            )
            
            if not is_colorbar:
                sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
                sm.set_array([])
                cbar = strip.figure.colorbar(sm)
                cbar.set_label('Значения численных признаков', labelpad=-8, fontsize='large')
                cbar.set_ticks([min(f_vals), max(f_vals)])
                cbar.set_ticklabels(["Min", "Max"])
                is_colorbar = True
        else:
            if f == 'Пол':
                f_unique = ['жен', 'муж']
                f_palette = {'жен': 'crimson', 'муж': 'dodgerblue'}
            elif f == 'Уровень висцерального жира':
                f_unique = sorted(ds_data['Уровень висцерального жира'].unique(), key=lambda x: int(x.replace('Level ', '')))
                f_cmap = sns.color_palette("husl", n_colors=len(f_unique))
                f_palette = {x: f_cmap[x_id] for x_id, x in enumerate(f_unique)}
            else:
                f_unique = ds_data.loc[f_index, f].unique()
                f_unique_colors = distinctipy.get_colors(len(f_unique), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
                f_palette = {x: f_unique_colors[x_id] for x_id, x in enumerate(f_unique)}
            strip = sns.stripplot(
                x=f_shap,
                y=[f]*len(f_shap),
                hue=f_vals,
                palette=f_palette,
                hue_order=f_unique,
                jitter=0.35,
                label=f,
                alpha=0.5,
                edgecolor='gray',
                linewidth=0.1,
                size=25 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
                legend=True,
                ax=axs[1],
            )
            sns.move_legend(strip, "upper left", bbox_to_anchor=(1.3, 1), ncol=1, title='Категориальные\nпризнаки', title_font=dict(size='large'), frameon=False) 
          
    axs[1].set_xlabel('SHAP: Влияние на предсказание модели')
    fig.savefig(f"{datasets[ds]['path']}/model_importance.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model_importance.pdf", bbox_inches='tight')
    ds_shap.to_excel(f"{datasets[ds]['path']}/model_importance.xlsx")
    plt.close(fig)

# Local explainability checking

# Inbody mRMR

In [None]:
ds = 'inbody_mrmr'

ds_feats = datasets[ds]['feats']
feats = ds_feats.index.values
feats = feats[feats != 'Возраст']
feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
feats_cnt = list(feats_cnt[feats_cnt != 'Возраст'])
feats_cat = ds_feats.index[ds_feats['Type'] != 'continuous'].to_list()

ds_data = datasets[ds]['data']
ds_results = datasets[ds]['results']
ds_metrics = datasets[ds]['metrics']
ds_shap = datasets[ds]['shap']
ds_model = datasets[ds]['model']
ds_corrector = datasets[ds]['corrector']
ds_color = datasets[ds]['color']
ds_data_shap = datasets[ds]['data_shap']
ds_predict_func = datasets[ds]['predict_func']

trgt_id = 82235 # 1159
trgt_age = ds_data_shap.at[trgt_id, feat_trgt]
trgt_pred = ds_data_shap.at[trgt_id, 'Prediction Unbiased']
trgt_aa = trgt_pred - trgt_age
print(trgt_age)
print(trgt_pred)

n_closest = 32
data_closest = ds_data_shap.iloc[(ds_data_shap['Prediction Unbiased'] - trgt_age).abs().argsort()[:n_closest]]

explainer = shap.SamplingExplainer(ds_predict_func, data_closest.loc[:, feats].values)
print(explainer.expected_value)
shap_values = explainer.shap_values(ds_data_shap.loc[[trgt_id], feats].values)[0]
shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)

# shap.plots.waterfall(
#     shap.Explanation(
#         values=shap_values,
#         base_values=trgt_age,
#         data=data.loc[trgt_id, feats].values,
#         feature_names=feats
#     ),
#     max_display=len(feats),
#     show=True,
# )

df_shap = pd.DataFrame(index=feats, data=shap_values, columns=[trgt_id])
df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap['cumsum'] = df_shap[trgt_id].cumsum()

df_less_more = pd.DataFrame(index=df_shap.index, columns=['Less', 'More'])
df_cat_part = {}
for f_id, f in enumerate(df_less_more.index):
    if ds_feats.at[f, 'Type'] != 'categorical':
        df_less_more.at[f, 'Меньше'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f].values, ds_data_shap.at[trgt_id, f]))
        df_less_more.at[f, 'Больше'] = 100.0 - df_less_more.at[f, 'Меньше']
    else:
        df_less_more.at[f, 'Меньше'] = np.nan
        df_less_more.at[f, 'Больше'] = np.nan
        
        f_value_counts = data_closest.loc[:, 'Пол'].value_counts()
        f_value_counts_rename = {x: datasets[ds]['cat_encoders']['Пол'].inverse_transform([x])[0] for x in f_value_counts.index.astype(int).values}
        f_value_counts.rename(index=f_value_counts_rename, inplace=True)
        f_value_counts = np.rint(f_value_counts / f_value_counts.sum() * 100)
        
        df_cat_part[f_id] = {
            'name': f,
            'distribution': f_value_counts.astype(int)
        }
        if f == 'Пол':
            df_cat_part[f_id]['palette'] = {'жен': 'crimson', 'муж': 'dodgerblue'}

fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.05, subplot_titles=['', "Распределение признаков у людей<br>в данном возрастном диапазоне"])
fig.add_trace(
    go.Waterfall(
        hovertext=["Хронологический возраст", "Биологический возраст"],
        orientation="h",
        measure=['absolute', 'relative'],
        y=[-1.5, df_shap.shape[0] + 0.5],
        x=[trgt_age, trgt_aa],
        base=0,
        text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}"],
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Waterfall(
        hovertext=df_shap.index.values,
        orientation="h",
        measure=["relative"] * len(feats),
        y=list(range(df_shap.shape[0])),
        x=df_shap[trgt_id].values,
        base=trgt_age,
        text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values],
        textposition = "auto",
        decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "solid"},
        },
    ),
    row=1,
    col=1,
)
fig.update_traces(row=1, col=1, showlegend=False)
fig.update_yaxes(
    row=1,
    col=1,
    automargin=True,
    tickmode="array",
    tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5],
    ticktext=["Хронологический возраст"] + [f"{x} = {ds_data.at[trgt_id, x]:0.2f}" if ds_feats.at[x, 'Type'] != 'categorical' else f"{x} = {ds_data.at[trgt_id, x]}" for x in df_shap.index] + ["Биологический возраст"],
    tickfont=dict(size=18),
)
fig.update_xaxes(
    row=1,
    col=1,
    automargin=True,
    title='Возраст',
    titlefont=dict(size=25),
    range=[
        trgt_age - df_shap['cumsum'].abs().max() * 1.25,
        trgt_age + df_shap['cumsum'].abs().max() * 1.25
    ],
)

fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Меньше',
        x=df_less_more.loc[df_shap.index.values, 'Меньше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='steelblue', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Меньше'],
        textposition='auto'
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Больше',
        x=df_less_more.loc[df_shap.index.values, 'Больше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='violet', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Больше'],
        textposition='auto',
    ),
    row=1,
    col=2
)

for f_cat_id, f_cat_dict in df_cat_part.items():
    for f_val in f_cat_dict['distribution'].index:
        fig.add_trace(
            go.Bar(
                hovertext=[f_cat_dict['name']],
                orientation="h",
                name=f_val,
                x=[f_cat_dict['distribution'][f_val]],
                y=[f_cat_id],
                marker=dict(color=f_cat_dict['palette'][f_val], line=dict(color="black", width=1)),
                text=[f_val],
                textposition='auto',
                showlegend=False
            ),
            row=1,
            col=2
        )

fig.update_xaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_yaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
    # tickmode="array",
    # tickvals=list(range(df_less_more.shape[0])),
    # ticktext=[f"{data.at[trgt_id, x]:0.2f}" if df_feats.at[x, 'Type'] != 'categorical' else data.at[trgt_id, x] for x in df_less_more.index],
    # tickfont=dict(size=18),
    # showticklabels=True
)
fig.update_layout(barmode="stack")
fig.update_layout(
    legend=dict(
        title=dict(side="top"),
        orientation="h",
        yanchor="bottom",
        y=0.95,
        xanchor="center",
        x=0.84
    ),
)

fig.update_layout(
    title=f"Возрастная акселерация для {trgt_id}",
    titlefont=dict(size=25),
    template="none",
    width=1200,
    height=1000,
    margin=go.layout.Margin(l=120, r=80, b=50, t=50, pad=0),
)

fig.show()

# Model combinations: Blood + Inbody

In [None]:
trgt_id = 20104

bkg_count = {
    'inbody_mrmr': 32,
    'lab': 256
}

data_all = []
feats_all = []
local_exlp = {}
for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats = ds_feats.index.values
    feats = feats[feats != 'Возраст']
    feats_cnt = ds_feats.index[ds_feats['Type'] == 'continuous'].to_list()
    feats_cnt = list(feats_cnt[feats_cnt != 'Возраст'])
    feats_cat = ds_feats.index[ds_feats['Type'] != 'continuous'].to_list()

    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    ds_data_shap = datasets[ds]['data_shap']
    ds_predict_func = datasets[ds]['predict_func']

    trgt_age = ds_data_shap.at[trgt_id, feat_trgt]
    trgt_pred = ds_data_shap.at[trgt_id, 'Prediction Unbiased']
    trgt_aa = trgt_pred - trgt_age
    # print(trgt_age)
    # print(trgt_pred)
    # print(trgt_aa * ds_metrics.at['Test', 'pearson_corrcoef_unbiased'] / len(datasets))

    n_closest = bkg_count[ds]
    data_closest = ds_data_shap.iloc[(ds_data_shap['Prediction Unbiased'] - trgt_age).abs().argsort()[:n_closest]]

    explainer = shap.SamplingExplainer(ds_predict_func, data_closest.loc[:, feats].values)
    # print(explainer.expected_value)
    shap_values = explainer.shap_values(ds_data_shap.loc[[trgt_id], feats].values)[0]
    shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)
    shap_values *= ds_metrics.at['Test', 'pearson_corrcoef_unbiased'] / len(datasets)
    # print(sum(shap_values))
    
    df_shap = pd.DataFrame(index=feats, data=shap_values, columns=[trgt_id])
    df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
    # df_shap['cumsum'] = df_shap[trgt_id].cumsum()

    df_less_more = pd.DataFrame(index=df_shap.index, columns=['Меньше', 'Больше'])
    df_cat_part = {}
    for f in df_less_more.index:
        if ds_feats.at[f, 'Type'] != 'categorical':
            df_less_more.at[f, 'Меньше'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f].values, ds_data_shap.at[trgt_id, f]))
            df_less_more.at[f, 'Больше'] = 100.0 - df_less_more.at[f, 'Меньше']
        else:
            df_less_more.at[f, 'Меньше'] = np.nan
            df_less_more.at[f, 'Больше'] = np.nan
            
            f_value_counts = data_closest.loc[:, 'Пол'].value_counts()
            f_value_counts_rename = {x: datasets[ds]['cat_encoders']['Пол'].inverse_transform([x])[0] for x in f_value_counts.index.astype(int).values}
            f_value_counts.rename(index=f_value_counts_rename, inplace=True)
            f_value_counts = np.rint(f_value_counts / f_value_counts.sum() * 100)
            
            df_cat_part[f] = {
                'distribution': f_value_counts.astype(int)
            }
            if f == 'Пол':
                df_cat_part[f]['palette'] = {'жен': 'crimson', 'муж': 'dodgerblue'}  
        
    local_exlp[ds] = {
        'df_shap': df_shap,
        'df_less_more': df_less_more,
        'df_cat_part': df_cat_part,
        'age_acceleration': (trgt_pred - trgt_age) * ds_metrics.at['Test', 'pearson_corrcoef_unbiased'] / len(datasets),
    }
    
    data_all.append(ds_data.loc[[trgt_id], :])
    feats_all.append(ds_feats.loc[feats, :])

data_all = reduce(lambda left, right: pd.merge(left, right, left_index=True, right_index=True, suffixes=('', '_y')), data_all)
feats_all = pd.concat(feats_all)
feats_all = feats_all[~feats_all.index.duplicated(keep='first')]

feat_cmn = 'Пол'

df_shap_cmn = pd.DataFrame(index=[feat_cmn], columns=[trgt_id], data=np.zeros(1))
dfs_shap = [df_shap_cmn]
df_less_more_cmn = pd.DataFrame(index=[feat_cmn], columns=[trgt_id], data=np.nan)
dfs_less_more_cmn = [df_less_more_cmn]
df_cat_part_cmn = {
    'distribution': pd.Series(index=['жен', 'муж'], data=[0, 0]),
    'palette': {'жен': 'crimson', 'муж': 'dodgerblue'}
}
for ds in datasets:
    print(local_exlp[ds]['age_acceleration'])
    df_shap_cmn.at[feat_cmn, trgt_id] += local_exlp[ds]['df_shap'].at[feat_cmn, trgt_id]
    df_cat_part_cmn['distribution'] += df_cat_part['Пол']['distribution'] / len(datasets)
    dfs_shap.append(local_exlp[ds]['df_shap'].drop([feat_cmn]))
    dfs_less_more_cmn.append(local_exlp[ds]['df_less_more'].drop([feat_cmn]))
    
df_shap_union = pd.concat(dfs_shap)
df_less_more_union = pd.concat(dfs_less_more_cmn)
df_shap_union.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap_union['cumsum'] = df_shap_union[trgt_id].cumsum()
df_less_more_union = df_less_more_union.loc[df_shap_union.index, :]

In [None]:
trgt_aa = df_shap_union[trgt_id].sum()
trgt_age = data_all.at[trgt_id, feat_trgt]

aa_1 = local_exlp['lab']['age_acceleration']
aa_2 = local_exlp['inbody_mrmr']['age_acceleration']

fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.15, subplot_titles=['', "Распределение признаков у людей<br>в данном возрастном диапазоне"])
fig.add_trace(
    go.Waterfall(
        hovertext=["Хронологический возраст", "Возрастная акселерация (Анализ Крови)", "Возрастная акселерация (Биоимпеданс)", "Биологический возраст"],
        orientation="h",
        measure=['absolute', 'relative', 'relative', 'absolute'],
        y=[-1.5, df_shap_union.shape[0] + 0.5, df_shap_union.shape[0] + 1.5, df_shap_union.shape[0] + 2.5],
        x=[trgt_age, aa_1, aa_2, trgt_age+trgt_aa],
        base=0,
        text=[f"{trgt_age:0.2f}", f"+{aa_1:0.2f}" if aa_1 > 0 else f"{aa_1:0.2f}", f"+{aa_2:0.2f}" if aa_2 > 0 else f"{aa_2:0.2f}", f"{trgt_age+trgt_aa:0.2f}"],
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Waterfall(
        hovertext=df_shap_union.index.values,
        orientation="h",
        measure=["relative"] * len(feats),
        y=list(range(df_shap_union.shape[0])),
        x=df_shap_union[trgt_id].values,
        base=trgt_age,
        text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap_union[trgt_id].values],
        textposition = "auto",
        decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "solid"},
        },
    ),
    row=1,
    col=1,
)
fig.update_traces(row=1, col=1, showlegend=False)
fig.update_yaxes(
    row=1,
    col=1,
    automargin=True,
    tickmode="array",
    tickvals=[-1.5] + list(range(df_shap_union.shape[0])) + [df_shap_union.shape[0] + 0.5, df_shap_union.shape[0] + 1.5, df_shap_union.shape[0] + 2.5],
    ticktext=["Хронологический возраст"] + [f"{x} = {data_all.at[trgt_id, x]:0.2f}" if feats_all.at[x, 'Type'] != 'categorical' else f"{x} = {data_all.at[trgt_id, x]}" for x in df_shap_union.index] + ["Возрастная акселерация (Анализ Крови)", "Возрастная акселерация (Биоимпеданс)", "Биологический возраст"],
    tickfont=dict(size=18),
)
fig.update_xaxes(
    row=1,
    col=1,
    automargin=True,
    title='Возраст',
    titlefont=dict(size=25),
    range=[
        trgt_age - df_shap_union['cumsum'].abs().max() * 1.25,
        trgt_age + df_shap_union['cumsum'].abs().max() * 1.25
    ],
)

fig.add_trace(
    go.Bar(
        hovertext=df_shap_union.index.values,
        orientation="h",
        name='Меньше',
        x=df_less_more_union.loc[df_shap_union.index.values, 'Меньше'],
        y=list(range(df_shap_union.shape[0])),
        marker=dict(color='steelblue', line=dict(color="black", width=1)),
        text=df_less_more_union.loc[df_shap_union.index.values, 'Меньше'],
        textposition='auto'
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        hovertext=df_shap_union.index.values,
        orientation="h",
        name='Больше',
        x=df_less_more_union.loc[df_shap_union.index.values, 'Больше'],
        y=list(range(df_shap_union.shape[0])),
        marker=dict(color='violet', line=dict(color="black", width=1)),
        text=df_less_more_union.loc[df_shap_union.index.values, 'Больше'],
        textposition='auto',
    ),
    row=1,
    col=2
)

for f_val in df_cat_part_cmn['distribution'].index:
    fig.add_trace(
        go.Bar(
            hovertext=[feat_cmn],
            orientation="h",
            name=f_val,
            x=[df_cat_part_cmn['distribution'][f_val]],
            y=[df_shap_union.index.get_loc(feat_cmn)],
            marker=dict(color=df_cat_part_cmn['palette'][f_val], line=dict(color="black", width=1)),
            text=[f_val],
            textposition='auto',
            showlegend=False
        ),
        row=1,
        col=2
    )

fig.update_xaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_yaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_layout(barmode="stack")
fig.update_layout(
    legend=dict(
        title=dict(side="top"),
        orientation="h",
        yanchor="bottom",
        y=0.98,
        xanchor="center",
        x=0.87
    ),
)

fig.update_layout(
    title=f"Возрастная акселерация для {trgt_id}",
    titlefont=dict(size=25),
    template="none",
    width=1800,
    height=1300,
    margin=go.layout.Margin(l=120, r=100, b=50, t=50, pad=0),
)

fig.show()