# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
from statsmodels.stats.multitest import multipletests
from regression_bias_corrector import LinearBiasCorrector
import torch
import scipy.stats
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from sklearn.metrics import mean_absolute_error
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from scipy import stats
import optuna
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.cm
from sklearn.preprocessing import LabelEncoder
import plotly.graph_objects as go
from plotly.subplots import make_subplots


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Plot all-in-one figures

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"

path_models = f"{path}/models/oct2025"

expl_type = 'current'

feat_trgt = 'Age'


feats_sets_models = {
    'Эхокардиография': 'DANet/2',
    'Сфигмография': 'DANet/101',
    'Биохимический анализ крови': 'DANet/2',
    'Электрокардиография': 'DANet/28',
    'Антропометрия': 'DANet/148',
    'Общий анализ крови': 'DANet/394',
    'Все': 'DANet/283',
}

colors_feats_sets = {
    'Эхокардиография': 'darkcyan',
    'Сфигмография': 'mediumorchid',
    'Биохимический анализ крови': 'goldenrod',
    'Электрокардиография': 'dodgerblue',
    'Антропометрия': 'chartreuse',
    'Общий анализ крови': 'crimson',
    'Все': 'gray',
}

for feats_set in feats_sets_models:
    data = pd.read_excel(f"{path_models}/{feats_set}/data.xlsx", index_col=0)
    feats = pd.read_excel(f"{path_models}/{feats_set}/feats.xlsx", index_col=0)
    results = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/df.xlsx", index_col=0)
    metrics = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/metrics.xlsx", index_col=0)
    df_shap = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/explanation.xlsx", index_col=0)
    model = TabularModel.load_model(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}")
    corrector = LinearBiasCorrector()
    corrector.fit(results.loc[results['Group'] == 'Train', feat_trgt].values, results.loc[results['Group'] == 'Train', 'Prediction'].values)
    
    
    sns.set_theme(style='ticks')
    fig = plt.figure(
        figsize=(15, 5 + 1.5 + 0.15 * feats.shape[0]),
        layout="constrained"
    )
    subfigs = fig.subfigures(
        nrows=2,
        ncols=1,
        height_ratios=[5, 1.5 + 0.15 * feats.shape[0]],
        wspace=0.01,
        hspace=0.01,
    )
    
    subfigs_row = subfigs[0].subfigures(
        nrows=1,
        ncols=2,
        width_ratios=[1, 1],
        wspace=0.15,
        hspace=0.01,
    )
    
    axs = subfigs_row[0].subplot_mosaic(
        [
            ['table', 'table'],
            ['scatter', 'violin'],
        ],
        # figsize=(6, 1.5 + 6),
        height_ratios=[1, 4],
        width_ratios=[3, 1.5],
        gridspec_kw={
            # "bottom": 0.14,
            # "top": 0.95,
            # "left": 0.1,
            # "right": 0.5,
            "wspace": 0.01,
            "hspace": 0.01,
        },
    )
    subfigs_row[0].suptitle('До преобразования поворота', fontsize='large')

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
    for part in ['Train', 'Validation', 'Test']:
        df_table.at['MAE', part] = f"{metrics.at[part, 'mean_absolute_error']:0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{metrics.at[part, 'pearson_corrcoef']:0.3f}"
        df_table.at["Bias", part] = f"{metrics.at[part, 'bias']:0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            # border="both",
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
        )
    ]
    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs['table'],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    xy_min, xy_max = np.quantile(results[[feat_trgt, 'Prediction']].values.flatten(), [0.01, 0.99])
    xy_ptp = xy_max - xy_min

    kdeplot = sns.kdeplot(
        data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
        x=feat_trgt,
        y='Prediction',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=colors_feats_sets[feats_set],
        legend=False,
        ax=axs['scatter']
    )
    scatter = sns.scatterplot(
        data=results.loc[results['Group'] == 'Test', :],
        x=feat_trgt,
        y="Prediction",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=25,
        color=colors_feats_sets[feats_set],
        ax=axs['scatter'],
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs['scatter']
    )
    regplot = sns.regplot(
        data=results.loc[results['Group'] == 'Train', :],
        x=feat_trgt,
        y='Prediction',
        color='k',
        scatter=False,
        truncate=False,
        ax=axs['scatter']
    )
    axs['scatter'].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylabel("Биологический возраст")
    axs['scatter'].set_xlabel("Возраст")

    violin = sns.violinplot(
        data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
        x=[0] * results.loc[results['Group'].isin(['Train', 'Validation']), :].shape[0],
        y='Error',
        color=make_rgb_transparent(mcolors.to_rgb(colors_feats_sets[feats_set]), (1, 1, 1), 0.5),
        density_norm='width',
        saturation=0.75,
        linewidth=1.0,
        ax=axs['violin'],
        legend=False,
    )
    swarm = sns.swarmplot(
        data=results.loc[results['Group'] == 'Test', :],
        x=[0] * results.loc[results['Group'] == 'Test', :].shape[0],
        y='Error',
        color=colors_feats_sets[feats_set],
        linewidth=0.5,
        ax=axs['violin'],
        size= 50 / np.sqrt(results.loc[results['Group'] == 'Test', :].shape[0]),
        legend=False,
    )
    axs['violin'].set_ylabel('Возрастная акселерация')
    axs['violin'].set_xlabel('')
    axs['violin'].set(xticklabels=[]) 
    axs['violin'].set(xticks=[])
    
    
    axs = subfigs_row[1].subplot_mosaic(
        [
            ['table', 'table'],
            ['scatter', 'violin'],
        ],
        # figsize=(6, 1.5 + 6),
        height_ratios=[1, 4],
        width_ratios=[3, 1.5],
        gridspec_kw={
            # "bottom": 0.14,
            # "top": 0.95,
            # "left": 0.1,
            # "right": 0.5,
            "wspace": 0.01,
            "hspace": 0.01,
        },
    )
    subfigs_row[1].suptitle('После преобразования поворота', fontsize='large')

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
    for part in ['Train', 'Validation', 'Test']:
        df_table.at['MAE', part] = f"{metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
        df_table.at["Bias", part] = f"{metrics.at[part, 'bias_unbiased']:0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            # border="both",
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
        )
    ]
    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs['table'],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    xy_min, xy_max = np.quantile(results[[feat_trgt, 'Prediction Unbiased']].values.flatten(), [0.01, 0.99])
    xy_ptp = xy_max - xy_min

    kdeplot = sns.kdeplot(
        data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
        x=feat_trgt,
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=colors_feats_sets[feats_set],
        legend=False,
        ax=axs['scatter']
    )
    scatter = sns.scatterplot(
        data=results.loc[results['Group'] == 'Test', :],
        x=feat_trgt,
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=25,
        color=colors_feats_sets[feats_set],
        ax=axs['scatter'],
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs['scatter']
    )
    regplot = sns.regplot(
        data=results.loc[results['Group'] == 'Train', :],
        x=feat_trgt,
        y='Prediction Unbiased',
        color='k',
        scatter=False,
        truncate=False,
        ax=axs['scatter']
    )
    axs['scatter'].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylabel("Биологический возраст")
    axs['scatter'].set_xlabel("Возраст")

    violin = sns.violinplot(
        data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
        x=[0] * results.loc[results['Group'].isin(['Train', 'Validation']), :].shape[0],
        y='Error Unbiased',
        color=make_rgb_transparent(mcolors.to_rgb(colors_feats_sets[feats_set]), (1, 1, 1), 0.5),
        density_norm='width',
        saturation=0.75,
        linewidth=1.0,
        ax=axs['violin'],
        legend=False,
    )
    swarm = sns.swarmplot(
        data=results.loc[results['Group'] == 'Test', :],
        x=[0] * results.loc[results['Group'] == 'Test', :].shape[0],
        y='Error Unbiased',
        color=colors_feats_sets[feats_set],
        linewidth=0.5,
        ax=axs['violin'],
        size= 50 / np.sqrt(results.loc[results['Group'] == 'Test', :].shape[0]),
        legend=False,
    )
    axs['violin'].set_ylabel('Возрастная акселерация')
    axs['violin'].set_xlabel('')
    axs['violin'].set(xticklabels=[]) 
    axs['violin'].set(xticks=[])
    
    
    if expl_type == 'recalc_gradient':
        df_shap = model.explain(data, method="GradientShap", baselines="b|100000")
        df_shap.index = data.index
    elif expl_type == 'recalc_sampling':
        ds_data_shap = data.copy()
        ds_cat_encoders = {}
        for f in feats.index:
            ds_cat_encoders[f] = LabelEncoder()
            ds_data_shap[f] = ds_cat_encoders[f].fit_transform(ds_data_shap[f])
        def predict_func(X):
            X_df = pd.DataFrame(data=X, columns=feats.index.to_list())
            for f in feats.index:
                X_df[f] = ds_cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
            y = model.predict(X_df)[f'{feat_trgt}_prediction'].values
            y = corrector.predict(y)
            return y
        explainer = shap.SamplingExplainer(predict_func, ds_data_shap.loc[:, feats.index.to_list()].values)
        print(explainer.expected_value)
        shap_values = explainer.shap_values(ds_data_shap.loc[:, feats.index.to_list()].values)
        df_shap = pd.DataFrame(index=data.index, columns=feats.index.to_list(), data=shap_values)

    
    ds_fi = pd.DataFrame(index=feats.index.to_list(), columns=['mean(|SHAP|)', 'rho'])
    for f in feats.index.to_list():
        ds_fi.at[f, 'mean(|SHAP|)'] = df_shap[f].abs().mean()
        df_tmp = data.loc[:, [feat_trgt, f]].dropna(axis=0, how='any')
        if df_tmp.shape[0] > 1:
            vals_1 = df_tmp.loc[:, feat_trgt].values
            vals_2 = df_tmp.loc[:, f].values
            ds_fi.at[f, 'rho'], _ = scipy.stats.pearsonr(vals_1, vals_2)
    ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    ds_fi['Features'] = ds_fi.index.values
    
    axs_importance = subfigs[1].subplots(1, 3, width_ratios=[1, 4, 8], gridspec_kw={'wspace':0.02, 'hspace': 0.02}, sharey=False, sharex=False)
    
    heatmap = sns.heatmap(
        ds_fi.loc[:, ['rho']].apply(pd.to_numeric).values,
        yticklabels=ds_fi.index.to_list(),
        annot=True,
        fmt=".2f",
        vmin=-1.0,
        vmax=1.0,
        cmap='coolwarm',
        linewidth=0.1,
        linecolor='black',
        cbar=False,
        #annot_kws={"fontsize": 15},
        # cbar_kws={
        #     # "shrink": 0.9,
        #     # "aspect": 30,
        #     #'fraction': 0.046, 
        #     #'pad': 0.04,
        # },
        ax=axs_importance[0]
    )
    # axs_importance[0].set(yticklabels=ds_fi.index.to_list())
    # heatmap_pos = axs_importance[2].get_position()
    # axs_importance[2].figure.axes[-1].set_position([heatmap_pos.x1 + 0.05, heatmap_pos.y0, 0.1, heatmap_pos.height])
    # axs_importance[2].figure.axes[-1].set_ylabel(r"Pearson $\rho$")
    # for spine in axs_importance[2].figure.axes[-1].spines.values():
    #     spine.set(visible=True, lw=0.25, edgecolor="black")
    # axs_importance[2].set_xlabel('')
    # axs_importance[2].set_ylabel('')
    # axs_importance[2].set(xticklabels=[])
    # axs_importance[2].set(xticks=[])
    
    
    barplot = sns.barplot(
        data=ds_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=colors_feats_sets[feats_set],
        edgecolor='black',
        dodge=False,
        ax=axs_importance[1]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=4.0)
    axs_importance[1].set_ylabel('')
    # axs_importance[1].set(yticklabels=ds_fi.index.to_list())
    # axs_importance[1].set(yticklabels=[])

    is_colorbar = False
    f_legends = []
    for f in ds_fi.index:
        
        if df_shap[f].abs().max() > 10:
            f_shap_ll = df_shap[f].quantile(0.01)
            f_shap_hl = df_shap[f].quantile(0.99)
        else:
            f_shap_ll = df_shap[f].min()
            f_shap_hl = df_shap[f].max()
        
        f_index = df_shap.index[(df_shap[f] >= f_shap_ll) & (df_shap[f] <= f_shap_hl)].values
        f_shap = df_shap.loc[f_index, f].values
        f_vals = data.loc[f_index, f].values
        
        f_cmap = sns.color_palette("Spectral_r", as_cmap=True)
        f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
        f_colors = {}
        for cval in f_vals:
            f_colors.update({cval: f_cmap(f_norm(cval))})

        strip = sns.stripplot(
            x=f_shap,
            y=[f]*len(f_shap),
            hue=f_vals,
            palette=f_colors,
            jitter=0.35,
            alpha=0.5,
            edgecolor='gray',
            linewidth=0.1,
            size=25 / np.sqrt(results.loc[results['Group'] == 'Test', :].shape[0]),
            legend=False,
            ax=axs_importance[2],
        )
        
        if not is_colorbar:
            sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
            sm.set_array([])
            cbar = strip.figure.colorbar(sm)
            # cbar.set_label('Значения\nчисленных\nпризнаков', labelpad=-8, fontsize='large')
            cbar.set_ticks([min(f_vals), max(f_vals)])
            cbar.set_ticklabels(["Min", "Max"])
            is_colorbar = True 
    # axs_importance[2].set(yticklabels=[])
    axs_importance[2].set_xlabel('SHAP')
    
    df_shap.to_excel(f"{path_models}/{feats_set}/model_importance.xlsx")
    
    fig.savefig(f"{path_models}/{feats_set}/model.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{feats_set}/model.pdf", bbox_inches='tight')
    plt.close(fig)

