# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import missingno as msno
from scipy import stats
from statsmodels.stats.multitest import multipletests
from sklearn.decomposition import PCA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE
from sklearn.cluster import DBSCAN, HDBSCAN


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data

In [None]:
feats_set = 'Электрокардиограмма (чекап)'
feats_title = 'Электрокардиограмма'

path = f"E:/YandexDisk/Work/bbd/millennium/models/{feats_set}"
path_models = f"E:/Git/bbs/notebooks/millennium/pt/{feats_set}"

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337

data = pd.read_excel(f"{path}/data.xlsx", index_col=0)
df_feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0)
feat_trgt = 'Возраст'
feats_cnt = df_feats.index.to_list()
feats_cat = []
feats = list(feats_cnt) + list(feats_cat)
feats_with_trgt = [feat_trgt] + feats

## Check duplicated indexes

In [None]:
data.loc[data.index.duplicated(), [feat_trgt]]

# Regenerate plots (if nesessary)

In [None]:
color = 'crimson'
n_cols = 4

# Missing values matrix
df_msno = data[feats + [feat_trgt]].copy()
df_msno.sort_values([feat_trgt], ascending=[True], inplace=True)
msno_mtx = msno.matrix(
    df=df_msno,
    label_rotation=90,
    color=mcolors.to_rgb(color),
    figsize=(0.7 * len(feats), 5),
)
plt.xticks(ha='center')
plt.setp(msno_mtx.xaxis.get_majorticklabels(), ha="center")
msno_mtx.set_title(feats_title, fontsize='large')
msno_mtx.set_ylabel("IDs", fontsize='large')
plt.savefig(f"{path}/msno.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/msno.pdf", bbox_inches='tight')
plt.clf()

# Age histogramm
hist_bins = np.linspace(5, 115, 23)
sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(6, 3.5), layout='constrained')
histplot = sns.histplot(
    data=data,
    bins=hist_bins,
    edgecolor='k',
    linewidth=1,
    x=feat_trgt,
    color=color,
    ax=ax
)
histplot.set(xlim=(0, 120))
histplot.set_ylabel('Количество')
histplot.set_title(feats_title)
plt.savefig(f"{path}/age_hist.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/age_hist.pdf", bbox_inches='tight')
plt.close(fig)

# Input features and output feature correlations
df_corr = pd.DataFrame(index=feats, columns=['rho'])
for f in tqdm(feats):
    df_tmp = data.loc[:, [feat_trgt, f]].dropna(axis=0, how='any')
    if df_tmp.shape[0] > 1:
        vals_1 = df_tmp.loc[:, feat_trgt].values
        vals_2 = df_tmp.loc[:, f].values
        df_corr.at[f, 'rho'], _ = stats.pearsonr(vals_1, vals_2)
df_corr.dropna(axis=0, how='any', inplace=True)
df_corr.insert(1, "abs(rho)", df_corr['rho'].abs())
df_corr.sort_values(["abs(rho)"], ascending=[False], inplace=True)
df_corr = df_corr.apply(pd.to_numeric)
sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(0.9 + 0.039 * df_corr.index.str.len().max(), 0.9 + 0.4 * len(feats) + 0.04 * df_corr.index.str.len().max()) , layout='constrained')
heatmap = sns.heatmap(
    df_corr.loc[:, ['rho']],
    annot=True,
    fmt=".2f",
    vmin=-1.0,
    vmax=1.0,
    cmap='coolwarm',
    linewidth=0.1,
    linecolor='black',
    #annot_kws={"fontsize": 15},
    cbar_kws={
        # "shrink": 0.9,
        # "aspect": 30,
        #'fraction': 0.046, 
        #'pad': 0.04,
    },
    ax=ax
)
heatmap_pos = ax.get_position()
ax.figure.axes[-1].set_position([heatmap_pos.x1 + 0.05, heatmap_pos.y0, 0.1, heatmap_pos.height])
ax.figure.axes[-1].set_ylabel(r"Pearson $\rho$")
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_xlabel('')
ax.set_ylabel('')
ax.set_title(feats_title, fontsize=16)
ax.set(xticklabels=[])
ax.set(xticks=[])
plt.savefig(f"{path}/age_feats_pearsonr.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/age_feats_pearsonr.pdf", bbox_inches='tight')
plt.close(fig)


