# Debugging autoreload

In [1]:
%load_ext autoreload
%autoreload 2

# Load packages

In [1]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats
from functools import reduce
from sklearn.impute import KNNImputer
from scipy import stats


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data and models for subsets of features

In [None]:
path = f"E:/YandexDisk/Work/bbd/atlas"

feat_trgt = 'Возраст'

data_suffix = '_v3'
data = pd.read_excel(f"{path}/data{data_suffix}.xlsx", index_col=0)

components = {
    'InBody': {
        'name': 'Биоимпеданс (InBody)',
        'path': f"{path}/subset_InBody-mRMR_no-sex",
        'path_model': f"{path}/subset_InBody-mRMR_no-sex/models/DANet/2/432", 
        'color': 'dodgerblue',
        'bkg_count': 50,
        'likelihood': 1.0
        
    },
    'CompleteBloodCount': {
        'name': 'Общий анализ крови',
        'path': f"{path}/subset_CBC_no-sex",
        'path_model': f"{path}/subset_CBC_no-sex/models/DANet/2/67", 
        'color': 'crimson',
        'bkg_count': 300,
        'likelihood': 1.0
    },
    'BloodBiochemical': {
        'name': 'Биохимия крови',
        'path': f"{path}/subset_BloodBiochemical_no-sex",
        'path_model': f"{path}/subset_BloodBiochemical_no-sex/models/DANet/1/630", 
        'color': 'cyan',
        'bkg_count': 150,
        'likelihood': 1.0
    },
    "LipidProfile": {
        "name": "Липидный профиль",
        "path": f"{path}/subset_LipidProfile_no-sex",
        "path_model": f"{path}/subset_LipidProfile_no-sex/models/DANet/1/798",
        "color": "gold",
        'bkg_count': 100,
        'likelihood': 0.75
    },
    "CoagulationTest": {
        "name": "Коагулограмма",
        "path": f"{path}/subset_CoagulationTest_no-sex",
        "path_model": f"{path}/subset_CoagulationTest_no-sex/models/DANet/1/868",
        "color": "olive",
        'bkg_count': 100,
        'likelihood': 0.75
    },
    "HormoneProfile": {
        "name": "Гормональный профиль",
        "path": f"{path}/subset_HormoneProfile_no-sex",
        "path_model": f"{path}/subset_HormoneProfile_no-sex/models/DANet/1/425",
        "color": "chocolate",
        'bkg_count': 50,
        'likelihood': 0.6
    },
    "ProstateSpecificAntigenTest": {
        "name": "Простатический специфический антиген",
        "path": f"{path}/subset_ProstateSpecificAntigenTest_no-sex",
        "path_model": f"{path}/subset_ProstateSpecificAntigenTest_no-sex/models/DANet/1/529",
        "color": "lawngreen",
        'bkg_count': 50,
        'likelihood': 0.6
    },
    "RheumatologyScreening": {
        "name": "Ревматологический тест",
        "path": f"{path}/subset_RheumatologyScreening_no-sex",
        "path_model": f"{path}/subset_RheumatologyScreening_no-sex/models/DANet/1/48",
        "color": "gray",
        'bkg_count': 50,
        'likelihood': 0.5
    },
    # "BloodPressure": {
    #     "name": "Кровяное Давление",
    #     "path": f"{path}/subset_BloodPressure_no-sex",
    #     "path_model": f"{path}/subset_BloodPressure_no-sex/models/DANet/1/75",
    #     "color": "orchid",
    #     'bkg_count': 300,
    #     'likelihood': 0.6
    # },
}

