# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
import torch
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib


# Load data

In [None]:
epi_dataa_type = 'no_harm'

selection_method = 'f_regression' # 'f_regression' 'spearman'
n_feats = 100

imm = 'CXCL9'

test_split_id = 2
val_n_splits = 4
val_random_state = 1337
val_fold_id = 0

path_data = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{epi_dataa_type}/{selection_method}_{n_feats}/{imm}"
pathlib.Path(f"{path_data}/pytorch_tabular/trials").mkdir(parents=True, exist_ok=True)
path_configs = "D:/Work/bbs/notebooks/immunology/003_EpImAge/immuno_regression_configs"
data = pd.read_excel(f"{path_data}/data.xlsx", index_col=0)
feats = pd.read_excel(f"{path_data}/feats_con.xlsx", index_col=0).index.values.tolist()

for fold_id in range(val_n_splits):
    data[f"Fold_{fold_id}"] = data[f"Split_{test_split_id}"]

stratify_cat_parts = {
    'ctrl_central': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Central') & (data[f"Split_{test_split_id}"] == 'trn_val')].values,
    'ctrl_yakutia': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Yakutia') & (data[f"Split_{test_split_id}"] == 'trn_val')].values,
    'esrd': data.index[(data['Status'] == 'ESRD') & (data[f"Split_{test_split_id}"] == 'trn_val')].values,
}
for part, ids in stratify_cat_parts.items():
    print(f"{part}: {len(ids)}")
    con = data.loc[ids, 'Age'].values
    ptp = np.ptp(con)
    num_bins = 5
    bins = np.linspace(np.min(con) - 0.1 * ptp, np.max(con) + 0.1 * ptp, num_bins + 1)
    binned = np.digitize(con, bins) - 1
    unique, counts = np.unique(binned, return_counts=True)
    occ = dict(zip(unique, counts))
    k_fold = RepeatedStratifiedKFold(
        n_splits=val_n_splits,
        n_repeats=1,
        random_state=val_random_state
    )
    splits = k_fold.split(X=ids, y=binned, groups=binned)
    
    for fold_id, (ids_trn, ids_val) in enumerate(splits):
        data.loc[ids[ids_trn], f"Fold_{fold_id}"] = "trn"
        data.loc[ids[ids_val], f"Fold_{fold_id}"] = "val"

test = data.loc[data[f"Split_{test_split_id}"] == "tst", feats + [f"{imm}_log"]]
train_validation = data.loc[data[f"Split_{test_split_id}"] == "trn_val", feats + [f"{imm}_log"] + [f"Fold_{i}" for i in range(val_n_splits)]]
train_only = data.loc[data[f"Fold_{val_fold_id}"] == "trn", feats + [f"{imm}_log"]]
validation_only = data.loc[data[f"Fold_{val_fold_id}"] == "val", feats + [f"{imm}_log"]]
cv_indexes = [
    (
        np.where(train_validation.index.isin(train_validation.index[train_validation[f"Fold_{i}"] == 'trn']))[0],
        np.where(train_validation.index.isin(train_validation.index[train_validation[f"Fold_{i}"] == 'val']))[0],
    )
    for i in range(val_n_splits)
]

trains = {}
validations = {}
for fold_id in range(val_n_splits):
    trains[fold_id] = data.loc[data[f"Fold_{fold_id}"] == "trn", feats + [f"{imm}_log"]]
    validations[fold_id] = data.loc[data[f"Fold_{fold_id}"] == "val", feats + [f"{imm}_log"]]

## Load non-model configs

In [None]:
data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = [f"{imm}_log"]
data_config['continuous_cols'] = feats
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints_path'] = f"{path_data}/pytorch_tabular"
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

# Models Search Spaces

## CategoryEmbeddingModel Search Space

