# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
from src.pt.hyper_opt import train_hyper_opt
from src.sa.hyper_opt import train_hyper_opt_sa_regression
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from scipy import stats
from regression_bias_corrector import LinearBiasCorrector
import optuna
from sklearn.preprocessing import LabelEncoder
from plottable import ColumnDefinition, Table
from plottable.plots import bar
from plottable.cmap import normed_cmap, centered_cmap
import matplotlib.lines as mlines
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import scipy.stats
from omegaconf import OmegaConf
import lightgbm


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data

In [None]:
feats_set = 'small/Оценка состава тела/M'

path = f"E:/YandexDisk/Work/bbd/millennium/models/{feats_set}"
path_ckpts = f"E:/Git/bbs/notebooks/millennium/pt/{feats_set}"
path_configs = "E:/Git/bbs/notebooks/millennium/configs"

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337
tst_split_id = 24

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337
val_fold_id = 1

data = pd.read_excel(f"{path}/data.xlsx", index_col=0)
df_feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0)

feat_trgt = 'Возраст'
feats_cnt = df_feats.index.to_list()
feats_cat = []
feats = list(feats_cnt) + list(feats_cat)

# Load stratification

In [None]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)

# Train, Validation, Test selection

In [None]:
split_dict = samples[tst_split_id]

test = data.loc[split_dict['test'], feats + [feat_trgt]]
train = data.loc[split_dict['trains'][val_fold_id], feats + [feat_trgt]]
validation = data.loc[split_dict['validations'][val_fold_id], feats + [feat_trgt]]

# Optuna PyTorch Tabular

## Models setup

In [None]:
seed_target = 451  # 1337 42 451 1984 1899 1408

models_runs = {
    'GANDALF': {
        'config': GANDALFConfig,
        'n_trials': 1024,
        'seed': seed_target,
        'n_startup_trials': 256,
        'n_ei_candidates': 16
    },
    # 'FTTransformer': {
    #     'config': FTTransformerConfig,
    #     'n_trials': 512,
    #     'seed': seed_target,
    #     'n_startup_trials': 128,
    #     'n_ei_candidates': 16
    # },
    # 'DANet': {
    #     'config': DANetConfig,
    #     'n_trials': 512,
    #     'seed': seed_target,
    #     'n_startup_trials': 256,
    #     'n_ei_candidates': 32
    # },
    # 'CategoryEmbeddingModel': {
    #     'config': CategoryEmbeddingModelConfig,
    #     'n_trials': 256,
    #     'seed': seed_target,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # },
    # 'TabNetModel': {
    #     'config': TabNetModelConfig,
    #     'n_trials': 256,
    #     'seed': seed_target,
    #     'n_startup_trials': 64,
    #     'n_ei_candidates': 16
    # }
}

## Training

In [None]:
dfs_models = []