feats_all = []
for comp in components:
    components[comp]['data'] = pd.read_excel(f"{components[comp]['path']}/data.xlsx", index_col=0)
    components[comp]['feats'] = pd.read_excel(f"{components[comp]['path']}/feats.xlsx", index_col=0)
    components[comp]['results'] = pd.read_excel(f"{components[comp]['path_model']}/df.xlsx", index_col=0)
    components[comp]['metrics'] = pd.read_excel(f"{components[comp]['path_model']}/metrics.xlsx", index_col=0)
    components[comp]['shap'] = pd.read_excel(f"{components[comp]['path_model']}/explanation.xlsx", index_col=0)
    components[comp]['model'] = TabularModel.load_model(f"{components[comp]['path_model']}")
    components[comp]['corrector'] = LinearBiasCorrector()
    comp_results = components[comp]['results']
    components[comp]['corrector'].fit(comp_results.loc[comp_results['Group'] == 'Train', feat_trgt].values, comp_results.loc[comp_results['Group'] == 'Train', 'Prediction'].values)
    res_cols = ['Group', 'Prediction', 'Error', 'Prediction Unbiased', 'Error Unbiased']
    components[comp]['data'].loc[components[comp]['data'].index, res_cols] = comp_results.loc[components[comp]['data'].index, res_cols]
    components[comp]['data_shap'] = components[comp]['data'].copy()
    
    feats = components[comp]['feats'].index.values
    feats = feats[feats != feat_trgt]
    feats_all += list(feats)
    feats_all += [f"Предсказание {components[comp]['name']}", f"Возрастная Акселерация {components[comp]['name']}"]
    
    components[comp]['feats_corr'] = pd.DataFrame(index=feats, columns=['Correlation'])
    for f in feats:
        components[comp]['feats_corr'].at[f, 'Correlation'], _ = stats.pearsonr(components[comp]['data'].loc[:, f].values, components[comp]['data'].loc[:, feat_trgt].values)

In [None]:
for comp in components:
    print(f"{comp}: {components[comp]['data'].shape[0]}")

# Prediction for samples

In [None]:
nan_part = 0.2

samples = pd.DataFrame(columns=feats_all)

for sample_id in (pbar := tqdm(data.index.values)):
    pbar.set_description(f"Sample {sample_id}")
    
    n_pos = 0
    n_neg = 0
    comp_present = []
    for comp in components:
        feats_w_trgt = components[comp]['feats'].index.values
        feats = feats_w_trgt[feats_w_trgt != feat_trgt]
        n_feats = len(feats)
        n_nans = data.loc[sample_id, feats].isna().sum()
        rho = components[comp]['metrics'].at['Test', 'pearson_corrcoef_unbiased'] * components[comp]['likelihood']
        if n_nans / n_feats < nan_part:
            comp_present.append(comp)
            data_sample = data.loc[[sample_id], feats_w_trgt]
            if n_nans != n_feats:
                data_bkcg = components[comp]['data'].loc[:, feats_w_trgt]
                data_imp = pd.concat([data_sample, data_bkcg], axis=0, ignore_index=True)
                imputer = KNNImputer(n_neighbors=5)
                data_sample.loc[sample_id, feats_w_trgt] = imputer.fit_transform(data_imp.loc[:, feats_w_trgt].values)[0, :]
            pred = components[comp]['model'].predict(data_sample)[f'{feat_trgt}_prediction'].values
            pred = components[comp]['corrector'].predict(pred)
            data_sample.at[sample_id, f"Предсказание {components[comp]['name']}"] = pred
            
            gt = data_sample.at[sample_id, feat_trgt]
            aa = pred - gt
            if aa > 0:
                n_pos += 1
            else:
                n_neg += 1
            
            data_sample.at[sample_id, f"Возрастная Акселерация {components[comp]['name']}"] = aa * rho
            
            samples.loc[sample_id, data_sample.columns] = data_sample.loc[sample_id, data_sample.columns]
            
    if len(comp_present) > 0:
        samples.at[sample_id, "Число моделей"] = len(comp_present)
        samples.at[sample_id, "Число моделей c отрицательной аккселерацией"] = n_neg
        samples.at[sample_id, "Число моделей c положительной аккселерацией"] = n_pos
        samples.at[sample_id, "Возрастная Акселерация"] = 0.0
        for comp in comp_present:
            samples.at[sample_id, f"Модель {components[comp]['name']}"] = True
            samples.at[sample_id, f"Возрастная Акселерация {components[comp]['name']}"] /= max(n_pos, n_neg)
            samples.at[sample_id, "Возрастная Акселерация"] += samples.at[sample_id, f"Возрастная Акселерация {components[comp]['name']}"]
        
samples.insert(len(samples.columns) - 1, "Число моделей", samples.pop("Число моделей"))
samples.insert(len(samples.columns) - 1, "Число моделей c отрицательной аккселерацией", samples.pop("Число моделей c отрицательной аккселерацией"))
samples.insert(len(samples.columns) - 1, "Число моделей c положительной аккселерацией", samples.pop("Число моделей c положительной аккселерацией"))
samples.insert(len(samples.columns) - 1, "Возрастная Акселерация", samples.pop("Возрастная Акселерация"))
samples['ATLAS Возраст'] = samples[feat_trgt] + samples["Возрастная Акселерация"]

samples.to_excel(f"{path}/AtlasAge/data.xlsx")

## Load calculated data

In [4]:
samples = pd.read_excel(f"{path}/AtlasAge/data.xlsx", index_col=0)

# Local explainability

In [None]:
trgt_id = 681241 # 14304 144

local_exlp = {}

n_pos = samples.at[trgt_id, "Число моделей c положительной аккселерацией"]
n_neg = samples.at[trgt_id, "Число моделей c отрицательной аккселерацией"]

for comp in components:
    
    if samples.at[trgt_id, f"Модель {components[comp]['name']}"] == True:
        feats_comp = components[comp]['feats']
        feats_w_trgt = components[comp]['feats'].index.values
        feats = feats_w_trgt[feats_w_trgt != feat_trgt]
        feats_corr = components[comp]['feats_corr']
        metrics = components[comp]['metrics']
        data_shap = components[comp]['data_shap']
        
        def predict_func(X):
            X_df = pd.DataFrame(data=X, columns=feats)
            y = components[comp]['model'].predict(X_df)[f'{feat_trgt}_prediction'].values
            y = components[comp]['corrector'].predict(y)
            return y
        
        mae = metrics.at['Test', 'mean_absolute_error_unbiased']
        rho = metrics.at['Test', 'pearson_corrcoef_unbiased'] * components[comp]['likelihood']
        
        color = components[comp]['color']
        bkg_count = components[comp]['bkg_count']
        
        trgt_age = samples.at[trgt_id, feat_trgt]
        trgt_pred_raw = samples.at[trgt_id, f"Предсказание {components[comp]['name']}"]
        
        data_closest = data_shap.loc[data_shap['Error Unbiased'].abs() < mae * rho, :]
        data_closest = data_closest.iloc[(data_closest['Prediction Unbiased'] - trgt_age).abs().argsort()[:bkg_count]]
        # data_closest = data_shap.iloc[(data_shap['Prediction Unbiased'] - trgt_age).abs().argsort()[:bkg_count]]

        explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, feats].values)
        shap_values = explainer.shap_values(samples.loc[[trgt_id], feats].values)[0]
        shap_values = shap_values * (trgt_pred_raw - trgt_age) / (trgt_pred_raw - explainer.expected_value)
        print(f"{sum(shap_values) - (trgt_pred_raw - trgt_age)}")
        shap_values = shap_values * rho / max(n_pos, n_neg)
        
        
        # SHAP values correction
        shap_corr_thld = 5.0
        shap_corr_to = 1.0
        shap_mean_abs = np.mean(np.abs(shap_values))
        if shap_mean_abs > shap_corr_thld and abs(sum(shap_values)) < shap_corr_thld:
            print('SHAP values correction')
            shap_pos_ids = np.squeeze(np.argwhere(shap_values >= 0))
            shap_neg_ids = np.squeeze(np.argwhere(shap_values < 0))
            
            shap_pos_sum_abs = np.sum(np.abs(shap_values[shap_pos_ids]))
            shap_neg_sum_abs = np.sum(np.abs(shap_values[shap_neg_ids]))
            
            shap_sum_abs_from = np.sum(np.abs(shap_values))
            shap_sum_abs_to = shap_corr_to * len(shap_values)
            
            shap_corr_diff = shap_sum_abs_from - shap_sum_abs_to
            
            for pos_id in shap_pos_ids:
                curr_part = abs(shap_values[pos_id]) / shap_pos_sum_abs
                shap_values[pos_id] -= curr_part * shap_corr_diff * 0.5
            for neg_id in shap_neg_ids:
                curr_part = abs(shap_values[neg_id]) / shap_neg_sum_abs
                shap_values[neg_id] += curr_part * shap_corr_diff * 0.5
        
        
        df_comp = pd.DataFrame(index=feats, columns=['SHAP', 'Values', 'Correlation', 'Percentile', 'Class', 'Consistent', 'Show'])
        df_comp['SHAP'] = shap_values
        df_comp.sort_values(by='SHAP', key=abs, inplace=True)
        df_comp.loc[df_comp.index.values, 'Values'] = samples.loc[trgt_id, df_comp.index.values].values
        df_comp.loc[df_comp.index.values, 'Correlation'] = feats_corr.loc[df_comp.index.values, 'Correlation']
        for f in df_comp.index.values:
            df_comp.at[f, 'Percentile'] = scipy.stats.percentileofscore(data_closest.loc[:, f].values, samples.at[trgt_id, f])
            df_comp.at[f, 'Class'] = int(df_comp.at[f, 'Percentile'] // 33.333334)
            if (
                ((df_comp.at[f, 'Correlation'] > 0) & (df_comp.at[f, 'SHAP'] > 0) & (df_comp.at[f, 'Class'] == 2)) or \
                ((df_comp.at[f, 'Correlation'] > 0) & (df_comp.at[f, 'SHAP'] < 0) & (df_comp.at[f, 'Class'] == 0)) or \
                ((df_comp.at[f, 'Correlation'] < 0) & (df_comp.at[f, 'SHAP'] > 0) & (df_comp.at[f, 'Class'] == 0)) or \
                ((df_comp.at[f, 'Correlation'] < 0) & (df_comp.at[f, 'SHAP'] < 0) & (df_comp.at[f, 'Class'] == 2))
                ):
                df_comp.at[f, 'Consistent'] = True
            else:
                df_comp.at[f, 'Consistent'] = False
            if (
                (df_comp.at[f, 'Consistent'] == True) or \
                ((df_comp.at[f, 'Consistent'] == False) & (df_comp.at[f, 'Class'] == 1))
                ):
                df_comp.at[f, 'Show'] = True
            else:
                df_comp.at[f, 'Show'] = False
        df_comp['Class'] = df_comp['Class'].replace({0: 'Понижен', 1: 'Средний', 2: 'Повышен'})
        
        local_exlp[comp] = {
            'df': df_comp,
            'age_acceleration': samples.at[trgt_id, f"Возрастная Акселерация {components[comp]['name']}"],
        }

df_comps = pd.DataFrame(index=list(local_exlp.keys()), columns=['age_acceleration'])    
for comp in local_exlp:
    df_comps.at[comp, 'age_acceleration'] = local_exlp[comp]['age_acceleration']
df_comps['abs_age_acceleration'] = df_comps['age_acceleration'].abs()
df_comps.sort_values(["abs_age_acceleration"], ascending=[True], inplace=True)
df_comps['cumsum'] = df_comps['age_acceleration'].cumsum()
comps_sorted = df_comps.index.values

trgt_aa = df_comps['age_acceleration'].sum()
trgt_age = samples.at[trgt_id, feat_trgt]

summary_hover_text = ["Хронологический возраст"] + [f"{components[comp]['name']}" for comp in comps_sorted] + ["Биологический возраст"]
summary_measure = ['absolute'] + ['relative'] * len(comps_sorted) + ['absolute']
summary_ys = [-0.5] + [x + 0.5 for x in range(len(comps_sorted))] + [len(comps_sorted) + 0.5]
aas = df_comps['age_acceleration'].tolist()
summary_xs = [trgt_age] + aas + [trgt_age + trgt_aa]
summary_text = [f"{trgt_age:0.2f}"] + [f"+{aa:0.2f}" if aa > 0 else f"{aa:0.2f}" for aa in aas] + [f"{trgt_age+trgt_aa:0.2f}"]

fig = go.Figure()
fig.add_trace(
    go.Waterfall(
        hovertext=summary_hover_text,
        orientation="h",
        measure=summary_measure,
        y=summary_ys,
        x=summary_xs,
        base=0,
        text=summary_text,
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
)
fig.update_yaxes(
    automargin=True,
    tickmode="array",
    tickvals=summary_ys,
    ticktext=summary_hover_text,
    tickfont=dict(size=18),
)
fig.update_xaxes(
    automargin=True,
    title='Возраст',
    titlefont=dict(size=20),
    range=[
        trgt_age + df_comps['cumsum'].min() * 1.2 - 2,
        trgt_age + df_comps['cumsum'].max() * 1.2 + 2
    ],
)
fig.update_layout(
    title=f"Возрастная акселерация для {trgt_id}",
    titlefont=dict(size=25),
    template="none",
    width=800,
    height=len(comps_sorted)*30 + 200,
    margin=go.layout.Margin(l=120, r=100, b=50, t=50, pad=0),
)
fig.show()

for comp in comps_sorted[::-1]:

    df_comp = local_exlp[comp]['df']
    aa_comp = local_exlp[comp]['age_acceleration']

    df_comp.sort_values(['SHAP'], key=abs, inplace=True)
    df_comp['SHAP cumsum'] = df_comp['SHAP'].cumsum()

    n_show = 5
    
    fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 0.5], horizontal_spacing=0.15, subplot_titles=['', 'Корреляция с возрастом'])

    if df_comp.shape[0] <= n_show:

        fig.add_trace(
            go.Waterfall(
                hovertext=["Итоговая акселерация"],
                orientation="h",
                measure=['relative'],
                y=[df_comp.shape[0] + 1],
                x=[aa_comp],
                base=0,
                text=[f"+{aa_comp:0.2f}" if aa_comp > 0 else f"{aa_comp:0.2f}"],
                textposition = "auto",
                decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
                increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
                totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
                connector={
                    "mode": "between",
                    "line": {"width": 1, "color": "black", "dash": "dot"},
                },
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Waterfall(
                hovertext=df_comp.index.values,
                orientation="h",
                measure=["relative"] * len(feats),
                y=list(range(df_comp.shape[0])),
                x=df_comp['SHAP'].values,
                base=0,
                text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_comp['SHAP'].values],
                textposition = "auto",
                decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
                increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
                connector={
                    "mode": "between",
                    "line": {"width": 1, "color": "black", "dash": "solid"},
                },
            ),
            row=1,
            col=1,
        )
        fig.update_traces(row=1, col=1, showlegend=False)
        fig.update_yaxes(
            row=1,
            col=1,
            automargin=True,
            tickmode="array",
            tickvals=list(range(df_comp.shape[0])) + [df_comp.shape[0] + 1],
            ticktext=[f"{x} = {df_comp.at[x, 'Values']:0.2f} ({df_comp.at[x, 'Class']})" for x in df_comp.index] + ['Итоговая акселерация'],
            tickfont=dict(size=18),
        )
        fig.update_xaxes(
            row=1,
            col=1,
            automargin=True,
            title='Возрастная акселерация',
            titlefont=dict(size=20),
        )

        fig.add_trace(
            go.Heatmap(
                x=['Возраст'],
                y=list(range(df_comp.shape[0])),
                z=[[x] for x in df_comp['Correlation'].values],
                text=[[f"{x:0.2f}"] for x in df_comp['Correlation'].values],
                texttemplate="%{text}",
                zmin=-1,
                zmax=1,
                colorbar_y=df_comp.shape[0] / (df_comp.shape[0] + 2) * 0.5,
            ),
            row=1,
            col=2,
        )

        fig.update_layout(
            title=f"{components[comp]['name']}",
            titlefont=dict(size=25),
            template="none",
            width=1800,
            height=df_comp.shape[0]*40 + 100,
            margin=go.layout.Margin(l=120, r=100, b=50, t=50, pad=0),
        )
    else:
        
        feats_show = df_comp.index[df_comp['Show'] == True].tolist()
        if len(feats_show) >= n_show:
            feats_show = feats_show[-n_show:]
        
        df_show = df_comp.loc[feats_show, :]
        df_not_show = df_comp.loc[list(set(df_comp.index.values) - set(feats_show)), :]
        df_show.sort_values(['SHAP'], key=abs, inplace=True)
        df_not_show.sort_values(['SHAP'], key=abs, inplace=True)
        shap_not_show_sum = df_not_show['SHAP'].sum()

        fig.add_trace(
            go.Waterfall(
                hovertext=["Итоговая акселерация"],
                orientation="h",
                measure=['relative'],
                y=[df_show.shape[0] + 1],
                x=[aa_comp],
                base=0,
                text=[f"+{aa_comp:0.2f}" if aa_comp > 0 else f"{aa_comp:0.2f}"],
                textposition = "auto",
                decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
                increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
                totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
                connector={
                    "mode": "between",
                    "line": {"width": 1, "color": "black", "dash": "dot"},
                },
            ),
            row=1,
            col=1,
        )
        fig.add_trace(
            go.Waterfall(
                hovertext=df_show.index.values,
                orientation="h",
                measure=["relative"] * len(feats),
                y=[-1] + list(range(df_show.shape[0])),
                x=[shap_not_show_sum] + df_show['SHAP'].tolist(),
                base=0,
                text=[f"+{shap_not_show_sum:0.2f}" if shap_not_show_sum > 0 else f"{shap_not_show_sum:0.2f}"] + [f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_show['SHAP'].values],
                textposition = "auto",
                decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
                increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
                connector={
                    "mode": "between",
                    "line": {"width": 1, "color": "black", "dash": "solid"},
                },
            ),
            row=1,
            col=1,
        )
        fig.update_traces(row=1, col=1, showlegend=False)
        fig.update_yaxes(
            row=1,
            col=1,
            automargin=True,
            tickmode="array",
            tickvals=[-1] + list(range(df_show.shape[0])) + [df_show.shape[0] + 1],
            ticktext=['Остальные признаки'] + [f"{x} = {df_show.at[x, 'Values']:0.2f} ({df_show.at[x, 'Class']})" for x in df_show.index] + ['Итоговая акселерация'],
            tickfont=dict(size=18),
        )
        fig.update_xaxes(
            row=1,
            col=1,
            automargin=True,
            title='Возрастная акселерация',
            titlefont=dict(size=20),
        )

        fig.add_trace(
            go.Heatmap(
                x=['Возраст'],
                y=list(range(df_show.shape[0])),
                z=[[x] for x in df_show['Correlation'].values],
                text=[[f"{x:0.2f}"] for x in df_show['Correlation'].values],
                texttemplate="%{text}",
                zmin=-1,
                zmax=1,
                colorbar_y=(df_show.shape[0] + 1) / (df_show.shape[0] + 3) * 0.5,
            ),
            row=1,
            col=2,
        )

        fig.update_layout(
            title=f"{components[comp]['name']}",
            titlefont=dict(size=25),
            template="none",
            width=1800,
            height=(df_show.shape[0] + 1) * 40 + 100,
            margin=go.layout.Margin(l=120, r=100, b=50, t=50, pad=0),
        )
    
    fig.update_xaxes(
        row=1,
        col=2,
        tickvals=[],
        automargin=True,
        showgrid=False,
        showline=False,
        zeroline=False,
        showticklabels=False,
    )
    fig.update_yaxes(
        row=1,
        col=2,
        tickvals=[],
        automargin=True,
        showgrid=False,
        showline=False,
        zeroline=False,
        showticklabels=False,
    )
    fig.show()

In [None]:
shap_pos_sum_abs, shap_neg_sum_abs, shap_pos_sum_abs - shap_neg_sum_abs
   

In [None]:
np.sum(np.abs(shap_values[shap_pos_ids])), np.sum(np.abs(shap_values[shap_neg_ids])), np.sum(np.abs(shap_values[shap_pos_ids])) - np.sum(np.abs(shap_values[shap_neg_ids]))