n_rows = int(np.ceil(len(feats) / n_cols))
n_empty = n_rows * n_cols - len(feats)
sns.set_theme(style='ticks')
fig, axs = plt.subplots(
    nrows=n_rows,
    ncols=n_cols,
    figsize=(n_cols * 3.0, n_rows * 2.5),
    gridspec_kw={'wspace':0.10, 'hspace': 0.05}, 
    sharex=True,
    layout='constrained'
)
for feat_id, feat in enumerate(df_corr.index.values):
    row_id, col_id = divmod(feat_id, n_cols)
    regplot = sns.regplot(
        data=data,
        x=feat_trgt,
        y=feat,
        color=color,
        scatter_kws=dict(
            linewidth=0.5,
            alpha=0.75,
            edgecolor="k",
            s=16,
        ),
        ax=axs[row_id, col_id]
    )
    axs[row_id, col_id].set_title(fr"Pearson $\rho$: {df_corr.loc[feat, 'rho']:0.3f}")
    y_labe_fontsize = min(15 / (len(feat) / 20), 13)
    axs[row_id, col_id].set_ylabel(feat, fontsize=y_labe_fontsize)
for empty_id in range(n_empty):   
    axs[n_rows - 1, n_cols - 1 - empty_id].axis('off')
fig.savefig(f"{path}/age_feats_regplot.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/age_feats_regplot.pdf", bbox_inches='tight')
plt.close(fig)

# Correlation heatmap
df_corr = pd.DataFrame(data=np.zeros(shape=(len(feats_with_trgt), len(feats_with_trgt))), index=feats_with_trgt, columns=feats_with_trgt)
for f_id_1 in range(len(feats_with_trgt)):
    for f_id_2 in range(f_id_1, len(feats_with_trgt)):
        f_1 = feats_with_trgt[f_id_1]
        f_2 = feats_with_trgt[f_id_2]
        if f_id_1 != f_id_2:
            vals_1 = data.loc[:, f_1].values
            vals_2 = data.loc[:, f_2].values
            corr, pval = stats.pearsonr(vals_1, vals_2)
            df_corr.at[f_2, f_1] = pval
            df_corr.at[f_1, f_2] = corr
        else:
            df_corr.at[f_2, f_1] = np.nan
selection = np.tri(df_corr.shape[0], df_corr.shape[1], -1, dtype=bool)
df_fdr = df_corr.where(selection).stack().reset_index()
df_fdr.columns = ['row', 'col', 'pval']
_, df_fdr['pval_fdr_bh'], _, _ = multipletests(df_fdr.loc[:, 'pval'].values, 0.05, method='fdr_bh')
nzmin = df_fdr['pval_fdr_bh'][df_fdr['pval_fdr_bh'].gt(0)].min(0) * 0.5
df_fdr['pval_fdr_bh'].replace({0.0: nzmin}, inplace=True)
df_corr_fdr = df_corr.copy()
for line_id in range(df_fdr.shape[0]):
    df_corr_fdr.loc[df_fdr.at[line_id, 'row'], df_fdr.at[line_id, 'col']] = -np.log10(df_fdr.at[line_id, 'pval_fdr_bh'])
df_corr_fdr.to_excel(f"{path}/feats_pearsonr.xlsx")
sns.set_theme(style='ticks')
fig, ax = plt.subplots(figsize=(8.5 + 0.35 * len(feats_with_trgt), 6.5 + 0.25 * len(feats_with_trgt)), layout='constrained')
cmap_triu = plt.get_cmap("seismic").copy()
mask_triu=np.tri(len(feats_with_trgt), len(feats_with_trgt), -1, dtype=bool)
heatmap_diff = sns.heatmap(
    df_corr_fdr,
    mask=mask_triu,
    annot=True,
    fmt=".2f",
    center=0.0,
    cmap=cmap_triu,
    linewidth=0.1,
    linecolor='black',
    annot_kws={"fontsize": 32 / np.sqrt(len(df_corr_fdr.values) + 8)},
    ax=ax
)
ax.figure.axes[-1].set_ylabel(r"Pearson $\rho$")
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
cmap_tril = plt.get_cmap("viridis").copy()
cmap_tril.set_under('black')
mask_tril=np.tri(len(feats_with_trgt), len(feats_with_trgt), -1, dtype=bool).T
heatmap_pval = sns.heatmap(
    df_corr_fdr,
    mask=mask_tril,
    annot=True,
    fmt=".1f",
    vmin=-np.log10(0.05),
    cmap=cmap_tril,
    linewidth=0.1,
    linecolor='black',
    annot_kws={"fontsize": 32 / np.sqrt(len(df_corr_fdr.values) + 8)},
    ax=ax
)
ax.figure.axes[-1].set_ylabel(r"$-\log_{10}(\mathrm{p-value})$")
for spine in ax.figure.axes[-1].spines.values():
    spine.set(visible=True, lw=0.25, edgecolor="black")