# Models inference

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"

path_models = f"{path}/models/oct2025"

expl_type = 'current'

feat_trgt = 'Age'

df = pd.read_excel(f"{path}/data_bioage_all.xlsx", index_col=0)
df.index = df.index.astype(str)

feats_sets_models = {
    'Эхокардиография': 'DANet/2',
    'Сфигмография': 'DANet/101',
    'Биохимический анализ крови': 'DANet/2',
    'Электрокардиография': 'DANet/28',
    'Антропометрия': 'DANet/148',
    'Общий анализ крови': 'DANet/394',
    'Все': 'DANet/283',
}

colors_feats_sets = {
    'Эхокардиография': 'darkcyan',
    'Сфигмография': 'mediumorchid',
    'Биохимический анализ крови': 'goldenrod',
    'Электрокардиография': 'dodgerblue',
    'Антропометрия': 'chartreuse',
    'Общий анализ крови': 'crimson',
    'Все': 'gray',
}

# df[f'Modular Age Acceleration'] = 0.0
# df[f'Modular BioAge'] = 0.0
for feats_set in feats_sets_models:
    data = pd.read_excel(f"{path_models}/{feats_set}/data.xlsx", index_col=0)
    feats = pd.read_excel(f"{path_models}/{feats_set}/feats.xlsx", index_col=0)
    results = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/df.xlsx", index_col=0)
    metrics = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/metrics.xlsx", index_col=0)
    df_shap = pd.read_excel(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}/explanation.xlsx", index_col=0)
    model = TabularModel.load_model(f"{path_models}/{feats_set}/models/{feats_sets_models[feats_set]}")
    corrector = LinearBiasCorrector()
    corrector.fit(results.loc[results['Group'] == 'Train', feat_trgt].values, results.loc[results['Group'] == 'Train', 'Prediction'].values)
    
    df[f'Prediction {feats_set}'] = model.predict(df)
    df[f'Prediction Corrected {feats_set}'] = corrector.predict(df[f'Prediction {feats_set}'])
    df[f'Error {feats_set}'] = df[f'Prediction {feats_set}'] - df[feat_trgt]
    df[f'Error Corrected {feats_set}'] = df[f'Prediction Corrected {feats_set}'] - df[feat_trgt]
    
    # feats_corr = pd.DataFrame(index=feats, columns=['Correlation'])
    # for f in feats.index.values:
    #     df_f_corr = df.dropna(subset=[feat_trgt, f])
    #     feats_corr.at[f, 'Correlation'], _ = scipy.stats.pearsonr(df_f_corr.loc[:, f].values, df_f_corr.loc[:, feat_trgt].values)
    
    # rho = metrics.at['Test', 'pearson_corrcoef_unbiased'] * feats_corr['Correlation'].abs().max()
    # print(f"rho = {rho:0.3f}")
    
    # if feats_set != 'Все':
    #     ids_not_na = df.index[df[f'Error Corrected {feats_set}'].notna()]
    #     df.loc[ids_not_na, f'Modular Age Acceleration {feats_set}'] = df.loc[ids_not_na, f'Error Corrected {feats_set}'] * rho
    #     df.loc[ids_not_na, f'Modular Age Acceleration'] += df.loc[ids_not_na,f'Modular Age Acceleration {feats_set}']
        
