# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [1]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]


# Load data

In [None]:
epi_data_type = 'no_harm'
imm_data_type = 'imp_source(imm)_method(knn)_params(5)' # 'origin' 'imp_source(imm)_method(knn)_params(5)' 'imp_source(imm)_method(miceforest)_params(2)'

selection_method = 'mrmr' # 'f_regression' 'mrmr'
n_feats = 100

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337

path = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{imm_data_type}/{epi_data_type}/{selection_method}_{n_feats}/EpImAge"
path_epi = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/epi"

data_full = pd.read_excel(f"{path}/data.xlsx", index_col=0)

# Filtering data
status_count = data_full['Status'].value_counts()
statuses_passed = status_count[status_count >= 10].index.values.tolist()
data_full = data_full[data_full['Status'].isin(statuses_passed)]
data_full.drop(data_full.index[data_full['Status'] == 'ICU'], inplace=True)
data_full.to_excel(f"{path}/data_filtered.xlsx", index_label='ID')

status_count = data_full['Status'].value_counts()

data = data_full[data_full['Status'] == 'Control']

feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0).index.values.tolist()
feats = [f"{f}_log" for f in feats]

gse_preproc = pd.read_excel(f"{path_epi}/preproc.xlsx", index_col=0)

df_groups = pd.read_excel(f"{path_epi}/groups.xlsx", index_col=0)
icd_chpts = np.sort(df_groups['ICD-11 chapter'].unique())
icd_codes = np.sort(df_groups['ICD-11 code'].unique())
colors = distinctipy.get_colors(len(icd_chpts), [mcolors.hex2color(mcolors.CSS4_COLORS['black']), mcolors.hex2color(mcolors.CSS4_COLORS['white'])], rng=1337, pastel_factor=0.5)
colors = [make_rgb_transparent(color, (1, 1, 1), 0.75) for color in colors]
colors_icd_chpts = {icd_chpt: colors[icd_chpt_id] for icd_chpt_id, icd_chpt in enumerate(icd_chpts)}

# Load stratification

In [None]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# Generate stratification

In [None]:
quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

gse_count = data['GSE'].value_counts()
stratify_cat_parts_all = {gse: data.index[data['GSE'] == gse].values for gse, count in gse_count.items()}

In [None]:
for part_all, ids_all in (pbar := tqdm(stratify_cat_parts_all.items())):
    pbar.set_description(f"Processing {part_all} ({len(ids_all)})")
    quantiles_all = pd.qcut(data.loc[ids_all, 'Age'].values, quantiles, labels=False, duplicates='drop')
    unique_all, counts_all = np.unique(quantiles_all, return_counts=True)
    
    if max(counts_all) >= len(quantiles):
        k_fold_all = RepeatedStratifiedKFold(
            n_splits=tst_n_splits,
            n_repeats=tst_n_repeats,
            random_state=tst_random_state
        )
        splits_all = k_fold_all.split(X=ids_all, y=quantiles_all, groups=quantiles_all)
        for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
            data.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
            data.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"
            
    else:
        for split_id in range(tst_n_splits * tst_n_repeats):
            data.loc[ids_all, f"Split_{split_id}"] = "tst"