ax.set_xlabel('', fontsize=16)
ax.set_ylabel('', fontsize=16)
ax.set_title(feats_title, fontsize=16)
plt.savefig(f"{path}/feats_pearsonr.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/feats_pearsonr.pdf", bbox_inches='tight')
plt.close(fig)

df_proc = data.copy()

# IQR outliers
out_columns = []
for f in tqdm(feats_with_trgt):
    q1 = df_proc[f].quantile(0.25)
    q3 = df_proc[f].quantile(0.75)
    iqr = q3 - q1
    df_proc[f"{f} IQR Outlier"] = 1
    out_columns.append(f"{f} IQR Outlier")
    filter = (df_proc[f] >= q1 - 1.5 * iqr) & (df_proc[f] <= q3 + 1.5 * iqr)
    df_proc.loc[filter, f"{f} IQR Outlier"] = 0
df_proc[f"Number of IQR Outliers"] = df_proc.loc[:, out_columns].sum(axis=1)

hist_bins = np.linspace(-0.5, len(feats_with_trgt) + 0.5, len(feats_with_trgt) + 2)
fig = plt.figure(figsize=(5, 3))
sns.set_theme(style='ticks')
histplot = sns.histplot(
    data=df_proc,
    x=f"Number of IQR Outliers",
    multiple="stack",
    bins=hist_bins,
    edgecolor='k',
    linewidth=1.0,
    color=color,
)
histplot.set(xlim=(-0.5, max(df_proc['Number of IQR Outliers'] + 0.5)))
histplot.set_title(feats_title)
histplot.set_xlabel("Количество IQR выбросов")
histplot.set_ylabel("Количество записей")
plt.savefig(f"{path}/outs_iqr_hist.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/outs_iqr_hist.pdf", bbox_inches='tight')
plt.close(fig)

out_columns = [f"{f} IQR Outlier" for f in feats_with_trgt]
df_msno = df_proc.loc[:, out_columns].copy()
df_msno.replace({1: np.nan}, inplace=True)
df_msno.rename(columns=dict(zip(out_columns, feats_with_trgt)), inplace=True)

# Plot barplot for features with outliers
msno_bar = msno.bar(
    df=df_msno,
    label_rotation=90,
    color=color,
)
plt.xticks(ha='center')
plt.setp(msno_bar.xaxis.get_majorticklabels(), ha="center")
msno_bar.set_title(feats_title, fontsize='large')
msno_bar.set_ylabel("Записи без выбросов", fontsize='large')
plt.savefig(f"{path}/outs_iqr_bar.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/outs_iqr_bar.pdf", bbox_inches='tight')
plt.clf()

# Plot matrix of samples outliers distribution
msno_mtx = msno.matrix(
    df=df_msno,
    label_rotation=90,
    color=mcolors.to_rgb(color),
)
plt.xticks(ha='center')
plt.setp(msno_bar.xaxis.get_majorticklabels(), ha="center")
msno_mtx.set_title(feats_title, fontsize='large')
msno_mtx.set_ylabel("Записи", fontsize='large')
plt.savefig(f"{path}/outs_iqr_matrix.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/outs_iqr_matrix.pdf", bbox_inches='tight')
plt.clf()