In [None]:
search_space = {
    "model_config__layers": ["256-128-64", "512-256-128", "512-256-256-128", "32-16", "32-32-16", "64-32-16", "16-8", "32-16-8", "128-64", "128-128", "16-16", "128-128-64"],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
model_config = read_parse_config(f"{path_configs}/models/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [3, 6, 10, 15, 20, 25, 30, 35],
    "model_config__gflu_dropout": [0.0, 0.1],
    "model_config__gflu_feature_init_sparsity": [0.1, 0.2, 0.3, 0.4, 0.5],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## TabNetModel Search Space

In [None]:
search_space = {
    "model_config__n_d": [4, 8, 16, 32],
    "model_config__n_a": [4, 8, 16, 32],
    "model_config__n_steps": [3, 5, 7],
    "model_config__gamma": [1.2, 1.3, 1.4, 1.5],
    "model_config__n_independent": [1, 2, 3, 4],
    "model_config__n_shared": [1, 2, 3, 4],
    "model_config__mask_type": ["sparsemax", "entmax"],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
model_config = read_parse_config(f"{path_configs}/models/TabNetModelConfig.yaml", TabNetModelConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## DANet Search Space

In [None]:
search_space = {
    "model_config__n_layers": [4, 8, 16, 20, 32],
    "model_config__abstlay_dim_1": [16, 32, 64],
    "model_config__k": [4, 5, 6],
    "model_config__dropout_rate": [0.1],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
model_config = read_parse_config(f"{path_configs}/models/DANetConfig.yaml", DANetConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

## FTTransformer Search Space

In [None]:
search_space = {
    "model_config__num_heads": [4, 8, 16],
    "model_config__num_attn_blocks": [4, 6, 8, 10],
    "model_config__attn_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__add_norm_dropout": [0.1],
    "model_config__ff_dropout": [0.1],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
model_config = read_parse_config(f"{path_configs}/models/FTTransformerConfig.yaml", FTTransformerConfig)
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

# Grid Search and Random Search

In [None]:
%%capture

strategy = 'random_search' # 'grid_search'
seed = 1337
n_random_trials = 100
is_cross_validation = True

if grid_size < n_random_trials and strategy == 'random_search':
    strategy = 'grid_search'

trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
trainer_config['auto_lr_find'] = True

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if is_cross_validation:
        result = tuner.tune(
            train=train_validation,
            validation=None,
            search_space=search_space,
            metric="mean_absolute_error",
            mode="min",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=cv_indexes,
            return_best_model=True,
            verbose=False,
            progress_bar=True,
            random_state=seed,
        )
    else:
        result = tuner.tune(
            train=train_only,
            validation=validation_only,
            search_space=search_space,
            metric="mean_absolute_error",
            mode="min",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=None,
            return_best_model=True,
            verbose=False,
            progress_bar=False,
            random_state=seed,
        )
result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/{model_config['_model_name']}_{strategy}_{seed}_{optimizer_config['lr_scheduler']}.xlsx")

# Model Sweep Training

## Generate models' configs from trials files

In [None]:
n_top_trials = 5

target_models_types = [
    # 'CategoryEmbeddingModel',
    'GANDALF',
    # 'TabNetModel',
    # 'FTTransformer',
    'DANet'
]

common_params = {
    "task": "regression",
}

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for model_type in target_models_types:
    trials_files = glob(f"{trainer_config['checkpoints_path']}/trials/{model_type}*.xlsx")
    for trials_file in trials_files:
        df_trials = pd.read_excel(trials_file, index_col=0)
        df_trials.sort_values(['mean_absolute_error'], ascending=[True], inplace=True)
        df_trials = df_trials.head(n_top_trials)
        for _, row in df_trials.iterrows():
            head_config_tmp = copy.deepcopy(head_config)
            head_config_tmp['dropout'] = float(row['model_config.head_config__dropout'])
            if model_type == 'CategoryEmbeddingModel':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", CategoryEmbeddingModelConfig)
                model_config['layers'] = row['model_config__layers']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(CategoryEmbeddingModelConfig(**model_config))
            elif model_type == 'GANDALF':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", GANDALFConfig)
                model_config['gflu_stages'] = int(row['model_config__gflu_stages'])
                model_config['gflu_feature_init_sparsity'] = float(row['model_config__gflu_feature_init_sparsity'])
                model_config['gflu_dropout'] = float(row['model_config__gflu_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(GANDALFConfig(**model_config))
            elif model_type == 'TabNetModel':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", TabNetModelConfig)
                model_config['n_steps'] = row['model_config__n_steps']
                model_config['n_shared'] = row['model_config__n_shared']
                model_config['n_independent'] = row['model_config__n_independent']
                model_config['n_d'] = row['model_config__n_d']
                model_config['n_a'] = row['model_config__n_a']
                model_config['mask_type'] = row['model_config__mask_type']
                model_config['gamma'] = row['model_config__gamma']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(TabNetModelConfig(**model_config))
            elif model_type == 'FTTransformer':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", FTTransformerConfig)
                model_config['num_heads'] = int(row['model_config__num_heads'])
                model_config['num_attn_blocks'] = int(row['model_config__num_attn_blocks'])
                model_config['attn_dropout'] = float(row['model_config__attn_dropout'])
                model_config['add_norm_dropout'] = float(row['model_config__add_norm_dropout'])
                model_config['ff_dropout'] = float(row['model_config__ff_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(FTTransformerConfig(**model_config))
            elif model_type == 'DANet':
                model_config = read_parse_config(f"{path_configs}/models/{model_type}Config.yaml", DANetConfig)
                model_config['n_layers'] = int(row['model_config__n_layers'])
                model_config['abstlay_dim_1'] = int(row['model_config__abstlay_dim_1'])
                model_config['k'] = int(row['model_config__k'])
                model_config['dropout_rate'] = float(row['model_config__dropout_rate'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(DANetConfig(**model_config))
print(len(model_list))

## Perform model sweep

In [None]:
%%capture
for fold_id in range(val_n_splits):
    
    for seed in [1337, 55763, 40279, 87571, 234461]:
    
        trainer_config['seed'] = seed
        trainer_config['checkpoints'] = 'valid_loss'
        trainer_config['load_best'] = True
        trainer_config['auto_lr_find'] = True
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sweep_df, best_model = model_sweep_custom(
                task="regression",
                # train=train_validation,
                train=trains[fold_id],
                # validation=None,
                validation=validations[fold_id],
                test=test,
                data_config=data_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                model_list=model_list,
                common_model_args=common_params,
                metrics=["mean_absolute_error", "pearson_corrcoef"],
                metrics_params=[{}, {}],
                metrics_prob_input=[False, False],
                rank_metric=("mean_absolute_error", "lower_is_better"),
                return_best_model=True,
                seed=seed,
                progress_bar=False,
                verbose=False,
                suppress_lightning_logger=True,
            )
        fn_suffix = f"{fold_id}_{seed}_{best_model.config['lr_scheduler']}_{best_model.config['continuous_feature_transform']}"
        sweep_df.style.background_gradient(
            subset=[
                "train_loss",
                "validation_loss",
                "test_loss",
                "time_taken",
                "time_taken_per_epoch"
            ], cmap="RdYlGn_r"
        ).to_excel(f"{trainer_config['checkpoints_path']}/sweep_{fn_suffix}.xlsx")

## Save best models

In [None]:
%%capture

selected_models_sweeps = [
    {'fold_id': 0, 'seed': 1337, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [11]},
    {'fold_id': 0, 'seed': 40279, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [0]},
    {'fold_id': 0, 'seed': 55763, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [11, 0]},
    {'fold_id': 0, 'seed': 87571, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [13]},
    {'fold_id': 0, 'seed': 234461, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [11]},
    {'fold_id': 1, 'seed': 1337, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [1]},
    {'fold_id': 1, 'seed': 40279, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [10]},
    {'fold_id': 1, 'seed': 55763, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [13]},
    {'fold_id': 1, 'seed': 234461, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [4]},
    {'fold_id': 3, 'seed': 87571, 'lr_scheduler': 'CosineAnnealingWarmRestarts', 'continuous_feature_transform': 'yeo-johnson', 'ids': [7]},
]

explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

for sweep in selected_models_sweeps:
    
    data_config['continuous_feature_transform'] = sweep['continuous_feature_transform']
    
    sweep_suffix = f"{sweep['fold_id']}_{sweep['seed']}_{optimizer_config['lr_scheduler']}_{sweep['continuous_feature_transform']}"
    sweep_df = pd.read_excel(f"{trainer_config['checkpoints_path']}/sweep_{sweep_suffix}.xlsx", index_col=0)
    
    path_models = f"{trainer_config['checkpoints_path']}/candidates/{sweep_suffix}"
    pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
    sweep_df.style.background_gradient(
            subset=[
                "train_loss",
                "validation_loss",
                "test_loss",
                "time_taken",
                "time_taken_per_epoch"
            ], cmap="RdYlGn_r"
        ).to_excel(f"{path_models}/sweep.xlsx")
    
    models_ids = sweep['ids']

    for model_id in models_ids:
    
        tabular_model = TabularModel(
            data_config=data_config,
            model_config=ast.literal_eval(sweep_df.at[model_id, 'params']),
            optimizer_config=optimizer_config,
            trainer_config=trainer_config,
            verbose=True,
            suppress_lightning_logger=True
        )
        datamodule = tabular_model.prepare_dataloader(
            train=trains[sweep['fold_id']],
            validation=validations[sweep['fold_id']],
            seed=sweep['seed'],
        )
        model = tabular_model.prepare_model(
            datamodule
        )
        tabular_model._prepare_for_training(
            model,
            datamodule
        )
        tabular_model.load_weights(sweep_df.at[model_id, 'checkpoint'])
        tabular_model.evaluate(test, verbose=False)
        tabular_model.save_model(f"{path_models}/{model_id}")
        
        loaded_model = TabularModel.load_model(f"{path_models}/{model_id}")
        
        df = data.loc[:, data_config['target']]
        df.loc[trains[sweep['fold_id']].index, 'Group'] = 'Train'
        df.loc[validations[sweep['fold_id']].index, 'Group'] = 'Validation'
        df.loc[test.index, 'Group'] = 'Test'
        df['Prediction'] = loaded_model.predict(data)
        df['Error'] = df['Prediction'] - df[data_config['target'][0]]
        df.to_csv(f"{path_models}/{model_id}/df.xlsx")
        
        colors_groups = {
            'Train': 'chartreuse',
            'Validation': 'dodgerblue',
            'Test': 'crimson',
        }
        
        df_metrics = pd.DataFrame(
            index=list(colors_groups.keys()),
            columns=['mean_absolute_error', 'pearson_corrcoef', 'bias']
        )
        for group in colors_groups.keys():
            pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
            real = torch.from_numpy(df.loc[df['Group'] == group, data_config['target'][0]].values)
            df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
            df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
            df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
        df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")
        
        sns.set_theme(style='whitegrid')
        xy_min = df[[data_config['target'][0], 'Prediction']].min().min()
        xy_max = df[[data_config['target'][0], 'Prediction']].max().max()
        xy_ptp = xy_max - xy_min
        fig, ax = plt.subplots(figsize=(4.5, 4))
        scatter = sns.scatterplot(
            data=df,
            x=data_config['target'][0],
            y="Prediction",
            hue="Group",
            palette=colors_groups,
            linewidth=0.2,
            alpha=0.75,
            edgecolor="k",
            s=20,
            hue_order=list(colors_groups.keys()),
            ax=ax
        )
        bisect = sns.lineplot(
            x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
            y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
            linestyle='--',
            color='black',
            linewidth=1.0,
            ax=ax
        )
        ax.set_title(f"{sweep_df.at[model_id, 'model']} ({sweep_df.at[model_id, '# Params']} params, {sweep_df.at[model_id, 'epochs']} epochs)")
        ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
        ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
        plt.gca().set_aspect('equal', adjustable='box')
        fig.savefig(f"{path_models}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_models}/{model_id}/scatter.pdf", bbox_inches='tight')
        plt.close(fig)
        
        df_fig = df.loc[:, ['Error', 'Group']]
        groups_rename = {
            group: f"{group}" + "\n" +
                   fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
                   fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
                   fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
            for group in colors_groups
        }
        colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
        df_fig['Group'].replace(groups_rename, inplace=True)
        sns.set_theme(style='whitegrid')
        fig, ax = plt.subplots(figsize=(7, 4))
        violin = sns.violinplot(
            data=df_fig,
            x='Group',
            y='Error',
            palette=colors_groups_violin,
            scale='width',
            order=list(colors_groups_violin.keys()),
            saturation=0.75,
            legend=False,
            ax=ax
        )
        ax.set_xlabel('')
        fig.savefig(f"{path_models}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_models}/{model_id}/violin.pdf", bbox_inches='tight')
        plt.close(fig)
        
        try:
            explanation = loaded_model.explain(data, method=explain_method, baselines=explain_baselines)
            explanation.index = data.index
            explanation.to_excel(f"{path_models}/{model_id}/explanation.xlsx")
            
            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data.loc[:, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="violin",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
            plt.close(fig)
            
            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data.loc[:, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="bar",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_bar.pdf", bbox_inches='tight')
            plt.close(fig)
        
        except NotImplementedError:
            pass