In [None]:
samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        data[f"Split_{split_id}_Fold_{fold_id}"] = data[f"Split_{split_id}"]
    
    samples[split_id] = {
        'test': data.index[data[f"Split_{split_id}"] == "tst"].values,
        'train_validation': data.index[data[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }
    
    stratify_cat_parts_trnval = {}
    for gse, count in gse_count.items():
        gse_ids = data.index[(data['GSE'] == gse) & (data[f"Split_{split_id}"] == 'trn_val')].values
        if len(gse_ids) > 0:
            stratify_cat_parts_trnval[gse] = gse_ids

    for part_trnval, ids_trnval in stratify_cat_parts_trnval.items():
        quantiles_trnval = pd.qcut(data.loc[ids_trnval, 'Age'].values, quantiles, labels=False, duplicates='drop')
        unique_trnval, counts_trnval = np.unique(quantiles_trnval, return_counts=True)
        k_fold_trnval = RepeatedStratifiedKFold(
            n_splits=val_n_splits,
            n_repeats=val_n_repeats,
            random_state=val_random_state
        )
        splits_trnval = k_fold_trnval.split(X=ids_trnval, y=quantiles_trnval, groups=quantiles_trnval)
        for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
            data.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
            data.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]

with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Check samples intersection

In [None]:
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# Model Sweep Training

## Load non-model configs

In [None]:
path_configs = "D:/Work/bbs/notebooks/immunology/003_EpImAge/age_regression_configs"

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = ['Age']
data_config['continuous_cols'] = feats
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints_path'] = f"{path}/pytorch_tabular"
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 1
lr_find_num_training = 256
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 4.0

## Models Search Spaces

### GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [4, 6, 8],
    "model_config__gflu_dropout": [0.1],
    "model_config__gflu_feature_init_sparsity": [0.3],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [451, 1408],
}
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for i, params in enumerate(ParameterGrid(search_space)):
    head_config_tmp = copy.deepcopy(head_config)
    head_config_tmp['dropout'] = params['model_config.head_config__dropout']
    model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
    model_config['gflu_stages'] = params['model_config__gflu_stages']
    model_config['gflu_feature_init_sparsity'] = params['model_config__gflu_feature_init_sparsity']
    model_config['gflu_dropout'] = params['model_config__gflu_dropout']
    model_config['learning_rate'] = params['model_config__learning_rate']
    model_config['seed'] = params['model_config__seed']
    model_config['head_config'] = head_config_tmp
    model_list.append(GANDALFConfig(**model_config))
    

## Perform model sweep

In [None]:
%%capture

common_params = {
    "task": "regression",
}

seed = 451

dfs_result = []
for split_id, split_dict in samples.items():
    for fold_id in split_dict['trains']:
        test = data.loc[split_dict['test'], feats + ["Age"]]
        train = data.loc[split_dict['trains'][fold_id], feats + ["Age"]]
        validation = data.loc[split_dict['validations'][fold_id], feats + ["Age"]]

        trainer_config['seed'] = seed
        trainer_config['checkpoints'] = 'valid_loss'
        trainer_config['load_best'] = True
        trainer_config['auto_lr_find'] = True
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sweep_df, best_model = model_sweep_custom(
                task="regression",
                train=train,
                validation=validation,
                test=test,
                data_config=data_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                model_list=model_list,
                common_model_args=common_params,
                metrics=["mean_absolute_error", "pearson_corrcoef"],
                metrics_params=[{}, {}],
                metrics_prob_input=[False, False],
                rank_metric=("mean_absolute_error", "lower_is_better"),
                return_best_model=True,
                seed=seed,
                progress_bar=False,
                verbose=False,
                suppress_lightning_logger=True,
                min_lr = lr_find_min_lr,
                max_lr = lr_find_max_lr,
                num_training = lr_find_num_training,
                mode = lr_find_mode,
                early_stop_threshold = lr_find_early_stop_threshold,
            )
        sweep_df['seed'] = seed
        sweep_df['split_id'] = split_id
        sweep_df['fold_id'] = fold_id
        sweep_df["train_more"] = False
        sweep_df.loc[(sweep_df["train_loss"] > sweep_df["test_loss"]) | (sweep_df["train_loss"] > sweep_df["validation_loss"]), "train_more"] = True
        sweep_df["validation_test_mean_loss"] = (sweep_df["validation_loss"] + sweep_df["test_loss"]) / 2.0
        sweep_df["train_validation_test_mean_loss"] = (sweep_df["train_loss"] + sweep_df["validation_loss"] + sweep_df["test_loss"]) / 3.0
        
        dfs_result.append(sweep_df)
        
        fn_suffix = (f"models({len(model_list)})_"
                     f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_"
                     f"val({val_random_state}_{val_n_splits}_{val_n_repeats})")
        try:
            df_result = pd.concat(dfs_result, ignore_index=True)
            df_result.style.background_gradient(
                subset=[
                    "train_loss",
                    "validation_loss",
                    "test_loss",
                    "time_taken",
                    "time_taken_per_epoch"
                ], cmap="RdYlGn_r"
            ).to_excel(f"{trainer_config['checkpoints_path']}/{fn_suffix}.xlsx")
        except PermissionError:
            pass