# Plot heatmap of features outliers correlations
msno_heatmap = msno.heatmap(
    df=df_msno,
    label_rotation=90,
    cmap="bwr",
    fontsize=12,
)
msno_heatmap.set_title(feats_title, fontsize='large')
plt.setp(msno_heatmap.xaxis.get_majorticklabels(), ha="center")
msno_heatmap.collections[0].colorbar.ax.tick_params(labelsize=20)
plt.savefig(f"{path}/outs_iqr_heatmap.png", bbox_inches='tight', dpi=200)
plt.savefig(f"{path}/outs_iqr_heatmap.pdf", bbox_inches='tight')
plt.clf()
    
# Dimensionality reduction
dim_red_models = {
    't-SNE': TSNE(n_components=2),
    'PCA': PCA(n_components=2, whiten=False),
    'IsoMap': Isomap(n_components=2, n_neighbors=5),
    'MDS': MDS(n_components=2, metric=True),
    'GRP': GaussianRandomProjection(n_components=2, eps=0.5),
    'SRP': SparseRandomProjection(n_components=2, density='auto', eps=0.5, dense_output=False),
}
feats_dim_red = []
for drm in dim_red_models:
    dim_red_res = dim_red_models[drm].fit_transform(df_proc.loc[:, feats_with_trgt].values)
    df_proc.loc[:, f"{drm} 1"] = dim_red_res[:, 0]
    df_proc.loc[:, f"{drm} 2"] = dim_red_res[:, 1]
    df_proc.loc[:, f"{drm} HDBSCAN"] = HDBSCAN(min_cluster_size=int(df_proc.shape[0] * 0.2)).fit(df_proc.loc[:, [f"{drm} 1", f"{drm} 2"]].values).labels_
    feats_dim_red += [ f"{drm} 1",  f"{drm} 2"]
n_rows = 2
n_cols = 3
fig_height = 10
fig_width = 15
sns.set_theme(style='ticks')
fig, axs = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), gridspec_kw={}, sharey=False, sharex=False, layout='constrained')
for drm_id, drm in enumerate(dim_red_models.keys()):
    row_id, col_id = divmod(drm_id, n_cols)
    scatter = sns.scatterplot(
        data=df_proc,
        x=f"{drm} 1",
        y=f"{drm} 2",
        # hue=f"{drm} HDBSCAN",
        hue='Пол',
        palette={'М': 'deepskyblue', 'Ж': 'hotpink'},
        linewidth=0.25,
        alpha=0.75,
        edgecolor="k",
        s=40,
        ax=axs[row_id, col_id],
    )
    axs[row_id, col_id].set_title(drm)
    # axs[n_rows - 1, n_cols - 1].axis('off')
fig.suptitle(feats_title, fontsize='large')   
fig.savefig(f"{path}/dim_red.png", bbox_inches='tight', dpi=200)
fig.savefig(f"{path}/dim_red.pdf", bbox_inches='tight')
df_proc.to_excel(f"{path}/df_proc.xlsx", index_label="ID")
plt.close(fig)


# Generate stratification

In [None]:
quantiles = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]

quantiles_all = pd.qcut(data[feat_trgt].values, quantiles, labels=False, duplicates='drop')
unique_all, counts_all = np.unique(quantiles_all, return_counts=True)
ids_all = data.index.values

k_fold_all = RepeatedStratifiedKFold(
    n_splits=tst_n_splits,
    n_repeats=tst_n_repeats,
    random_state=tst_random_state
)
splits_all = k_fold_all.split(X=ids_all, y=quantiles_all, groups=quantiles_all)
for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
    data.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
    data.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"