for model_name, model_run in models_runs.items():

    model_config_name = model_run['config']
    n_trials = model_run['n_trials']
    seed = model_run['seed']
    n_startup_trials = model_run['n_startup_trials']
    n_ei_candidates = model_run['n_ei_candidates']

    data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
    data_config['target'] = [feat_trgt]
    data_config['continuous_cols'] = [str(x) for x in feats_cnt]
    data_config['categorical_cols'] = feats_cat
    trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
    trainer_config['checkpoints_path'] = path_ckpts
    optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

    lr_find_min_lr = 1e-8
    lr_find_max_lr = 10
    lr_find_num_training = 128
    lr_find_mode = "exponential"
    lr_find_early_stop_threshold = 8.0

    trainer_config['seed'] = seed
    trainer_config['checkpoints'] = 'valid_loss'
    trainer_config['load_best'] = True
    trainer_config['auto_lr_find'] = False

    model_config_default = read_parse_config(f"{path_configs}/models/{model_name}Config.yaml", model_config_name)
    tabular_model_default = TabularModel(
        data_config=data_config,
        model_config=model_config_default,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=False,
    )
    datamodule = tabular_model_default.prepare_dataloader(train=train, validation=validation, seed=seed)

    opt_parts = ['test', 'validation']
    opt_metrics = [('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('pearson_corrcoef', 'maximize')]
    # opt_metrics = [('mean_absolute_error', 'minimize')]
    opt_directions = []
    for part in opt_parts:
        for metric_pair in opt_metrics:
            opt_directions.append(f"{metric_pair[1]}")

    trials_results = []

    study = optuna.create_study(
        study_name=model_name,
        sampler=optuna.samplers.TPESampler(
            n_startup_trials=n_startup_trials,
            n_ei_candidates=n_ei_candidates,
            seed=seed,
        ),
        directions=opt_directions
    )
    study.optimize(
        func=lambda trial: train_hyper_opt(
            trial=trial,
            trials_results=trials_results,
            opt_metrics=opt_metrics,
            opt_parts=opt_parts,
            model_config_default=model_config_default,
            data_config_default=data_config,
            optimizer_config_default=optimizer_config,
            trainer_config_default=trainer_config,
            experiment_config_default=None,
            train=train,
            validation=validation,
            test=test,
            datamodule=datamodule,
            min_lr=lr_find_min_lr,
            max_lr=lr_find_max_lr,
            num_training=lr_find_num_training,
            mode=lr_find_mode,
            early_stop_threshold=lr_find_early_stop_threshold
        ),
        n_trials=n_trials,
        show_progress_bar=True
    )

    fn_trials = (
        f"model({model_name})_"
        f"trials({n_trials}_{seed}_{n_startup_trials}_{n_ei_candidates})_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
    )

    df_trials = pd.DataFrame(trials_results)
    df_trials['split_id'] = tst_split_id
    df_trials['fold_id'] = val_fold_id
    df_trials["train_more"] = False
    df_trials.loc[(df_trials["train_loss"] > df_trials["test_loss"]) | (
            df_trials["train_loss"] > df_trials["validation_loss"]), "train_more"] = True
    df_trials["validation_test_mean_loss"] = (df_trials["validation_loss"] + df_trials["test_loss"]) / 2.0
    df_trials["train_validation_test_mean_loss"] = (df_trials["train_loss"] + df_trials["validation_loss"] + df_trials["test_loss"]) / 3.0
    df_trials.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{trainer_config['checkpoints_path']}/{fn_trials}.xlsx")

    dfs_models.append(df_trials)

df_models = pd.concat(dfs_models, ignore_index=True)
df_models.insert(0, 'Selected', 0)
fn = (
    f"models_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)
df_models.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{trainer_config['checkpoints_path']}/{fn}.xlsx")


## Best models analysis

In [None]:
explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

models_type = 'GANDALF' # 'DANet' # 'GANDALF'
models_ids = [
18,
31,
171,
433,
473,
516,
540,
762,
770,
794,
831,
834,
837,
840,
842,
852,
859,
862,
865,
866,
867,
869,
886,
890,
907,
914,
918,
987,
]
models_ids = sorted(list(set(models_ids)))

ids_shap_check = test.index.values
bkg_count = 10

feat_trgt = 'Возраст'
data_for_shap = data.copy()
data_for_shap.loc[train.index, 'Group'] = 'Train'
data_for_shap.loc[validation.index, 'Group'] = 'Validation'
data_for_shap.loc[test.index, 'Group'] = 'Test'

feats_corr = pd.DataFrame(index=feats, columns=['Correlation'])
for f in feats:
    feats_corr.at[f, 'Correlation'], _ = stats.pearsonr(data.loc[:, f].values, data.loc[:, feat_trgt].values)

df_sweeps = pd.read_excel(
    (
        f"{path_ckpts}/"
        f"models_"
        f"tst({tst_split_id})_"
        f"val({val_fold_id})"
        f".xlsx"
    ),
    index_col=0
)

colors_groups = {
    'Train': 'chartreuse',
    'Validation': 'dodgerblue',
    'Test': 'crimson',
}

path_to_candidates = f"{path_ckpts}/candidates/{models_type}"
pathlib.Path(path_to_candidates).mkdir(parents=True, exist_ok=True)

for model_id in models_ids:

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_to_candidates}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_to_candidates}/{model_id}", dirs_exist_ok=True)
    
    df = data.loc[:, [feat_trgt]]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = model.predict(data)
    df['Error'] = df['Prediction'] - df[feat_trgt]
    corrector = LinearBiasCorrector()
    corrector.fit(df.loc[df['Group'] == 'Train', feat_trgt].values, df.loc[df['Group'] == 'Train', 'Prediction'].values)
    df['Prediction Unbiased'] = corrector.predict(df['Prediction'].values)
    df['Error Unbiased'] = df['Prediction Unbiased'] - df[feat_trgt]
    df.to_excel(f"{path_to_candidates}/{model_id}/df.xlsx")
    
    def predict_func(X):
        X_df = pd.DataFrame(data=X, columns=feats)
        y = model.predict(X_df)[f'{feat_trgt}_prediction'].values
        y = corrector.predict(y)
        return y
    
    data_for_shap['Prediction'] = df['Prediction']
    data_for_shap['Error'] = df['Error']
    data_for_shap['Prediction Unbiased'] = df['Prediction Unbiased']
    data_for_shap['Error Unbiased'] = df['Error Unbiased']

    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=[
            'mean_absolute_error', 'pearson_corrcoef', 'bias',
            'mean_absolute_error_unbiased', 'pearson_corrcoef_unbiased', 'bias_unbiased'
        ]
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        pred_unbiased = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction Unbiased'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, feat_trgt].values.astype(np.float32))
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
        df_metrics.at[group, 'mean_absolute_error_unbiased'] = mean_absolute_error(pred_unbiased, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef_unbiased'] = pearson_corrcoef(pred_unbiased, real).numpy()
        df_metrics.at[group, 'bias_unbiased'] = np.mean(df.loc[df['Group'] == group, 'Error Unbiased'].values)
    df_metrics.to_excel(f"{path_to_candidates}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    mae = df_metrics.at['Test', 'mean_absolute_error_unbiased']
    rho = df_metrics.at['Test', 'pearson_corrcoef_unbiased']
    
    
    
    
    
    df_correspondence = pd.DataFrame()
    for sample_id in ids_shap_check:
        
        trgt_age = data_for_shap.at[sample_id, feat_trgt]
        trgt_pred_raw = data_for_shap.at[sample_id, 'Prediction Unbiased']
        
        data_closest = data_for_shap.loc[data_for_shap['Error Unbiased'].abs() < mae * rho * feats_corr['Correlation'].abs().max(), :]
        data_closest = data_closest.iloc[(data_closest['Prediction Unbiased'] - trgt_age).abs().argsort()[:bkg_count]]
        
        print(f"Background count: {data_closest.shape[0]}")
        print(f"Background min diff: {(data_closest['Prediction Unbiased'] - trgt_age).min()}")
        print(f"Background max diff: {(data_closest['Prediction Unbiased'] - trgt_age).max()}")
        
        explainer = shap.SamplingExplainer(predict_func, data_closest.loc[:, feats].values)
        shap_values = explainer.shap_values(data_for_shap.loc[[sample_id], feats].values)[0]
        shap_values = shap_values * (trgt_pred_raw - trgt_age) / (trgt_pred_raw - explainer.expected_value)
        print(f"SHAP values difference: {sum(shap_values) - (trgt_pred_raw - trgt_age)}")
        
        # SHAP values correction 1
        shap_corr_thld = 3.0
        shap_corr_to = 1.0
        shap_mean_abs = np.mean(np.abs(shap_values))
        if shap_mean_abs > shap_corr_thld and abs(sum(shap_values)) < shap_corr_thld:
            print('SHAP values correction 1')
            shap_pos_ids = np.argwhere(shap_values >= 0).ravel()
            shap_neg_ids = np.argwhere(shap_values < 0).ravel()
            
            shap_pos_sum_abs = np.sum(np.abs(shap_values[shap_pos_ids]))
            shap_neg_sum_abs = np.sum(np.abs(shap_values[shap_neg_ids]))
            
            shap_sum_abs_from = np.sum(np.abs(shap_values))
            shap_sum_abs_to = shap_corr_to * len(shap_values)
            
            shap_corr_diff = shap_sum_abs_from - shap_sum_abs_to
            
            for pos_id in shap_pos_ids:
                curr_part = abs(shap_values[pos_id]) / shap_pos_sum_abs
                shap_values[pos_id] -= curr_part * shap_corr_diff * 0.5
            for neg_id in shap_neg_ids:
                curr_part = abs(shap_values[neg_id]) / shap_neg_sum_abs
                shap_values[neg_id] += curr_part * shap_corr_diff * 0.5
                
        # SHAP values correction 2
        shap_corr_max_thld = 1.0
        shap_corr_max_to = 0.95
        shap_max_abs = np.max(np.abs(shap_values))
        shap_abs_sum = np.abs(np.sum(shap_values))
        if shap_max_abs > shap_corr_max_thld * shap_abs_sum:
            print('SHAP values correction 2')
            shap_pos_ids = np.argwhere(shap_values >= 0).ravel()
            shap_neg_ids = np.argwhere(shap_values < 0).ravel()
            
            shap_pos_sum_abs = np.sum(np.abs(shap_values[shap_pos_ids]))
            shap_neg_sum_abs = np.sum(np.abs(shap_values[shap_neg_ids]))
            
            shap_corr_diff = (shap_max_abs - shap_corr_max_thld) * shap_corr_max_to
            
            for pos_id in shap_pos_ids:
                curr_part = abs(shap_values[pos_id]) / shap_pos_sum_abs
                shap_values[pos_id] -= curr_part * shap_corr_diff
            for neg_id in shap_neg_ids:
                curr_part = abs(shap_values[neg_id]) / shap_neg_sum_abs
                shap_values[neg_id] += curr_part * shap_corr_diff
        
        df_comp = pd.DataFrame(index=feats, columns=['SHAP', 'Values', 'Correlation', 'Percentile', 'Consistent'])
        df_comp['SHAP'] = shap_values
        df_comp.sort_values(by='SHAP', key=abs, inplace=True)
        df_comp.loc[df_comp.index.values, 'Values'] = data_for_shap.loc[sample_id, df_comp.index.values].values
        df_comp.loc[df_comp.index.values, 'Correlation'] = feats_corr.loc[df_comp.index.values, 'Correlation']
        for f_id, f in enumerate(df_comp.index.values):
            df_comp.at[f, 'Percentile'] = stats.percentileofscore(data_closest.loc[:, f].values, data_for_shap.at[sample_id, f])
            if (
                ((df_comp.at[f, 'Correlation'] > 0) & (df_comp.at[f, 'SHAP'] > 0) & (df_comp.at[f, 'Percentile'] > 55)) or \
                ((df_comp.at[f, 'Correlation'] > 0) & (df_comp.at[f, 'SHAP'] < 0) & (df_comp.at[f, 'Percentile'] < 45)) or \
                ((df_comp.at[f, 'Correlation'] < 0) & (df_comp.at[f, 'SHAP'] > 0) & (df_comp.at[f, 'Percentile'] < 45)) or \
                ((df_comp.at[f, 'Correlation'] < 0) & (df_comp.at[f, 'SHAP'] < 0) & (df_comp.at[f, 'Percentile'] > 55))
                ):
                df_comp.at[f, 'Consistent'] = 1
            else:
                df_comp.at[f, 'Consistent'] = 0

            df_correspondence.at[sample_id, f"{f} Consistent"] = df_comp.at[f, 'Consistent']
            df_correspondence.at[sample_id, f"{f} Order"] = f_id
            df_correspondence.at[sample_id, f"{f} Percentile"] = df_comp.at[f, 'Percentile']
            df_correspondence.at[sample_id, f"{f} SHAP"] = df_comp.at[f, 'SHAP']
            df_correspondence.at[sample_id, f"{f} Value"] = df_comp.at[f, 'Values']
            df_correspondence.at[sample_id, f"{f} Correlation"] = df_comp.at[f, 'Correlation']
            
    for n_top_feats in np.arange(1, len(feats) + 1):
        df_sweeps.at[model_id, f"SHAP top-{n_top_feats}"] = 0.0
        for f in feats:
            df_sweeps.at[model_id, f"SHAP top-{n_top_feats}"] += df_correspondence.loc[df_correspondence[f"{f} Order"] > len(feats) - n_top_feats - 1, f"{f} Consistent"].sum()
        df_sweeps.at[model_id, f"SHAP top-{n_top_feats}"] /= (n_top_feats * len(ids_shap_check))
        
    df_correspondence.to_excel(f"{path_to_candidates}/{model_id}/correspondence.xlsx")
    
    
    
    
    xy_min = df[[feat_trgt, 'Prediction']].min().min()
    xy_max = df[[feat_trgt, 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    
    xy_min_unbiased = df[[feat_trgt, 'Prediction Unbiased']].min().min()
    xy_max_unbiased = df[[feat_trgt, 'Prediction Unbiased']].max().max()
    xy_ptp_unbiased = xy_max_unbiased - xy_min_unbiased
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction Unbiased",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        y=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    ax.set_ylim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))   
    scatter = sns.scatterplot(
        data=df,
        x=feat_trgt,
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error Unbiased', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error_unbiased']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef_unbiased']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias_unbiased']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error Unbiased',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    try:
        explanation = model.explain(data, method=explain_method, baselines=explain_baselines)
        explanation.index = data.index
        explanation.to_excel(f"{path_to_candidates}/{model_id}/explanation.xlsx")
        
        # sns.set_theme(style='whitegrid')
        # fig = shap.summary_plot(
        #     shap_values=explanation.loc[:, feats].values,
        #     features=data.loc[:, feats].values,
        #     feature_names=feats,
        #     max_display=explain_n_feats_to_plot,
        #     plot_type="violin",
        #     show=False,
        # )
        # plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
        # plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
        # plt.close(fig)
        
        sns.set_theme(style='ticks')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="bar",
            show=False,
            plot_size=[12,8]
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass
    
df_sweeps.loc[models_ids, :].style.background_gradient(cmap="RdYlGn_r").to_excel(f"{path_to_candidates}/selected.xlsx")

# Optuna Stand Alone

## Training

In [None]:
model_name = 'LightGBM' # 'ElasticNet'

model_config = OmegaConf.load(f"{path_configs}/models/{model_name}.yaml")
model_config = OmegaConf.to_container(model_config, resolve=True)

pathlib.Path(f"{path_ckpts}").mkdir(parents=True, exist_ok=True)

opt_parts = ['test', 'validation']
opt_metrics = [('pearson_corrcoef', 'maximize')]
# opt_metrics = [('mean_absolute_error', 'minimize'), ('pearson_corrcoef', 'maximize')]
# opt_metrics = [('pearson_corrcoef', 'maximize')]
# opt_metrics = [('mean_absolute_error', 'minimize')]
opt_directions = []
for part in opt_parts:
    for metric_pair in opt_metrics:
        opt_directions.append(f"{metric_pair[1]}")

n_trials = 512
seed = 1337
n_startup_trials = 256
n_ei_candidates = 32

trials_results = []

study = optuna.create_study(
    study_name=model_name,
    sampler=optuna.samplers.TPESampler(
        n_startup_trials=n_startup_trials,
        n_ei_candidates=n_ei_candidates,
        seed=seed,
    ),
    directions=opt_directions
)
study.optimize(
    func=lambda trial: train_hyper_opt_sa_regression(
        trial=trial,
        trials_results=trials_results,
        opt_metrics=opt_metrics,
        opt_parts=opt_parts,
        model_config_default=model_config,
        train=train,
        validation=validation,
        test=test,
        features=feats_cnt,
        target=feat_trgt,
        save_dir=f"{path_ckpts}"
    ),
    n_trials=n_trials,
    show_progress_bar=True
)

fn_trials = (
    f"{model_name}_"
    f"trials({n_trials}_{seed}_{n_startup_trials}_{n_ei_candidates})_"
    f"tst({tst_split_id})_"
    f"val({val_fold_id})"
)

df_trials = pd.DataFrame(trials_results)
df_trials['split_id'] = tst_split_id
df_trials['fold_id'] = val_fold_id
df_trials.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{path_ckpts}/{fn_trials}.xlsx")

## Best models

In [None]:
explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

models_ids = [

]
models_ids = sorted(list(set(models_ids)))

df_sweeps = pd.read_excel(
    (
        f"{path_ckpts}/"
        f"progress"
        f".xlsx"
    ),
    index_col=0
)

path_to_candidates = f"{path_ckpts}/candidates"
pathlib.Path(path_to_candidates).mkdir(parents=True, exist_ok=True)
df_sweeps.loc[models_ids, :].to_excel(f"{path_to_candidates}/selected.xlsx")

for model_id in models_ids:
    
    model_dir = f"{path_ckpts}/elastic_net_{df_sweeps.at[model_id, 'alpha']:0.4e}"
    model = pickle.load(open(f"{model_dir}/model.pkl", 'rb'))
    pathlib.Path(f"{path_to_candidates}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_to_candidates}/{model_id}", dirs_exist_ok=True)
    
    df = data.loc[:, [feat_trgt]]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = model.predict(data[feats_cnt].values)
    df['Error'] = df['Prediction'] - df[feat_trgt]
    corrector = LinearBiasCorrector()
    corrector.fit(df.loc[df['Group'] == 'Train', feat_trgt].values, df.loc[df['Group'] == 'Train', 'Prediction'].values)
    df['Prediction Unbiased'] = corrector.predict(df['Prediction'].values)
    df['Error Unbiased'] = df['Prediction Unbiased'] - df[feat_trgt]
    df.to_excel(f"{path_to_candidates}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=[
            'mean_absolute_error', 'pearson_corrcoef', 'bias',
            'mean_absolute_error_unbiased', 'pearson_corrcoef_unbiased', 'bias_unbiased'
        ]
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        pred_unbiased = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction Unbiased'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, feat_trgt].values.astype(np.float32))
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
        df_metrics.at[group, 'mean_absolute_error_unbiased'] = mean_absolute_error(pred_unbiased, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef_unbiased'] = pearson_corrcoef(pred_unbiased, real).numpy()
        df_metrics.at[group, 'bias_unbiased'] = np.mean(df.loc[df['Group'] == group, 'Error Unbiased'].values)
    df_metrics.to_excel(f"{path_to_candidates}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    xy_min = df[[feat_trgt, 'Prediction']].min().min()
    xy_max = df[[feat_trgt, 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    
    xy_min_unbiased = df[[feat_trgt, 'Prediction Unbiased']].min().min()
    xy_max_unbiased = df[[feat_trgt, 'Prediction Unbiased']].max().max()
    xy_ptp_unbiased = xy_max_unbiased - xy_min_unbiased
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction Unbiased",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        y=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    ax.set_ylim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))   
    scatter = sns.scatterplot(
        data=df,
        x=feat_trgt,
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error Unbiased', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error_unbiased']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef_unbiased']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias_unbiased']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error Unbiased',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    def predict_func(X):        
        y = model.predict(X)
        y = corrector.predict(y)
        return y
    
    explainer = shap.SamplingExplainer(predict_func, data[feats_cnt].values)
    print(explainer.expected_value)
    shap_values = explainer.shap_values( data[feats_cnt].values)
    
    sns.set_theme(style='ticks')
    fig = shap.summary_plot(
        shap_values=shap_values,
        features=data[feats_cnt].values,
        feature_names=feats_cnt,
        max_display=25,
        plot_type="bar",
        show=False,
        plot_size=[14, 10]
    )
    plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='ticks')
    fig = shap.summary_plot(
        shap_values=shap_values,
        features=data[feats_cnt].values,
        feature_names=feats_cnt,
        max_display=25,
        plot_type="violin",
        show=False,
        plot_size=[14, 10]
    )
    plt.savefig(f"{path_to_candidates}/{model_id}/explain_violin.png", bbox_inches='tight', dpi=200)
    plt.savefig(f"{path_to_candidates}/{model_id}/explain_violin.pdf", bbox_inches='tight')
    plt.close(fig)
    