# Best models analysis

In [None]:
gse_count = data['GSE'].value_counts()
gses = gse_count.index.values
gse_ids = {gse: data.index[data['GSE'] == gse].values for gse, count in gse_count.items()}
colors = distinctipy.get_colors(len(gses), [mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_gse = {gses[gse_id]: colors[gse_id] for gse_id in range(len(gses))}
statuses_rename = {x: x.replace(' ', '\n').replace('-', '\n') for x in data_full['Status'].value_counts().index.values}

data_full['Status'] = data_full['Status'].replace(statuses_rename) 
status_count = data_full['Status'].value_counts()
statuses = status_count.index.values
colors = distinctipy.get_colors(len(statuses) - 1, [mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue']), mcolors.hex2color(mcolors.CSS4_COLORS['white']), mcolors.hex2color(mcolors.CSS4_COLORS['black'])], rng=1337)
colors_status = {'Control': mcolors.hex2color(mcolors.CSS4_COLORS['dodgerblue'])}
for status_id, status in enumerate(statuses):
    if status != 'Control':
        colors_status[status] = colors[status_id - 1]

mosaic_violins = []
mosaic_rows = np.sort(df_groups['Row on violin plot'].unique())
for mosaic_row in mosaic_rows:
    df_mosaic_row = df_groups[df_groups['Row on violin plot'] == mosaic_row].sort_values(by=['ICD-11 chapter', 'ICD-11 code', 'GSE'], ascending=[True, True, True])
    mosaic_row_labels = []
    for plot_id, plot_row in df_mosaic_row.iterrows():
        n_violins = len(ast.literal_eval(plot_row['Statuses']))
        mosaic_row_labels += [plot_id]*n_violins
    mosaic_violins.append(mosaic_row_labels)
    
max_mosaic_row = max([len(x) for x in mosaic_violins])
violons_empty_panels = set()
for row_id, row in enumerate(mosaic_violins):
    for added_spaces in range(len(row), max_mosaic_row):
        violons_empty_panels.add(f'Empty row {row_id}')
        mosaic_violins[row_id].append(f'Empty row {row_id}')
        
is_explain = False

models_ids = [
    145,
    # 199,
    # 600,
    # 858,
    # 966,
    # 1692,
    # 1724,
    # 1725,
    # 1950,
    # 1952,
]

models_ids = sorted(list(set(models_ids)))

n_models = 6
fn_sweep = (
    f"models({n_models})_"
    f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_"
    f"val({val_random_state}_{val_n_splits}_{val_n_repeats})"
)
df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/{fn_sweep}.xlsx", index_col=0)
path_models = f"{path}/pytorch_tabular/candidates/{fn_sweep}"
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
df_sweeps.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/sweep.xlsx")
df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/selected.xlsx")

explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

df_models_metrics = pd.DataFrame(index=models_ids)
df_models_check = pd.DataFrame(index=models_ids)

df_gses_models_metrics = {}
for md in ['MAE All', 'MAE Test', 'Rho All', 'Rho Test', 'Bias All', 'Bias Test']:
    df_gses_models_metrics[md] = pd.DataFrame(index=gses, columns=['Count'] + list(models_ids))
    df_gses_models_metrics[md].loc[gses, 'Count'] = gse_count[gses]

for model_id in models_ids:
    print(model_id)

    df_models_check.at[model_id, 'Train MAE Before'] = df_sweeps.at[model_id, 'train_mean_absolute_error']
    df_models_check.at[model_id, 'Validation MAE Before'] = df_sweeps.at[model_id, 'validation_mean_absolute_error']
    df_models_check.at[model_id, 'Test MAE Before'] = df_sweeps.at[model_id, 'test_mean_absolute_error']
    df_models_check.at[model_id, 'Train Rho Before'] = df_sweeps.at[model_id, 'train_pearson_corrcoef']
    df_models_check.at[model_id, 'Validation Rho Before'] = df_sweeps.at[model_id, 'validation_pearson_corrcoef']
    df_models_check.at[model_id, 'Test Rho Before'] = df_sweeps.at[model_id, 'test_pearson_corrcoef']

    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]

    test = data.loc[split_dict['test'], feats + ["Age"]]
    train = data.loc[split_dict['trains'][fold_id], feats + ["Age"]]
    validation = data.loc[split_dict['validations'][fold_id], feats + ["Age"]]

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_models}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_models}/{model_id}", dirs_exist_ok=True)

    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']
    data_full[['GPL', 'GSE', 'Age', 'Sex', 'Status', 'Prediction', 'Error']].to_excel(f"{path_models}/{model_id}/df_full.xlsx")

    # Mosaic violins plot
    sns.set_theme(style='ticks')
    fig_height = 5 * len(mosaic_violins)
    fig_width = 1.4 * max_mosaic_row
    fig, axs = plt.subplot_mosaic(mosaic=mosaic_violins, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False)

    for plot_id, plot_row in df_groups.iterrows():

        plot_statuses = [statuses_rename[x] for x in ast.literal_eval(plot_row['Statuses'])]
        plot_groups = [(statuses_rename[x[0]], statuses_rename[x[1]]) for x in ast.literal_eval(plot_row['Groups'])]

        df_plot = data_full.loc[(data_full['GSE'] == plot_row['GSE']) & (data_full['Status'].isin(plot_statuses)), ['Status', 'Error']]
        plot_status_count = df_plot['Status'].value_counts()
        plot_statuses_rename = {}
        for x in plot_statuses:
            plot_statuses_rename[x] = x + f"\nCount: {plot_status_count[x]}\nBias: {np.mean(df_plot.loc[df_plot['Status'] == x, 'Error']):0.1f}"
        plot_statuses = [plot_statuses_rename[x] for x in plot_statuses]
        plot_groups = [(plot_statuses_rename[x[0]], plot_statuses_rename[x[1]]) for x in plot_groups]
        df_plot['Status'] = df_plot['Status'].replace(plot_statuses_rename)
        colors_plot_status = {plot_statuses_rename[x]: colors_status[x] for x in plot_statuses_rename}
        
        pval_formatted = []
        for plot_group in plot_groups:
            stat, pval = mannwhitneyu(
                df_plot.loc[df_plot["Status"] == plot_group[0], "Error"].values,
                df_plot.loc[df_plot["Status"] == plot_group[1], "Error"].values,
                alternative="two-sided",
            )
            pval_formatted.append(f"{pval:.1e}")
            
        violinplot = sns.violinplot(
            data=df_plot,
            x='Status',
            y='Error',
            palette=colors_plot_status,
            scale='width',
            order=plot_statuses,
            saturation=0.75,
            legend=False,
            ax=axs[plot_id]
        )
        annotator = Annotator(
            ax=axs[plot_id],
            pairs=plot_groups,
            data=df_plot,
            x="Status",
            y="Error",
            order=plot_statuses,
        )
        annotator.set_custom_annotations(pval_formatted)
        annotator.configure(loc='inside', verbose=0)
        annotator.annotate()

        axs[plot_id].set_xlabel('')
        axs[plot_id].set_title(f"{plot_row['GSE']} ({df_plot.shape[0]})")
        axs[plot_id].set_facecolor(colors_icd_chpts[plot_row['ICD-11 chapter']])

    for empty_panel in violons_empty_panels:
        axs[empty_panel].axis('off')
    fig.tight_layout()    
    fig.savefig(f"{path_models}/{model_id}/violins_Status.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/violins_Status.pdf", bbox_inches='tight')
    plt.close(fig)

    data['Group'] = ''
    data.loc[train.index, 'Group'] = 'Train'
    data.loc[validation.index, 'Group'] = 'Validation'
    data.loc[test.index, 'Group'] = 'Test'
    data['Prediction'] = model.predict(data)
    data['Error'] = data['Prediction'] - data['Age']
    data[['GPL', 'GSE', 'Age', 'Sex', 'Status', 'Group', 'Prediction', 'Error']].to_excel(f"{path_models}/{model_id}/data_ctrl.xlsx")

    pred = torch.from_numpy(data.loc[:, 'Prediction'].values)
    real = torch.from_numpy(data.loc[:, 'Age'].values)
    df_models_metrics.at[model_id, f'MAE\nAll'] = mean_absolute_error(pred, real).numpy()
    df_models_metrics.at[model_id, f'Rho\nAll'] = pearson_corrcoef(pred, real).numpy()
    df_models_metrics.at[model_id, f'Bias\nAll'] = np.mean(data.loc[:, 'Error'].values)

    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=['mean_absolute_error', 'pearson_corrcoef', 'bias']
    )
    for group in colors_groups.keys():

        pred = torch.from_numpy(data.loc[data['Group'] == group, 'Prediction'].values)
        real = torch.from_numpy(data.loc[data['Group'] == group, 'Age'].values)

        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(data.loc[data['Group'] == group, 'Error'].values)

        df_models_metrics.at[model_id, f'MAE\n{group}'] = df_metrics.at[group, 'mean_absolute_error']
        df_models_metrics.at[model_id, f'Rho\n{group}'] = df_metrics.at[group, 'pearson_corrcoef']
        df_models_metrics.at[model_id, f'Bias\n{group}'] = df_metrics.at[group, 'bias']

        df_models_check.at[model_id, f'{group} MAE After'] = df_metrics.at[group, 'mean_absolute_error']
        df_models_check.at[model_id, f'{group} Rho After'] = df_metrics.at[group, 'pearson_corrcoef']

    df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")

    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    scatter = sns.scatterplot(
        data=data,
        x='Age',
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    xy_min = data[['Age', 'Prediction']].min().min()
    xy_max = data[['Age', 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params, {df_sweeps.at[model_id, 'epochs']} epochs)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_models}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)

    for gse, ids in gse_ids.items():
        pred = torch.from_numpy(data.loc[ids, 'Prediction'].values)
        real = torch.from_numpy(data.loc[ids, 'Age'].values)
        df_models_metrics.at[model_id, f'MAE\n{gse}\nAll'] = mean_absolute_error(pred, real).numpy()
        df_models_metrics.at[model_id, f'Rho\n{gse}\nAll'] = pearson_corrcoef(pred, real).numpy()
        df_models_metrics.at[model_id, f'Bias\n{gse}\nAll'] = np.mean(data.loc[ids, 'Error'].values)
        df_gses_models_metrics['MAE All'].at[gse, model_id] = df_models_metrics.at[model_id, f'MAE\n{gse}\nAll']
        df_gses_models_metrics['Rho All'].at[gse, model_id] = df_models_metrics.at[model_id, f'Rho\n{gse}\nAll']
        df_gses_models_metrics['Bias All'].at[gse, model_id] = df_models_metrics.at[model_id, f'Bias\n{gse}\nAll']
        for group in colors_groups.keys():
            gse_group_ids = data.index[(data['Group'] == group) & (data['GSE'] == gse)].values
            if len(gse_group_ids) > 0:
                pred = torch.from_numpy(data.loc[gse_group_ids, 'Prediction'].values)
                real = torch.from_numpy(data.loc[gse_group_ids, 'Age'].values)
                df_models_metrics.at[model_id, f'MAE\n{gse}\n{group}'] = mean_absolute_error(pred, real).numpy()
                df_models_metrics.at[model_id, f'Rho\n{gse}\n{group}'] = pearson_corrcoef(pred, real).numpy()
                df_models_metrics.at[model_id, f'Bias\n{gse}\n{group}'] = np.mean(data.loc[gse_group_ids, 'Error'].values)
                if group == 'Test':
                    df_gses_models_metrics['MAE Test'].at[gse, model_id] = df_models_metrics.at[model_id, f'MAE\n{gse}\n{group}']
                    df_gses_models_metrics['Rho Test'].at[gse, model_id] = df_models_metrics.at[model_id, f'Rho\n{gse}\n{group}']
                    df_gses_models_metrics['Bias Test'].at[gse, model_id] = df_models_metrics.at[model_id, f'Bias\n{gse}\n{group}']

    ids_parts = {
        'All': data.index.values,
        'Test': data.index[data['Group'] == 'Test'].values,
    }
    for part, ids_part in ids_parts.items():
        # Errors in GSEs
        df_fig = data.loc[ids_part, ['Error', 'GSE']].copy()
        df_fig['GSE'] = pd.Categorical(df_fig.GSE, categories=gses, ordered=True)
        df_fig = df_fig.sort_values('GSE')
        gses_rename = {
            gse: f"{gse} ({gse_count[gse]})" + "\n" +
                   fr"MAE: {df_gses_models_metrics[f'MAE {part}'].at[gse, model_id]:0.2f}" + "\n"
                   fr"Pearson $\rho$: {df_gses_models_metrics[f'Rho {part}'].at[gse, model_id]:0.2f}" + "\n" +
                   fr"Bias: {df_gses_models_metrics[f'Bias {part}'].at[gse, model_id]:0.2f}"  + "\n" + 
                   f"{gse_preproc.at[gse, 'Preproc']}"
            for gse in colors_gse
        }
        gse_colors_grid = {gses_rename[gse]: colors_gse[gse] for gse in colors_gse}
        df_fig['GSE'].replace(gses_rename, inplace=True) 
        sns.set_theme(style="whitegrid", rc={"axes.facecolor": (0, 0, 0, 0)})
        g = sns.FacetGrid(df_fig, row="GSE", hue="GSE", aspect=8, height=0.5, palette=gse_colors_grid)
        g.map(
            sns.kdeplot,
            'Error',
            fill=True, 
            alpha=1.0,
            linewidth=0.5
        )
        g.map(
            sns.kdeplot,
            'Error',
            color="black",
            linewidth=1.0,
        )
        # g.refline(y=0, linewidth=2.0, linestyle="-", color=None, clip_on=False)
        def label(x, color, label):
            ax = plt.gca()
            ax.text(-0.15, 0.2, label, size=4, fontweight="light", color=color, ha="left", va="center", transform=ax.transAxes, path_effects=[pe.withStroke(linewidth=0.5, foreground="black")])
        g.map(label, 'Error')
        # Set the subplots to overlap
        g.figure.subplots_adjust(hspace=0.0)
        # Remove axes details that don't play well with overlap
        g.set_titles("")
        g.set(yticks=[], ylabel="")
        g.despine(bottom=True, left=True)
        # Save
        g.savefig(f"{path_models}/{model_id}/GSEs_Error_{part}.png", bbox_inches='tight', dpi=200)
        g.savefig(f"{path_models}/{model_id}/GSEs_Error_{part}.pdf", bbox_inches='tight')
        plt.close(g.fig)

    # sns.set_theme(style='whitegrid')
    # fig, ax = plt.subplots(figsize=(4.5, 4))
    # kdeplot = sns.kdeplot(
    #     data=data,
    #     x='Age',
    #     y='Prediction',
    #     fill=True,
    #     cbar=False,
    #     color='paleturquoise',
    #     cut=0,
    #     legend=False,
    #     ax=ax
    # )
    # bisect = sns.lineplot(
    #     x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
    #     y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
    #     linestyle='--',
    #     color='black',
    #     linewidth=1.0,
    #     ax=ax
    # )
    # ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params, {df_sweeps.at[model_id, 'epochs']} epochs)")
    # ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    # ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    # plt.gca().set_aspect('equal', adjustable='box')
    # fig.savefig(f"{path_models}/{model_id}/kde.png", bbox_inches='tight', dpi=200)
    # fig.savefig(f"{path_models}/{model_id}/kde.pdf", bbox_inches='tight')
    # plt.close(fig)

    df_fig = data.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_models}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)

    if is_explain:
        try:
            explanation = model.explain(data, method=explain_method, baselines=explain_baselines)
            explanation.index = data.index
            explanation.to_excel(f"{path_models}/{model_id}/explanation.xlsx")

            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data.loc[:, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="violin",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
            plt.close(fig)

            sns.set_theme(style='whitegrid')
            fig = shap.summary_plot(
                shap_values=explanation.loc[:, feats].values,
                features=data.loc[:, feats].values,
                feature_names=feats,
                max_display=explain_n_feats_to_plot,
                plot_type="bar",
                show=False,
            )
            plt.savefig(f"{path_models}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
            plt.savefig(f"{path_models}/{model_id}/explain_bar.pdf", bbox_inches='tight')
            plt.close(fig)

        except NotImplementedError:
            pass

df_models_metrics.to_excel(f"{path_models}/models_metrics.xlsx", index_label='Model ID')

for group in ['Train', 'Validation', 'Test']:
    df_models_check[f'{group} MAE Diff'] = df_models_check[f'{group} MAE After'] - df_models_check[f'{group} MAE Before']
    df_models_check[f'{group} Rho Diff'] = df_models_check[f'{group} Rho After'] - df_models_check[f'{group} Rho Before']
df_models_check.to_excel(f"{path_models}/models_check.xlsx", index_label='Model ID')

with pd.ExcelWriter(f"{path_models}/gses_models_metrics.xlsx", engine='xlsxwriter') as writer:
    for md in ['MAE All', 'MAE Test', 'Rho All', 'Rho Test', 'Bias All', 'Bias Test']:
        df_gses_models_metrics[md].insert(1, 'Preproc', gse_preproc.loc[df_gses_models_metrics[md].index, 'Preproc'])
        df_gses_models_metrics[md].to_excel(writer, sheet_name=md)

# All models processing

In [None]:
n_models = 6
fn_sweep = (
    f"models({n_models})_"
    f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_"
    f"val({val_random_state}_{val_n_splits}_{val_n_repeats})"
)
df_sweeps = pd.read_excel(f"{path}/pytorch_tabular/selected.xlsx", index_col=0)

df_sweeps.insert(15, 'Passed\nICD-11\nTotal', None)
ins_pos = 16
for icd_chpt in icd_chpts:
    df_sweeps.insert(ins_pos, f'Passed\nICD-11\nChapter {icd_chpt}', None)
    ins_pos += 1

for model_id, model_row in (pbar := tqdm(df_sweeps.iterrows())):
    pbar.set_description(f"Processing {model_id}")

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)

    data_full['Prediction'] = model.predict(data_full)
    data_full['Error'] = data_full['Prediction'] - data_full['Age']

    passed_icd = {icd_chpt: 0 for icd_chpt in icd_chpts}
    for icd_chpt in icd_chpts:
        df_ckpt = df_groups[df_groups['ICD-11 chapter'] == icd_chpt]
        for section_id, section_row in df_ckpt.iterrows():
            section_statuses = ast.literal_eval(section_row['Statuses'])
            section_groups = ast.literal_eval(section_row['Groups'])
            section_directions = ast.literal_eval(section_row['Directions'])
            df_section = data_full.loc[(data_full['GSE'] == section_row['GSE']) & (data_full['Status'].isin(section_statuses)), ['Status', 'Error']]
            
            for section_group_id, section_group in enumerate(section_groups):
                stat, pval = mannwhitneyu(
                    df_section.loc[df_section["Status"] == section_group[0], "Error"].values,
                    df_section.loc[df_section["Status"] == section_group[1], "Error"].values,
                    alternative="two-sided",
                )
                bias_0 = np.mean(df_section.loc[df_section['Status'] == section_group[0], 'Error'])
                bias_1 = np.mean(df_section.loc[df_section['Status'] == section_group[1], 'Error'])
                group_direction = section_directions[section_group_id]
                if pval < 0.05:
                    if group_direction == 'Increasing' and bias_1 > bias_0:
                        passed_icd[icd_chpt] += 1
                    elif group_direction == 'Decreasing' and bias_1 < bias_0:
                        passed_icd[icd_chpt] += 1
        
        df_sweeps.at[model_id, f'Passed\nICD-11\nChapter {icd_chpt}'] = passed_icd[icd_chpt]              
        
    df_sweeps.at[model_id, f'Passed\nICD-11\nTotal'] = sum(passed_icd.values())
    
df_sweeps.to_excel(f"{path}/pytorch_tabular/models.xlsx", index_label='Model ID')