In [None]:
samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        data[f"Split_{split_id}_Fold_{fold_id}"] = data[f"Split_{split_id}"]
    
    samples[split_id] = {
        'test': data.index[data[f"Split_{split_id}"] == "tst"].values,
        'train_validation': data.index[data[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }
    
    ids_trnval = data.index[(data[f"Split_{split_id}"] == 'trn_val')].values
    
    quantiles_trnval = pd.qcut(data.loc[ids_trnval, feat_trgt].values, quantiles, labels=False, duplicates='drop')
    unique_trnval, counts_trnval = np.unique(quantiles_trnval, return_counts=True)
    k_fold_trnval = RepeatedStratifiedKFold(
        n_splits=val_n_splits,
        n_repeats=val_n_repeats,
        random_state=val_random_state
    )
    splits_trnval = k_fold_trnval.split(X=ids_trnval, y=quantiles_trnval, groups=quantiles_trnval)
    for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
        data.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
        data.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load stratification

In [None]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# PyTorch Tabular Model Sweep Training

## Load non-model configs

In [None]:
path_configs = "E:/Git/bbs/notebooks/millennium/configs"

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = [feat_trgt]
data_config['continuous_cols'] = feats_cnt
data_config['categorical_cols'] = feats_cat
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
trainer_config['checkpoints_path'] = path_models
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 1
lr_find_num_training = 128
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 4.0

## Models Search Spaces

### GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [6],
    "model_config__gflu_dropout": [0.1],
    "model_config__gflu_feature_init_sparsity": [0.3],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1899],
}
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for i, params in enumerate(ParameterGrid(search_space)):
    head_config_tmp = copy.deepcopy(head_config)
    head_config_tmp['dropout'] = params['model_config.head_config__dropout']
    model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
    model_config['gflu_stages'] = params['model_config__gflu_stages']
    model_config['gflu_feature_init_sparsity'] = params['model_config__gflu_feature_init_sparsity']
    model_config['gflu_dropout'] = params['model_config__gflu_dropout']
    model_config['learning_rate'] = params['model_config__learning_rate']
    model_config['seed'] = params['model_config__seed']
    model_config['head_config'] = head_config_tmp
    model_list.append(GANDALFConfig(**model_config))

## Perform model sweep

In [None]:
pathlib.Path(f"E:/Git/bbs/notebooks/millennium/pt/{feats_set}").mkdir(parents=True, exist_ok=True)

common_params = {
    "task": "regression",
}

seed = 1899

dfs_result = []
for split_id, split_dict in samples.items():
    for fold_id in split_dict['trains']:
        test = data.loc[split_dict['test'], feats + [feat_trgt]]
        train = data.loc[split_dict['trains'][fold_id], feats + [feat_trgt]]
        validation = data.loc[split_dict['validations'][fold_id], feats + [feat_trgt]]

        trainer_config['seed'] = seed
        trainer_config['checkpoints'] = 'valid_loss'
        trainer_config['load_best'] = True
        trainer_config['auto_lr_find'] = True
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sweep_df, best_model = model_sweep_custom(
                task="regression",
                train=train,
                validation=validation,
                test=test,
                data_config=data_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                model_list=model_list,
                common_model_args=common_params,
                metrics=["mean_absolute_error", "pearson_corrcoef"],
                metrics_params=[{}, {}],
                metrics_prob_input=[False, False],
                rank_metric=("mean_absolute_error", "lower_is_better"),
                return_best_model=True,
                seed=seed,
                progress_bar=False,
                verbose=False,
                suppress_lightning_logger=True,
                min_lr = lr_find_min_lr,
                max_lr = lr_find_max_lr,
                num_training = lr_find_num_training,
                mode = lr_find_mode,
                early_stop_threshold = lr_find_early_stop_threshold,
            )
        sweep_df['seed'] = seed
        sweep_df['split_id'] = split_id
        sweep_df['fold_id'] = fold_id
        sweep_df["train_more"] = False
        sweep_df.loc[(sweep_df["train_loss"] > sweep_df["test_loss"]) | (sweep_df["train_loss"] > sweep_df["validation_loss"]), "train_more"] = True
        sweep_df["validation_test_mean_loss"] = (sweep_df["validation_loss"] + sweep_df["test_loss"]) / 2.0
        sweep_df["train_validation_test_mean_loss"] = (sweep_df["train_loss"] + sweep_df["validation_loss"] + sweep_df["test_loss"]) / 3.0
        
        dfs_result.append(sweep_df)
        
        fn_suffix = (f"models({len(model_list)})_"
                     f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_"
                     f"val({val_random_state}_{val_n_splits}_{val_n_repeats})")
        try:
            df_result = pd.concat(dfs_result, ignore_index=True)
            df_result.sort_values(by=['test_loss'], ascending=[True], inplace=True)
            df_result.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{trainer_config['checkpoints_path']}/{fn_suffix}.xlsx")
        except PermissionError:
            pass