# df[f'Modular BioAge'] = df[feat_trgt] + df[f'Modular Age Acceleration']

# modular_mae = mean_absolute_error(df[feat_trgt].values, df[f'Modular BioAge'].values)
# modular_pearsonr = scipy.stats.pearsonr(df.loc[:, f'Modular BioAge'].values, df.loc[:, feat_trgt].values)
    
df.to_excel(f"{path}/result.xlsx")

# Comparing multiple features sets

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"

path_models = f"{path}/models/oct2025"

feats_sets_models = {
    'Эхокардиография': 'DANet/2',
    # 'Сфигмография': '113',
    # 'Биохимический анализ крови': '125',
    # 'Электрокардиография': '70',
    # 'Антропометрия': '108',
    # 'Общий анализ крови': '112',
    # 'Все': '2',
}

colors_feats_sets = {
    'Эхокардиография': 'darkcyan',
    # 'Сфигмография': 'mediumorchid',
    # 'Биохимический анализ крови': 'goldenrod',
    # 'Электрокардиография': 'dodgerblue',
    # 'Антропометрия': 'chartreuse',
    # 'Общий анализ крови': 'crimson',
    # 'Все': 'gray',
}

f

## Age histograms

In [None]:
n_rows = 2
n_cols = 3
fig_width = 15
fig_height = 9
hist_bins = np.linspace(5, 115, 23)

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=True, sharex=True)
for fs_id, (fs, model_id) in enumerate(feats_sets_models.items()):
    row_id, col_id = divmod(fs_id, n_cols)
    
    df_res = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/df.xlsx", index_col=0)
    
    histplot = sns.histplot(
        data=df_res,
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=colors_feats_sets[fs],
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set(xlim=(0, 120))
    axs[row_id, col_id].set_title(f"{feats_set_rename[fs]} ({df_res.shape[0]})")
fig.tight_layout()    
fig.savefig(f"{path}/hist_age_for_feats.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/hist_age_for_feats.pdf", bbox_inches='tight')
plt.close(fig)

## Scatters and KDEs with metrics

In [None]:
n_rows = 2 * 3
n_cols = 3
fig_width = 12
fig_height = 12

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.8, 0.1]*2, gridspec_kw={'wspace':0.35, 'hspace': 0.05}, sharey=False, sharex=False)

df_metrics_global = pd.DataFrame(index=list(feats_sets_models.keys()), columns=['MAE', 'Rho', 'Bias'])

for fs_id, (fs, model_id) in enumerate(feats_sets_models.items()):
    row_id, col_id = divmod(fs_id, n_cols)
    row_id_table = row_id * 3
    row_id_scatter = row_id * 3 + 1
    row_id_empty = row_id * 3 + 2
    
    fs_color = colors_feats_sets[fs]
    
    df_results = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/df.xlsx", index_col=0)
    df_metrics = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/metrics.xlsx", index_col=0)
    
    df_metrics_global.at[fs, 'MAE'] = df_metrics.at['Test', 'mean_absolute_error_unbiased']
    df_metrics_global.at[fs, 'Rho'] = df_metrics.at['Test', 'pearson_corrcoef_unbiased']
    df_metrics_global.at[fs, 'Bias'] = df_metrics.at['Test', 'bias_unbiased']
    
    xy_min = df_results[['Age', 'Prediction Unbiased']].min().min()
    xy_max = df_results[['Age', 'Prediction Unbiased']].max().max()
    xy_ptp = xy_max - xy_min

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
    for part in ['Train', 'Validation', 'Test']:
        df_table.at['MAE', part] = f"{df_metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{df_metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
        df_table.at["Bias", part] = f"{df_metrics.at[part, 'bias_unbiased']:0.3f}"
    
    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            # border="both",
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
            group=feats_set_rename[fs],
        )
    ]

    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table, col_id],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    kdeplot = sns.kdeplot(
        data=df_results.loc[df_results['Group'] != 'Test', :],
        x='Age',
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        color=fs_color,
        thresh=0.05,
        cut=0,
        legend=False,
        ax=axs[row_id_scatter, col_id]
    )
    scatter = sns.scatterplot(
        data=df_results.loc[df_results['Group'] == 'Test', :],
        x='Age',
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=35,
        color=fs_color,
        ax=axs[row_id_scatter, col_id],
    )
    axs[row_id_scatter, col_id].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter, col_id].set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    axs[row_id_scatter, col_id].set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    
    axs[row_id_empty, col_id].axis('off')

fig.tight_layout()    
fig.savefig(f"{path}/models_for_feats.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/models_for_feats.pdf", bbox_inches='tight')
plt.close(fig)

df_metrics_global.rename(index=feats_set_rename, inplace=True)

col_defs = [
    ColumnDefinition(
        name="index",
        title="Features Set",
        textprops={"ha": "right", "weight": "bold"},
        width=2.25,
    ),
    ColumnDefinition(
        name="MAE",
        title="Test\nMAE",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(df_metrics_global["MAE"].dropna(), cmap=matplotlib.cm.Reds, num_stds=2.5),
        width=1.0,
    ),
    ColumnDefinition(
        name="Rho",
        title="Test\n" + r"Pearson $\rho$",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=normed_cmap(df_metrics_global["Rho"].dropna(), cmap=matplotlib.cm.Greens, num_stds=2.5),
        width=1.0,
        border="left"
    ),
    ColumnDefinition(
        name="Bias",
        title="Test\nBias",
        textprops={"ha": "center"},
        formatter="{:.3f}",
        cmap=centered_cmap(df_metrics_global["Bias"].dropna(), cmap=matplotlib.cm.seismic, num_stds=2.5),
        width=1.0,
        border="left"
    ),
]
# df_metrics_global.insert(0, 'Feature Set', df_metrics_global.index)

fig, ax = plt.subplots()
table = Table(
    df_metrics_global,
    column_definitions=col_defs,
    row_dividers=True,
    footer_divider=False,
    odd_row_color="#ffffff",
    even_row_color="#f0f0f0",
    ax=ax,
    # textprops={"fontsize": 10},
    row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
    col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
    column_border_kw={"linewidth": 1, "linestyle": "-"},
).autoset_fontcolors(colnames=df_metrics_global.columns.to_list())
fig.savefig(f"{path}/test_metrics_for_feats.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/test_metrics_for_feats.pdf", bbox_inches='tight')
plt.close(fig)

## Features importance

In [None]:
n_rows = 1
n_cols = 6
fig_width = 20
fig_height = 9

sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), sharey=False, sharex=False)

for fs_id, (fs, model_id) in enumerate(feats_sets_models.items()):
    row_id, col_id = divmod(fs_id, n_cols)
    
    fs_color = colors_feats_sets[fs]
    
    fs_feats = pd.read_excel(f"{path}/{fs}/feats.xlsx", index_col=0)
    feats_cnt = fs_feats.index[fs_feats['data_type'].isin(['decimal', 'integer'])].to_list()
    
    df_results = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/df.xlsx", index_col=0)
    df_metrics = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/metrics.xlsx", index_col=0)
    df_explain = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/explanation.xlsx", index_col=0)
    
    df_fi = pd.DataFrame(index=feats_cnt, columns=['mean(|SHAP|)'])
    for f in feats_cnt:
        df_fi.at[f, 'mean(|SHAP|)'] = df_explain[f].abs().mean()
    df_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    df_fi = df_fi.head(30)
    df_fi['Features'] = df_fi.index.values
    barplot = sns.barplot(
        data=df_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=fs_color,
        edgecolor='black',
        dodge=False,
        ax=axs[col_id]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', fmt='%0.2f', fontsize=12, padding=4.0)
    axs[col_id].set_title(feats_set_rename[fs], fontsize='large')
    axs[col_id].set_ylabel('')

fig.tight_layout()    
fig.savefig(f"{path}/barplot_importance_for_feats.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/barplot_importance_for_feats.pdf", bbox_inches='tight')
plt.close(fig)

# Features correlations

In [None]:
feats_set_rename = {
    'Anthropometry': 'Антропометрия',
    'Complete Blood Count': 'Общий анализ крови',
    'Blood Biochemical': 'Биохимический анализ крови',
    'Electrocardiography': 'Электрокардиография',
    'Echocardiography': 'Эхокардиография',
    'Sphygmocardiography': 'Сфигмография',
    'Arterial Stiffness': 'Cосудистая жесткость',
    'Best Correlation': 'Лучшие из разных групп'
}


for fs in feats_set_rename:
    fs_feats = pd.read_excel(f"{path}/{fs}/feats.xlsx", index_col=0)
    fs_df = pd.read_excel(f"{path}/{fs}/data.xlsx", index_col=0)
    
    feats_cnt = ['Age'] + fs_feats.index[fs_feats['data_type'].isin(['decimal', 'integer'])].to_list()
    df_corr = pd.DataFrame(data=np.zeros(shape=(len(feats_cnt), len(feats_cnt))), index=feats_cnt, columns=feats_cnt)
    for f_id_1 in range(len(feats_cnt)):
        for f_id_2 in range(f_id_1, len(feats_cnt)):
            f_1 = feats_cnt[f_id_1]
            f_2 = feats_cnt[f_id_2]
            if f_id_1 != f_id_2:
                vals_1 = fs_df.loc[:, f_1].values
                vals_2 = fs_df.loc[:, f_2].values
                corr, pval = stats.pearsonr(vals_1, vals_2)
                df_corr.at[f_2, f_1] = pval
                df_corr.at[f_1, f_2] = corr
            else:
                df_corr.at[f_2, f_1] = np.nan
    selection = np.tri(df_corr.shape[0], df_corr.shape[1], -1, dtype=bool)
    df_fdr = df_corr.where(selection).stack().reset_index()
    df_fdr.columns = ['row', 'col', 'pval']
    _, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
    nzmin = df_fdr['pval_fdr_bh'][df_fdr['pval_fdr_bh'].gt(0)].min(0) * 0.5
    df_fdr['pval_fdr_bh'].replace({0.0: nzmin}, inplace=True)
    df_corr_fdr = df_corr.copy()
    for line_id in range(df_fdr.shape[0]):
        df_corr_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])
    df_corr_fdr.to_excel(f"{path}/{fs}/feats_pearsonr.xlsx")
        
    sns.set_theme(style='ticks')
    # fig, ax = plt.subplots(figsize=(4.5 + 0.25 * len(feats_cnt), 2.5 + 0.25 * len(feats_cnt)))
    fig, ax = plt.subplots(figsize=(4.5 + 0.25 * len(feats_cnt), 2.5 + 0.25 * len(feats_cnt)))
    cmap_triu = plt.get_cmap("seismic").copy()
    mask_triu=np.tri(len(feats_cnt), len(feats_cnt), -1, dtype=bool)
    heatmap_diff = sns.heatmap(
        df_corr_fdr,
        mask=mask_triu,
        annot=True,
        fmt=".2f",
        center=0.0,
        cmap=cmap_triu,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"fontsize": 25 / np.sqrt(len(df_corr_fdr.values))},
        ax=ax
    )
    ax.figure.axes[-1].set_ylabel(r"Pearson $\rho$", size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
        
    cmap_tril = plt.get_cmap("viridis").copy()
    cmap_tril.set_under('black')
    mask_tril=np.tri(len(feats_cnt), len(feats_cnt), -1, dtype=bool).T
    heatmap_pval = sns.heatmap(
        df_corr_fdr,
        mask=mask_tril,
        annot=True,
        fmt=".1f",
        vmin=-np.log10(0.05),
        cmap=cmap_tril,
        linewidth=0.1,
        linecolor='black',
        annot_kws={"fontsize": 25 / np.sqrt(len(df_corr_fdr.values))},
        ax=ax
    )
    ax.figure.axes[-1].set_ylabel(r"$-\log_{10}(\mathrm{p-value})$", size=13)
    for spine in ax.figure.axes[-1].spines.values():
        spine.set(visible=True, lw=0.25, edgecolor="black")
    ax.set_xlabel('', fontsize=16)
    ax.set_ylabel('', fontsize=16)
    ax.set_title(feats_set_rename[fs])
    plt.savefig(f"{path}/{fs}/feats_pearsonr.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path}/{fs}/feats_pearsonr.pdf", bbox_inches='tight')
    plt.close(fig)

# Models with selected features

In [None]:
path = f"E:/YandexDisk/Work/bbd/mriya"
models_type = 'models_v2'

feats_sets_models = {
    # 'Best Correlation': '45',
    # 'Arterial Stiffness': '86',
    # 'MaxSamplesMaxFeatures': '1',
    
    # 'Echocardiography': 102,
    # 'Complete Blood Count': 2,
    # 'Blood Biochemical': 6,
    # 'Sphygmocardiography': 12,
    # 'Anthropometry': '21',
    # 'Electrocardiography': '59',
    
    # 'Антропометрия': '4',
    'Биохимический анализ крови': '14',
    'Общий анализ крови': '0',
    'Сфигмография': '66',
    'Электрокардиография': '0',
    'Эхокардиография': '21'
}

feats_set_rename = {
    # 'Arterial Stiffness': 'Cосудистая жесткость',
    # 'Best Correlation': 'Лучшие из разных групп',
    # 'MaxSamplesMaxFeatures': 'MaxSamplesMaxFeatures',
    
    # 'Echocardiography': 'Эхокардиография',
    # 'Complete Blood Count': 'Общий анализ крови',
    # 'Blood Biochemical': 'Биохимический анализ крови',
    # 'Sphygmocardiography': 'Сфигмография',
    # 'Anthropometry': 'Антропометрия',
    # 'Electrocardiography': 'Электрокардиография',
    
    # 'Антропометрия': 'Антропометрия',
    'Биохимический анализ крови': 'Биохимический анализ крови',
    'Общий анализ крови': 'Общий анализ крови',
    'Сфигмография': 'Сфигмография',
    'Электрокардиография': 'Электрокардиография',
    'Эхокардиография': 'Эхокардиография'
    
}

colors_feats_sets = {
    # 'Best Correlation': 'crimson',
    # 'Arterial Stiffness': 'dodgerblue',
    # 'MaxSamplesMaxFeatures': 'crimson',
    
    # 'Echocardiography': 'crimson',
    # 'Complete Blood Count': 'crimson',
    # 'Blood Biochemical': 'crimson',
    # 'Sphygmocardiography': 'crimson',
    # 'Anthropometry': 'crimson',
    # 'Electrocardiography': 'crimson',
    
    # 'Антропометрия': 'chartreuse',
    'Эхокардиография': 'darkcyan',
    'Сфигмография': 'mediumorchid',
    'Биохимический анализ крови': 'goldenrod',
    'Электрокардиография': 'dodgerblue',
    'Общий анализ крови': 'crimson',
}

for fs_id, (fs, model_id) in enumerate(feats_sets_models.items()):
    
    fs_data = pd.read_excel(f"{path}/{models_type}/{fs}/data.xlsx", index_col=0)
    fs_feats = pd.read_excel(f"{path}/{models_type}/{fs}/feats.xlsx", index_col=0)
    fs_results = pd.read_excel(f"{path}/{models_type}/{fs}/models/DANet/{model_id}/df.xlsx", index_col=0)
    # fs_results.loc[fs_results.index, 'РИСК'] = fs_data.loc[fs_results.index, 'РИСК']
    fs_metrics = pd.read_excel(f"{path}/{models_type}/{fs}/models/DANet/{model_id}/metrics.xlsx", index_col=0)
    df_explain = pd.read_excel(f"{path}/{models_type}/{fs}/models/DANet/{model_id}/explanation.xlsx", index_col=0)
    
    # xy_min = fs_results[['Age', 'Prediction Unbiased']].min().min()
    # xy_max = fs_results[['Age', 'Prediction Unbiased']].max().max()
    # xy_ptp = xy_max - xy_min
    
    xy_min = 5
    xy_max = 100
    xy_ptp = xy_max - xy_min
    
    feats_cnt = fs_feats.index[fs_feats['data_type'].isin(['decimal', 'integer'])].to_list()
    
    n_rows = 4
    n_cols = 1
    fig_height = 6
    fig_width = 4
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), height_ratios=[0.2, 0.2, 0.8, 0.15], gridspec_kw={'wspace':0.25, 'hspace': 0.05}, sharey=False, sharex=False)

    row_id_table = 0
    row_id_hist = 1
    row_id_scatter = 2
    row_id_empty = 3

    df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
    for part in ['Train', 'Validation', 'Test']:
        df_table.at['MAE', part] = f"{fs_metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
        df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{fs_metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
        df_table.at["Bias", part] = f"{fs_metrics.at[part, 'bias_unbiased']:0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=2.2,
            group=feats_set_rename[fs],
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
            group=feats_set_rename[fs],
        )
    ]

    table = Table(
        df_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs[row_id_table],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    hist_bins = np.linspace(0, 120, 25)
    histplot = sns.histplot(
        data=fs_results,
        bins=hist_bins,
        edgecolor='k',
        linewidth=1,
        x="Age",
        color=colors_feats_sets[fs],
        ax=axs[row_id_hist]
    )
    axs[row_id_hist].set_xticks([])
    axs[row_id_hist].set_xlim(xy_min-0.05*xy_ptp, xy_max+0.05*xy_ptp)
    axs[row_id_hist].set_ylabel("Count")

    kdeplot = sns.kdeplot(
        data=fs_results.loc[fs_results['Group'].isin(['Train', 'Validation']), :],
        x='Age',
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=colors_feats_sets[fs],
        cut=0,
        legend=False,
        ax=axs[row_id_scatter]
    )
    scatter = sns.scatterplot(
        data=fs_results.loc[fs_results['Group'] == 'Test', :],
        x='Age',
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=15,
        color=colors_feats_sets[fs],
        ax=axs[row_id_scatter],
    )
    axs[row_id_scatter].axline((0, 0), slope=1, color="black", linestyle=":")
    axs[row_id_scatter].set_xlim(xy_min-0.05*xy_ptp, xy_max+0.05*xy_ptp)
    axs[row_id_scatter].set_ylim(xy_min-0.05*xy_ptp, xy_max+0.05*xy_ptp)
    axs[row_id_scatter].set_ylabel("Prediction")
    axs[row_id_scatter].set_xlabel("Age")
    axs[row_id_empty].axis('off')
    fig.tight_layout()
    fig.savefig(f"{path}/{models_type}/{fs}/model.png", bbox_inches='tight', dpi=400)
    fig.savefig(f"{path}/{models_type}/{fs}/model.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fi = pd.DataFrame(index=feats_cnt, columns=['mean(|SHAP|)'])
    for f in feats_cnt:
        df_fi.at[f, 'mean(|SHAP|)'] = df_explain[f].abs().mean()
    df_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    fig, ax = plt.subplots(figsize=(8, 1.5 + 0.15 * df_fi.shape[0]))
    df_fi['Features'] = df_fi.index.values
    barplot = sns.barplot(
        data=df_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=colors_feats_sets[fs],
        edgecolor='black',
        dodge=False,
        ax=ax
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', fmt='%0.2f', fontsize=12, padding=4.0)
    ax.set_title(feats_set_rename[fs], fontsize='large')
    ax.set_ylabel('')
    fig.tight_layout()    
    fig.savefig(f"{path}/{models_type}/{fs}/importance_for_feats.png", bbox_inches='tight', dpi=300)
    fig.savefig(f"{path}/{models_type}/{fs}/importance_for_feats.pdf", bbox_inches='tight')
    plt.close(fig)
    
    # Violin plot for risk
    # _, pval = mannwhitneyu(
    #     fs_results.loc[fs_results['РИСК'] == 'Высокий', 'Error Unbiased'].values,
    #     fs_results.loc[fs_results['РИСК'] == 'Низкий', 'Error Unbiased'].values,
    #     alternative="two-sided",
    # )
    # sns.set_theme(style='ticks')
    # fig, ax = plt.subplots(figsize=(4, 5), layout='constrained')
    # sns.violinplot(
    #     data=fs_results.loc[fs_results['РИСК'].isin(['Низкий', 'Высокий']), :],
    #     x='РИСК',
    #     y='Error Unbiased',
    #     palette={'Низкий': 'dodgerblue', 'Высокий': 'crimson'},
    #     scale='width',
    #     order=['Низкий', 'Высокий'],
    #     saturation=0.75,
    #     ax=ax,
    #     legend=False,
    #     cut=0,
    # )
    # ax.set_ylabel('Error')
    # title = f'Mann-Whitney: {pval:.2e}'
    # ax.set_title(title)
    # fig.savefig(f"{path}/{models_type}/{fs}/РИСК.png", bbox_inches='tight', dpi=200)
    # fig.savefig(f"{path}/{models_type}/{fs}/РИСК.pdf", bbox_inches='tight')
    # plt.close(fig)

# Local explainability

In [None]:
path = f"D:/YandexDisk/Work/bbd/mriya"
feat_trgt = 'Age'
fs = 'Best Correlation'
model_id = 45

data = pd.read_excel(f"{path}/{fs}/data.xlsx", index_col=0)
feats = pd.read_excel(f"{path}/{fs}/feats.xlsx", index_col=0)
results = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/df.xlsx", index_col=0)
metrics = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/metrics.xlsx", index_col=0)
explain = pd.read_excel(f"{path}/{fs}/models/DANet/{model_id}/explanation.xlsx", index_col=0)
model = TabularModel.load_model(f"{path}/{fs}/models/DANet/{model_id}")
corrector = LinearBiasCorrector()
corrector.fit(results.loc[results['Group'] == 'Train', feat_trgt].values, results.loc[results['Group'] == 'Train', 'Prediction'].values)
res_cols = ['Group', 'Prediction', 'Error', 'Prediction Unbiased', 'Error Unbiased']
data.loc[data.index, res_cols] = results.loc[data.index, res_cols]

feats_cnt = feats.index[feats['data_type'].isin(['decimal', 'integer'])].to_list()
feats_cnt = list(feats_cnt[feats_cnt != 'Возраст'])
feats_cat = feats.index[feats['data_type'].isin(['enum'])].to_list()
feats_all = feats.index.values

data_shap = data.copy()
cat_encoders = {}
for f in feats_cat:
    cat_encoders[f] = LabelEncoder()
    data_shap[f] = cat_encoders[f].fit_transform(data_shap[f])
    
def predict_func(X):
    X_df = pd.DataFrame(data=X, columns=feats_all)
    for f in feats_cat:
        X_df[f] = cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
    y = model.predict(X_df)[f'{feat_trgt}_prediction'].values
    y = corrector.predict(y)
    return y

In [None]:
trgt_id = 'I723'

trgt_age = data_shap.at[trgt_id, feat_trgt]
trgt_pred = data_shap.at[trgt_id, 'Prediction Unbiased']
trgt_aa = trgt_pred - trgt_age
print(trgt_age)
print(trgt_pred)

n_closest = 16
data_closest = data_shap.iloc[(data_shap['Prediction Unbiased'] - trgt_age).abs().argsort()[:n_closest]]

explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, feats_all].values)
print(explainer.expected_value)
shap_values = explainer.shap_values(data_shap.loc[[trgt_id], feats_all].values)[0]
shap_values = shap_values * (trgt_pred - trgt_age) / (trgt_pred - explainer.expected_value)

df_shap = pd.DataFrame(index=feats_all, data=shap_values, columns=[trgt_id])
df_shap.sort_values(by=trgt_id, key=abs, inplace=True)
df_shap['cumsum'] = df_shap[trgt_id].cumsum()

df_less_more = pd.DataFrame(index=df_shap.index, columns=['Меньше', 'Больше'])
df_cat_part = {}
for f_id, f in enumerate(df_less_more.index):
    if feats.at[f, 'data_type'] != 'enum':
        df_less_more.at[f, 'Меньше'] = round(scipy.stats.percentileofscore(data_closest.loc[:, f].values, data_shap.at[trgt_id, f]))
        df_less_more.at[f, 'Больше'] = 100.0 - df_less_more.at[f, 'Меньше']
    else:
        df_less_more.at[f, 'Меньше'] = np.nan
        df_less_more.at[f, 'Больше'] = np.nan
        
        f_value_counts = data_closest.loc[:, 'Пол'].value_counts()
        f_value_counts_rename = {x: cat_encoders['Пол'].inverse_transform([x])[0] for x in f_value_counts.index.astype(int).values}
        f_value_counts.rename(index=f_value_counts_rename, inplace=True)
        f_value_counts = np.rint(f_value_counts / f_value_counts.sum() * 100)
        
        df_cat_part[f_id] = {
            'name': f,
            'distribution': f_value_counts.astype(int)
        }
        if f == 'Пол':
            df_cat_part[f_id]['palette'] = {'жен': 'crimson', 'муж': 'dodgerblue'}
            
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, shared_xaxes=False, column_widths=[2.5, 1], horizontal_spacing=0.05, subplot_titles=['', "Распределение признаков у людей<br>в данном возрастном диапазоне"])
fig.add_trace(
    go.Waterfall(
        hovertext=["Хронологический возраст", "Возрастная акселерация", "Биологический возраст"],
        orientation="h",
        measure=['absolute', 'relative', 'absolute'],
        y=[-1.5, df_shap.shape[0] + 0.5, df_shap.shape[0] + 1.5],
        x=[trgt_age, trgt_aa, trgt_age+trgt_aa],
        base=0,
        text=[f"{trgt_age:0.2f}", f"+{trgt_aa:0.2f}" if trgt_aa > 0 else f"{trgt_aa:0.2f}", f"{trgt_age+trgt_aa:0.2f}"],
        textposition = "auto",
        decreasing = {"marker":{"color": "deepskyblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "crimson", "line": {"color": "black", "width": 1}}},
        totals= {"marker":{"color": "dimgray", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "dot"},
        },
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Waterfall(
        hovertext=df_shap.index.values,
        orientation="h",
        measure=["relative"] * len(feats_all),
        y=list(range(df_shap.shape[0])),
        x=df_shap[trgt_id].values,
        base=trgt_age,
        text=[f"+{x:0.2f}" if x > 0 else f"{x:0.2f}" for x in df_shap[trgt_id].values],
        textposition = "auto",
        decreasing = {"marker":{"color": "lightblue", "line": {"color": "black", "width": 1}}},
        increasing = {"marker":{"color": "lightcoral", "line": {"color": "black", "width": 1}}},
        connector={
            "mode": "between",
            "line": {"width": 1, "color": "black", "dash": "solid"},
        },
    ),
    row=1,
    col=1,
)
fig.update_traces(row=1, col=1, showlegend=False)
fig.update_yaxes(
    row=1,
    col=1,
    automargin=True,
    tickmode="array",
    tickvals=[-1.5] + list(range(df_shap.shape[0])) + [df_shap.shape[0] + 0.5, df_shap.shape[0] + 1.5],
    ticktext=["Хронологический возраст"] + [f"{feats.at[x, 'description']} = {data.at[trgt_id, x]:0.2f}" if feats.at[x, 'data_type'] != 'enum' else f"{x} = {data.at[trgt_id, x]}" for x in df_shap.index] + ["Возрастная акселерация", "Биологический возраст"],
    tickfont=dict(size=18),
)
fig.update_xaxes(
    row=1,
    col=1,
    automargin=True,
    title='Возраст',
    titlefont=dict(size=25),
    range=[
        trgt_age + df_shap['cumsum'].min() * 1.2 - 2,
        trgt_age + df_shap['cumsum'].max() * 1.2 + 2
    ],
)

fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Меньше',
        x=df_less_more.loc[df_shap.index.values, 'Меньше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='steelblue', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Меньше'],
        textposition='auto'
    ),
    row=1,
    col=2,
)
fig.add_trace(
    go.Bar(
        hovertext=df_shap.index.values,
        orientation="h",
        name='Больше',
        x=df_less_more.loc[df_shap.index.values, 'Больше'],
        y=list(range(df_shap.shape[0])),
        marker=dict(color='violet', line=dict(color="black", width=1)),
        text=df_less_more.loc[df_shap.index.values, 'Больше'],
        textposition='auto',
    ),
    row=1,
    col=2
)

for f_cat_id, f_cat_dict in df_cat_part.items():
    for f_val in f_cat_dict['distribution'].index:
        fig.add_trace(
            go.Bar(
                hovertext=[f_cat_dict['name']],
                orientation="h",
                name=f_val,
                x=[f_cat_dict['distribution'][f_val]],
                y=[f_cat_id],
                marker=dict(color=f_cat_dict['palette'][f_val], line=dict(color="black", width=1)),
                text=[f_val],
                textposition='auto',
                showlegend=False
            ),
            row=1,
            col=2
        )

fig.update_xaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_yaxes(
    row=1,
    col=2,
    automargin=True,
    showgrid=False,
    showline=False,
    zeroline=False,
    showticklabels=False,
)
fig.update_layout(barmode="stack")
fig.update_layout(
    legend=dict(
        title=dict(side="top"),
        orientation="h",
        yanchor="bottom",
        y=0.97,
        xanchor="center",
        x=0.86
    ),
)
fig.update_layout(
    title=f"Возрастная акселерация для {trgt_id}",
    titlefont=dict(size=25),
    template="none",
    width=1800,
    height=1100,
    margin=go.layout.Margin(l=120, r=80, b=50, t=50, pad=0),
)
fig.show()
fig.write_image(f"{path}/{fs}/shap_local/{trgt_id}.pdf", format="pdf")
fig.write_image(f"{path}/{fs}/shap_local/{trgt_id}.png", scale=2)
df_shap.to_excel(f"{path}/{fs}/shap_local/{trgt_id}.xlsx")