# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Best models processing

In [None]:
path = f"E:/YandexDisk/Work/bbd/millennium/models/Электрокардиограмма (чекап)"
path_model = f"{path}/models/DANet/424"
dataset = 'Электрокардиограмма'
expl_type = 'current'
color = 'olive'

feat_trgt = 'Возраст'

data = pd.read_excel(f"{path}/data.xlsx", index_col=0)
feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0)
results = pd.read_excel(f"{path_model}/df.xlsx", index_col=0)
metrics = pd.read_excel(f"{path_model}/metrics.xlsx", index_col=0)
df_shap = pd.read_excel(f"{path_model}/explanation.xlsx", index_col=0)
model = TabularModel.load_model(f"{path_model}")
corrector = LinearBiasCorrector()
corrector.fit(results.loc[results['Group'] == 'Train', feat_trgt].values, results.loc[results['Group'] == 'Train', 'Prediction'].values)

xy_min, xy_max = np.quantile(results[[feat_trgt, 'Prediction Unbiased']].values.flatten(), [0.01, 0.99])
xy_ptp = xy_max - xy_min


sns.set_theme(style='ticks')
fig = plt.figure(
    figsize=(8, 5 + 1.5 + 0.15 * feats.shape[0] + 1.5 + 0.15 * feats.shape[0]),
    layout="constrained"
)
subfigs = fig.subfigures(
    nrows=3,
    ncols=1,
    height_ratios=[5, 1.5 + 0.15 * feats.shape[0], 1.5 + 0.15 * feats.shape[0]],
    wspace=0.001,
    hspace=0.001,
)

axs = subfigs[0].subplot_mosaic(
    [
        ['table', 'table'],
        ['scatter', 'violin'],
    ],
    # figsize=(6, 1.5 + 6),
    height_ratios=[1, 4],
    width_ratios=[3, 1.5],
    gridspec_kw={
        # "bottom": 0.14,
        # "top": 0.95,
        # "left": 0.1,
        # "right": 0.5,
        "wspace": 0.01,
        "hspace": 0.01,
    },
)

df_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
for part in ['Train', 'Validation', 'Test']:
    df_table.at['MAE', part] = f"{metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
    df_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
    df_table.at["Bias", part] = f"{metrics.at[part, 'bias_unbiased']:0.3f}"

col_defs = [
    ColumnDefinition(
        name="index",
        title='',
        textprops={"ha": "center", "weight": "bold"},
        width=2.5,
        # border="both",
    ),
    ColumnDefinition(
        name="Train",
        textprops={"ha": "left"},
        width=1.5,
        border="left",
    ),
    ColumnDefinition(
        name="Validation",
        textprops={"ha": "left"},
        width=1.5,
    ),
    ColumnDefinition(
        name="Test",
        textprops={"ha": "left"},
        width=1.5,
    )
]

table = Table(
    df_table,
    column_definitions=col_defs,
    row_dividers=True,
    footer_divider=False,
    ax=axs['table'],
    textprops={"fontsize": 8},
    row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
    col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
    column_border_kw={"linewidth": 1, "linestyle": "-"},
).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

kdeplot = sns.kdeplot(
    data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
    x=feat_trgt,
    y='Prediction Unbiased',
    fill=True,
    cbar=False,
    thresh=0.05,
    color=color,
    legend=False,
    ax=axs['scatter']
)
scatter = sns.scatterplot(
    data=results.loc[results['Group'] == 'Test', :],
    x=feat_trgt,
    y="Prediction Unbiased",
    linewidth=0.5,
    alpha=0.8,
    edgecolor="k",
    s=25,
    color=color,
    ax=axs['scatter'],
)
bisect = sns.lineplot(
    x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
    y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
    linestyle='--',
    color='black',
    linewidth=1.0,
    ax=axs['scatter']
)
regplot = sns.regplot(
    data=results,
    x=feat_trgt,
    y='Prediction Unbiased',
    color='k',
    scatter=False,
    truncate=False,
    ax=axs['scatter']
)
axs['scatter'].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
axs['scatter'].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
axs['scatter'].set_ylabel("Биологический возраст")
axs['scatter'].set_xlabel("Возраст")

violin = sns.violinplot(
    data=results.loc[results['Group'].isin(['Train', 'Validation']), :],
    x=[0] * results.loc[results['Group'].isin(['Train', 'Validation']), :].shape[0],
    y='Error Unbiased',
    color=make_rgb_transparent(mcolors.to_rgb(color), (1, 1, 1), 0.5),
    density_norm='width',
    saturation=0.75,
    linewidth=1.0,
    ax=axs['violin'],
    legend=False,
)
swarm = sns.swarmplot(
    data=results.loc[results['Group'] == 'Test', :],
    x=[0] * results.loc[results['Group'] == 'Test', :].shape[0],
    y='Error Unbiased',
    color=color,
    linewidth=0.5,
    ax=axs['violin'],
    size= 50 / np.sqrt(results.loc[results['Group'] == 'Test', :].shape[0]),
    legend=False,
)
axs['violin'].set_ylabel('Возрастная акселерация')
axs['violin'].set_xlabel('')
axs['violin'].set(xticklabels=[]) 
axs['violin'].set(xticks=[]) 

ax_heatmap = subfigs[1].subplots()
df_corr = pd.DataFrame(index=feats.index.to_list(), columns=['rho'])
for f in tqdm(feats.index.to_list()):
    df_tmp = data.loc[:, [feat_trgt, f]].dropna(axis=0, how='any')
    if df_tmp.shape[0] > 1:
        vals_1 = df_tmp.loc[:, feat_trgt].values
        vals_2 = df_tmp.loc[:, f].values
        df_corr.at[f, 'rho'], _ = scipy.stats.pearsonr(vals_1, vals_2)
df_corr.dropna(axis=0, how='any', inplace=True)
df_corr.insert(1, "abs(rho)", df_corr['rho'].abs())
df_corr.sort_values(["abs(rho)"], ascending=[False], inplace=True)
feats_cnt_wo_age = df_corr.index.to_list()
feats_cnt = ['Возраст'] + feats_cnt_wo_age
df_corr = df_corr.apply(pd.to_numeric)
heatmap = sns.heatmap(
    df_corr.loc[:, ['rho']],
    annot=True,
    fmt=".2f",
    vmin=-1.0,
    vmax=1.0,
    cmap='coolwarm',
    linewidth=0.1,
    linecolor='black',
    #annot_kws={"fontsize": 15},
    cbar_kws={
        # "shrink": 0.9,
        # "aspect": 30,
        #'fraction': 0.046, 
        #'pad': 0.04,
    },
    ax=ax_heatmap
)
heatmap_pos = ax_heatmap.get_position()
ax_heatmap.figure.axes[-1].set_position([heatmap_pos.x1 + 0.05, heatmap_pos.y0, 0.1, heatmap_pos.height])
ax_heatmap.figure.axes[-1].set_ylabel(r"Pearson $\rho$")
for spine in ax_heatmap.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax_heatmap.set_xlabel('')
ax_heatmap.set_ylabel('')
ax_heatmap.set(xticklabels=[])
ax_heatmap.set(xticks=[])


if expl_type == 'recalc_gradient':
    df_shap = model.explain(data, method="GradientShap", baselines="b|100000")
    df_shap.index = data.index
elif expl_type == 'recalc_sampling':
    ds_data_shap = data.copy()
    ds_cat_encoders = {}
    for f in feats.index:
        ds_cat_encoders[f] = LabelEncoder()
        ds_data_shap[f] = ds_cat_encoders[f].fit_transform(ds_data_shap[f])
    def predict_func(X):
        X_df = pd.DataFrame(data=X, columns=feats.index.to_list())
        for f in feats.index:
            X_df[f] = ds_cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
        y = model.predict(X_df)[f'{feat_trgt}_prediction'].values
        y = corrector.predict(y)
        return y
    explainer = shap.SamplingExplainer(predict_func, ds_data_shap.loc[:, feats.index.to_list()].values)
    print(explainer.expected_value)
    shap_values = explainer.shap_values(ds_data_shap.loc[:, feats.index.to_list()].values)
    df_shap = pd.DataFrame(index=data.index, columns=feats.index.to_list(), data=shap_values)

ds_fi = pd.DataFrame(index=feats.index.to_list(), columns=['mean(|SHAP|)'])
for f in feats.index.to_list():
    ds_fi.at[f, 'mean(|SHAP|)'] = df_shap[f].abs().mean()
ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
ds_fi['Features'] = ds_fi.index.values


axs_importance = subfigs[2].subplots(1, 2, width_ratios=[4, 8], gridspec_kw={'wspace':0.02, 'hspace': 0.02}, sharey=True, sharex=False)

barplot = sns.barplot(
    data=ds_fi,
    x='mean(|SHAP|)',
    y='Features',
    color=color,
    edgecolor='black',
    dodge=False,
    ax=axs_importance[0]
)
for container in barplot.containers:
    barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=4.0)
axs_importance[0].set_ylabel('')
axs_importance[0].set(yticklabels=ds_fi.index.to_list())

is_colorbar = False
f_legends = []
for f in ds_fi.index:
    
    if df_shap[f].abs().max() > 10:
        f_shap_ll = df_shap[f].quantile(0.01)
        f_shap_hl = df_shap[f].quantile(0.99)
    else:
        f_shap_ll = df_shap[f].min()
        f_shap_hl = df_shap[f].max()
    
    f_index = df_shap.index[(df_shap[f] >= f_shap_ll) & (df_shap[f] <= f_shap_hl)].values
    f_shap = df_shap.loc[f_index, f].values
    f_vals = data.loc[f_index, f].values
    
    f_cmap = sns.color_palette("Spectral_r", as_cmap=True)
    f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
    f_colors = {}
    for cval in f_vals:
        f_colors.update({cval: f_cmap(f_norm(cval))})

    strip = sns.stripplot(
        x=f_shap,
        y=[f]*len(f_shap),
        hue=f_vals,
        palette=f_colors,
        jitter=0.35,
        alpha=0.5,
        edgecolor='gray',
        linewidth=0.1,
        size=25 / np.sqrt(results.loc[results['Group'] == 'Test', :].shape[0]),
        legend=False,
        ax=axs_importance[1],
    )
    
    if not is_colorbar:
        sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
        sm.set_array([])
        cbar = strip.figure.colorbar(sm)
        # cbar.set_label('Значения\nчисленных\nпризнаков', labelpad=-8, fontsize='large')
        cbar.set_ticks([min(f_vals), max(f_vals)])
        cbar.set_ticklabels(["Min", "Max"])
        is_colorbar = True 

axs_importance[1].set_xlabel('SHAP')
df_shap.to_excel(f"{path}/model_importance.xlsx")

fig.suptitle(dataset, fontsize='large')
fig.savefig(f"{path}/model.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/model.pdf", bbox_inches='tight')
plt.close(fig)