## Best models analysis

In [None]:
n_models = 1

explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

fn_sweep = (
    f"models({n_models})_"
    f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})"
)

df_sweeps = pd.read_excel(f"{path_models}/{fn_sweep}.xlsx", index_col=0)
path_to_candidates = f"{path_models}/candidates/{fn_sweep}"
pathlib.Path(path_to_candidates).mkdir(parents=True, exist_ok=True)
df_sweeps.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_to_candidates}/sweep.xlsx")

models_ids = [
296
]

df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_to_candidates}/selected.xlsx")

for model_id in models_ids:

    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]

    test = data.loc[split_dict['test'], feats + [feat_trgt]]
    train = data.loc[split_dict['trains'][fold_id], feats + [feat_trgt]]
    validation = data.loc[split_dict['validations'][fold_id], feats + [feat_trgt]]

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_to_candidates}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_to_candidates}/{model_id}", dirs_exist_ok=True)
    
    df = data.loc[:, [feat_trgt]]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = model.predict(data)
    df['Error'] = df['Prediction'] - df[feat_trgt]
    corrector = LinearBiasCorrector()
    corrector.fit(df.loc[df['Group'] == 'Train', feat_trgt].values, df.loc[df['Group'] == 'Train', 'Prediction'].values)
    df['Prediction Unbiased'] = corrector.predict(df['Prediction'].values)
    df['Error Unbiased'] = df['Prediction Unbiased'] - df[feat_trgt]
    df.to_excel(f"{path_to_candidates}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=[
            'mean_absolute_error', 'pearson_corrcoef', 'bias',
            'mean_absolute_error_unbiased', 'pearson_corrcoef_unbiased', 'bias_unbiased'
        ]
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        pred_unbiased = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction Unbiased'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, feat_trgt].values.astype(np.float32))
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
        df_metrics.at[group, 'mean_absolute_error_unbiased'] = mean_absolute_error(pred_unbiased, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef_unbiased'] = pearson_corrcoef(pred_unbiased, real).numpy()
        df_metrics.at[group, 'bias_unbiased'] = np.mean(df.loc[df['Group'] == group, 'Error Unbiased'].values)
    df_metrics.to_excel(f"{path_to_candidates}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    xy_min = df[[feat_trgt, 'Prediction']].min().min()
    xy_max = df[[feat_trgt, 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    
    xy_min_unbiased = df[[feat_trgt, 'Prediction Unbiased']].min().min()
    xy_max_unbiased = df[[feat_trgt, 'Prediction Unbiased']].max().max()
    xy_ptp_unbiased = xy_max_unbiased - xy_min_unbiased
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    for group in colors_groups.keys():    
        regplot = sns.regplot(
            data=df.loc[df['Group'] == group, :],
            x=feat_trgt,
            y="Prediction Unbiased",
            label=group,
            color=colors_groups[group],
            scatter_kws=dict(
                linewidth=0.2,
                alpha=0.75,
                edgecolor="k",
                s=20,
            ),
            ax=ax
        )
    bisect = sns.lineplot(
        x=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        y=[xy_min_unbiased - 0.1 * xy_ptp_unbiased, xy_max_unbiased + 0.1 * xy_ptp_unbiased],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    ax.set_ylim(xy_min_unbiased - 0.1 * xy_ptp, xy_max_unbiased + 0.1 * xy_ptp_unbiased)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/regplot_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))   
    scatter = sns.scatterplot(
        data=df,
        x=feat_trgt,
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error Unbiased', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error_unbiased']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef_unbiased']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias_unbiased']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error Unbiased',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_to_candidates}/{model_id}/violin_unbiased.pdf", bbox_inches='tight')
    plt.close(fig)
    
    try:
        explanation = model.explain(data, method=explain_method, baselines=explain_baselines)
        explanation.index = data.index
        explanation.to_excel(f"{path_to_candidates}/{model_id}/explanation.xlsx")
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="violin",
            show=False,
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
        plt.close(fig)
        
        sns.set_theme(style='ticks')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="bar",
            show=False,
            plot_size=[12,8]
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass