# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Best models processing

## Load data and models for all subsets of features

In [None]:
dataset = 'wo_outliers'
path = f"E:/YandexDisk/Work/bbd/immunology/005_immuno_clocks_log_and_new_data"
feat_trgt = 'Age'

datasets = {
    "full": {
        "name": "Full (736)",
        "path": f"{path}/full",
        "path_model": f"{path}/full/models/DANet/272",
        "color": "crimson",
    },
    "wo_outliers": {
        "name": "Without ouliers (635)",
        "path": f"{path}/wo_outliers",
        "path_model": f"{path}/wo_outliers/models/DANet/269",
        "color": "dodgerblue",
    },
}

for ds in datasets:
    datasets[ds]['data'] = pd.read_excel(f"{datasets[ds]['path']}/data.xlsx", index_col=0)
    datasets[ds]['feats'] = pd.read_excel(f"{datasets[ds]['path']}/feats.xlsx", index_col=0)
    datasets[ds]['results'] = pd.read_excel(f"{datasets[ds]['path_model']}/df.xlsx", index_col=0)
    datasets[ds]['metrics'] = pd.read_excel(f"{datasets[ds]['path_model']}/metrics.xlsx", index_col=0)
    datasets[ds]['shap'] = pd.read_excel(f"{datasets[ds]['path_model']}/explanation.xlsx", index_col=0)
    datasets[ds]['model'] = TabularModel.load_model(f"{datasets[ds]['path_model']}")
    datasets[ds]['corrector'] = LinearBiasCorrector()
    ds_results = datasets[ds]['results']
    datasets[ds]['corrector'].fit(ds_results.loc[ds_results['Group'] == 'Train', feat_trgt].values, ds_results.loc[ds_results['Group'] == 'Train', 'Prediction'].values)

## Plot models results

In [None]:
for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats_cnt_wo_trgt = ds_feats.index.to_list()
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    xy_min, xy_max = np.quantile(ds_results[[feat_trgt, 'Prediction Unbiased']].values.flatten(), [0.01, 0.99])
    xy_ptp = xy_max - xy_min
    
    fig, axs = plt.subplot_mosaic(
        [
            ['table', 'bar'],
            ['scatter', 'bar'],
        ],
        layout='constrained',
        figsize=(8, 6),
        height_ratios=[1, 4],
        width_ratios=[3, 1.5],
        gridspec_kw={
            # "bottom": 0.14,
            # "top": 0.95,
            # "left": 0.1,
            # "right": 0.5,
            "wspace": 0.01,
            "hspace": 0.01,
        },
    )
    
    ds_table = pd.DataFrame(index=['MAE', fr"Pearson $\mathbf{{\rho}}$", "Bias"], columns=['Train', 'Validation', 'Test'])
    for part in ['Train', 'Validation', 'Test']:
        ds_table.at['MAE', part] = f"{ds_metrics.at[part, 'mean_absolute_error_unbiased']:0.3f}"
        ds_table.at[fr"Pearson $\mathbf{{\rho}}$", part] = f"{ds_metrics.at[part, 'pearson_corrcoef_unbiased']:0.3f}"
        ds_table.at["Bias", part] = f"{ds_metrics.at[part, 'bias_unbiased']:0.3f}"

    col_defs = [
        ColumnDefinition(
            name="index",
            title='',
            textprops={"ha": "center", "weight": "bold"},
            width=2.5,
            # border="both",
        ),
        ColumnDefinition(
            name="Train",
            textprops={"ha": "left"},
            width=1.5,
            border="left",
        ),
        ColumnDefinition(
            name="Validation",
            textprops={"ha": "left"},
            width=1.5,
        ),
        ColumnDefinition(
            name="Test",
            textprops={"ha": "left"},
            width=1.5,
        )
    ]

    table = Table(
        ds_table,
        column_definitions=col_defs,
        row_dividers=True,
        footer_divider=False,
        ax=axs['table'],
        textprops={"fontsize": 8},
        row_divider_kw={"linewidth": 1, "linestyle": (0, (1, 1))},
        col_label_divider_kw={"linewidth": 1, "linestyle": "-"},
        column_border_kw={"linewidth": 1, "linestyle": "-"},
    ).autoset_fontcolors(colnames=['Train', 'Validation', 'Test'])

    kdeplot = sns.kdeplot(
        data=ds_results.loc[ds_results['Group'].isin(['Train', 'Validation']), :],
        x=feat_trgt,
        y='Prediction Unbiased',
        fill=True,
        cbar=False,
        thresh=0.05,
        color=ds_color,
        legend=False,
        ax=axs['scatter']
    )
    scatter = sns.scatterplot(
        data=ds_results.loc[ds_results['Group'] == 'Test', :],
        x=feat_trgt,
        y="Prediction Unbiased",
        linewidth=0.5,
        alpha=0.8,
        edgecolor="k",
        s=25,
        color=ds_color,
        ax=axs['scatter'],
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        y=[xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=axs['scatter']
    )
    regplot = sns.regplot(
        data=ds_results,
        x=feat_trgt,
        y='Prediction Unbiased',
        color='k',
        scatter=False,
        truncate=False,
        ax=axs['scatter']
    )
    axs['scatter'].set_xlim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylim(xy_min - 0.15 * xy_ptp, xy_max + 0.15 * xy_ptp)
    axs['scatter'].set_ylabel('Biological age')
    axs['scatter'].set_xlabel(feat_trgt)
    
    
    feats_cnt = ds_feats.index.to_list()
    df_fi = pd.DataFrame(index=feats_cnt, columns=['mean(|SHAP|)'])
    for f in feats_cnt:
        df_fi.at[f, 'mean(|SHAP|)'] = ds_shap[f].abs().mean()
    df_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    df_fi = df_fi.head(30)
    df_fi['Features'] = df_fi.index.values
    barplot = sns.barplot(
        data=df_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=ds_color,
        edgecolor='black',
        dodge=False,
        ax=axs['bar']
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', fmt='%0.2f', fontsize=12, padding=4.0)
    axs['bar'].set_ylabel('')
    
    fig.suptitle(datasets[ds]['name'], fontsize='large')
    fig.savefig(f"{datasets[ds]['path']}/model.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model.pdf", bbox_inches='tight')
    plt.close(fig)

## Plot models explainability

In [None]:
expl_type = 'current' # 'current' 'recalc_gradient' 'recalc_sampling'

for ds in datasets:
    ds_feats = datasets[ds]['feats']
    feats_wo_trgt = ds_feats.index.to_list()
    feats_cnt_wo_trgt = ds_feats.index.to_list()
    feats_cat_wo_trgt = []
    ds_data = datasets[ds]['data']
    ds_results = datasets[ds]['results']
    ds_metrics = datasets[ds]['metrics']
    ds_shap = datasets[ds]['shap']
    ds_model = datasets[ds]['model']
    ds_corrector = datasets[ds]['corrector']
    ds_color = datasets[ds]['color']
    
    if expl_type == 'recalc_gradient':
        ds_shap = ds_model.explain(ds_data, method="GradientShap", baselines="b|100000")
        ds_shap.index = ds_data.index
    elif expl_type == 'recalc_sampling':
        ds_data_shap = ds_data.copy()
        ds_cat_encoders = {}
        for f in feats_cat_wo_trgt:
            ds_cat_encoders[f] = LabelEncoder()
            ds_data_shap[f] = ds_cat_encoders[f].fit_transform(ds_data_shap[f])
        def predict_func(X):
            X_df = pd.DataFrame(data=X, columns=feats_wo_trgt)
            for f in feats_cat_wo_trgt:
                X_df[f] = ds_cat_encoders[f].inverse_transform(X_df[f].astype(int).values)
            y = ds_model.predict(X_df)[f'Возраст_prediction'].values
            y = ds_corrector.predict(y)
            return y
        explainer = shap.SamplingExplainer(predict_func, ds_data_shap.loc[:, feats_wo_trgt].values)
        print(explainer.expected_value)
        shap_values = explainer.shap_values(ds_data_shap.loc[:, feats_wo_trgt].values)
        ds_shap = pd.DataFrame(index=ds_data.index, columns=feats_wo_trgt, data=shap_values)
    
    
    ds_fi = pd.DataFrame(index=feats_wo_trgt, columns=['mean(|SHAP|)'])
    for f in feats_wo_trgt:
        ds_fi.at[f, 'mean(|SHAP|)'] = ds_shap[f].abs().mean()
    ds_fi.sort_values(['mean(|SHAP|)'], ascending=[False], inplace=True)
    
    if ds != 'inbody_mrmr_lab':
        ds_fi = ds_fi.head(30)
    ds_fi['Features'] = ds_fi.index.values
    
    sns.set_theme(style='ticks')
    fig, axs = plt.subplots(1, 2, figsize=(12, 0.6 * ds_fi.shape[0]), width_ratios=[4, 8], gridspec_kw={'wspace':0.1, 'hspace': 0.05}, sharey=True, sharex=False)
    
    barplot = sns.barplot(
        data=ds_fi,
        x='mean(|SHAP|)',
        y='Features',
        color=ds_color,
        edgecolor='black',
        dodge=False,
        ax=axs[0]
    )
    for container in barplot.containers:
        barplot.bar_label(container, label_type='edge', color='gray', fmt='%0.2f', fontsize=12, padding=4.0)
    axs[0].set_ylabel('')
    axs[0].set(yticklabels=ds_fi.index.to_list())
    
    is_colorbar = False
    f_legends = []
    for f in ds_fi.index:
        
        if ds_shap[f].abs().max() > 10:
            f_shap_ll = ds_shap[f].quantile(0.01)
            f_shap_hl = ds_shap[f].quantile(0.99)
        else:
            f_shap_ll = ds_shap[f].min()
            f_shap_hl = ds_shap[f].max()
        
        f_index = ds_shap.index[(ds_shap[f] >= f_shap_ll) & (ds_shap[f] <= f_shap_hl)].values
        f_shap = ds_shap.loc[f_index, f].values
        f_vals = ds_data.loc[f_index, f].values
        
        f_cmap = sns.color_palette("Spectral_r", as_cmap=True)
        f_norm = mcolors.Normalize(vmin=min(f_vals), vmax=max(f_vals)) 
        f_colors = {}
        for cval in f_vals:
            f_colors.update({cval: f_cmap(f_norm(cval))})

        strip = sns.stripplot(
            x=f_shap,
            y=[f]*len(f_shap),
            hue=f_vals,
            palette=f_colors,
            jitter=0.35,
            alpha=0.5,
            edgecolor='gray',
            linewidth=0.1,
            size=25 / np.sqrt(ds_results.loc[ds_results['Group'] == 'Test', :].shape[0]),
            legend=False,
            ax=axs[1],
        )
        
        if not is_colorbar:
            sm = plt.cm.ScalarMappable(cmap=f_cmap, norm=f_norm)
            sm.set_array([])
            cbar = strip.figure.colorbar(sm)
            cbar.set_label('Значения\nчисленных\nпризнаков', labelpad=-8, fontsize='large')
            cbar.set_ticks([min(f_vals), max(f_vals)])
            cbar.set_ticklabels(["Min", "Max"])
            is_colorbar = True
          
    axs[1].set_xlabel('SHAP: Влияние на предсказание модели')
    fig.savefig(f"{datasets[ds]['path']}/model_importance.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{datasets[ds]['path']}/model_importance.pdf", bbox_inches='tight')
    ds_shap.to_excel(f"{datasets[ds]['path']}/model_importance.xlsx")
    plt.close(fig)