# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
from sklearn.preprocessing import LabelEncoder 
import shutil
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib
from tqdm import tqdm
import distinctipy
import matplotlib.patheffects as pe
import matplotlib.colors as mcolors
from statannotations.Annotator import Annotator
from scipy.stats import mannwhitneyu
from regression_bias_corrector import LinearBiasCorrector
import missingno as msno
from scipy import stats
from statsmodels.stats.multitest import multipletests
from sklearn.decomposition import PCA
from sklearn.random_projection import GaussianRandomProjection, SparseRandomProjection
from sklearn.manifold import MDS, Isomap, TSNE
from pytorch_tabular.utils import get_balanced_sampler
from sklearn.cluster import DBSCAN, HDBSCAN
from torchmetrics.functional.classification import (
    multiclass_accuracy,
    multiclass_f1_score,
    multiclass_precision,
    multiclass_recall,
    multiclass_specificity,
    multiclass_cohen_kappa,
    multiclass_auroc
)
from sklearn.metrics import confusion_matrix
from src.plot.radar import radar_factory


def make_rgb_transparent(rgb, bg_rgb, alpha):
    return [alpha * c1 + (1 - alpha) * c2 for (c1, c2) in zip(rgb, bg_rgb)]

# Load data

In [None]:
feats_set = 'paper_sex_hormones/F/classification'
feats_title = ''

path = f"E:/YandexDisk/Work/bbd/millennium/{feats_set}"
path_models = f"E:/Git/bbs/notebooks/millennium/pt/{feats_set}"

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337

val_n_splits = 4
val_n_repeats = 4
val_random_state = 1337

data = pd.read_excel(f"{path}/data.xlsx", index_col=0)
df_feats = pd.read_excel(f"{path}/feats.xlsx", index_col=0)
feat_trgt = 'Menopause'
data[f'{feat_trgt} code'] = LabelEncoder().fit_transform(data[feat_trgt])
feats_cnt = df_feats.index.to_list()
feats_cat = []
feats = list(feats_cnt) + list(feats_cat)
feats_with_trgt = [feat_trgt] + feats

## Check duplicated indexes

In [None]:
data.loc[data.index.duplicated(), [feat_trgt]]

# Generate stratification

In [None]:
classes_all = data[f'{feat_trgt} code'].values
ids_all = data.index.values

k_fold_all = RepeatedStratifiedKFold(
    n_splits=tst_n_splits,
    n_repeats=tst_n_repeats,
    random_state=tst_random_state
)
splits_all = k_fold_all.split(X=ids_all, y=classes_all, groups=classes_all)
for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
    data.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
    data.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"

In [None]:
samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        data[f"Split_{split_id}_Fold_{fold_id}"] = data[f"Split_{split_id}"]
    
    samples[split_id] = {
        'test': data.index[data[f"Split_{split_id}"] == "tst"].values,
        'train_validation': data.index[data[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }
    
    ids_trnval = data.index[(data[f"Split_{split_id}"] == 'trn_val')].values
    
    classes_trnval = data.loc[ids_trnval, f'{feat_trgt} code'].values
    k_fold_trnval = RepeatedStratifiedKFold(
        n_splits=val_n_splits,
        n_repeats=val_n_repeats,
        random_state=val_random_state
    )
    splits_trnval = k_fold_trnval.split(X=ids_trnval, y=classes_trnval, groups=classes_trnval)
    for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
        data.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
        data.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load stratification

In [None]:
with open(f"{path}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

# PyTorch Tabular Model Sweep Training

## Load non-model configs

In [None]:
path_configs = "E:/Git/bbs/notebooks/millennium/configs/classification"

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = [feat_trgt]
data_config['continuous_cols'] = feats_cnt
data_config['categorical_cols'] = feats_cat
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
trainer_config['checkpoints_path'] = path_models
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 1
lr_find_num_training = 128
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 4.0

## Models Search Spaces

### GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [6],
    "model_config__gflu_dropout": [0.1],
    "model_config__gflu_feature_init_sparsity": [0.3],
    "model_config.head_config__dropout": [0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1899],
}
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for i, params in enumerate(ParameterGrid(search_space)):
    head_config_tmp = copy.deepcopy(head_config)
    head_config_tmp['dropout'] = params['model_config.head_config__dropout']
    model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
    model_config['gflu_stages'] = params['model_config__gflu_stages']
    model_config['gflu_feature_init_sparsity'] = params['model_config__gflu_feature_init_sparsity']
    model_config['gflu_dropout'] = params['model_config__gflu_dropout']
    model_config['learning_rate'] = params['model_config__learning_rate']
    model_config['seed'] = params['model_config__seed']
    model_config['head_config'] = head_config_tmp
    model_list.append(GANDALFConfig(**model_config))

## Perform model sweep

In [None]:
pathlib.Path(f"E:/Git/bbs/notebooks/millennium/pt/{feats_set}").mkdir(parents=True, exist_ok=True)

common_params = {
    "task": "classification",
}

seed = 1899

dfs_result = []
for split_id, split_dict in samples.items():
    for fold_id in split_dict['trains']:
        test = data.loc[split_dict['test'], feats + [feat_trgt]]
        train = data.loc[split_dict['trains'][fold_id], feats + [feat_trgt]]
        validation = data.loc[split_dict['validations'][fold_id], feats + [feat_trgt]]
        
        train_sampler_balanced = get_balanced_sampler(train[feat_trgt].values.ravel())

        trainer_config['seed'] = seed
        trainer_config['checkpoints'] = 'valid_loss'
        trainer_config['load_best'] = True
        trainer_config['auto_lr_find'] = True
        
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            sweep_df, best_model = model_sweep_custom(
                task="classification",
                train=train,
                validation=validation,
                test=test,
                data_config=data_config,
                optimizer_config=optimizer_config,
                trainer_config=trainer_config,
                model_list=model_list,
                common_model_args=common_params,
                metrics=[
                    "accuracy",
                    "f1_score",
                    "precision",
                    "recall",
                    "specificity",
                    "cohen_kappa",
                    "auroc"
                ],
                metrics_params=[
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                    {'task': 'multiclass', 'num_classes': 2},
                    {'task': 'multiclass', 'num_classes': 2, 'average': 'macro'},
                ],
                metrics_prob_input=[True, True, True, True, True, True, True],
                rank_metric=("accuracy", "higher_is_better"),
                return_best_model=True,
                seed=seed,
                progress_bar=False,
                verbose=False,
                suppress_lightning_logger=True,
                min_lr = lr_find_min_lr,
                max_lr = lr_find_max_lr,
                num_training = lr_find_num_training,
                mode = lr_find_mode,
                early_stop_threshold = lr_find_early_stop_threshold,
                train_sampler=train_sampler_balanced
            )
        sweep_df['seed'] = seed
        sweep_df['split_id'] = split_id
        sweep_df['fold_id'] = fold_id
        sweep_df["train_more"] = False
        sweep_df.loc[(sweep_df["train_loss"] > sweep_df["test_loss"]) | (sweep_df["train_loss"] > sweep_df["validation_loss"]), "train_more"] = True
        sweep_df["validation_test_mean_loss"] = (sweep_df["validation_loss"] + sweep_df["test_loss"]) / 2.0
        sweep_df["train_validation_test_mean_loss"] = (sweep_df["train_loss"] + sweep_df["validation_loss"] + sweep_df["test_loss"]) / 3.0
        
        dfs_result.append(sweep_df)
        
        fn_suffix = (f"models({len(model_list)})_"
                     f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_"
                     f"val({val_random_state}_{val_n_splits}_{val_n_repeats})")
        try:
            df_result = pd.concat(dfs_result, ignore_index=True)
            df_result.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{trainer_config['checkpoints_path']}/{fn_suffix}.xlsx")
        except PermissionError:
            pass

## Best models analysis

In [None]:
class_names = ["No", "Yes"]

n_models = 1

explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

fn_sweep = (
    f"models({n_models})_"
    f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})"
)

df_sweeps = pd.read_excel(f"{path_models}/{fn_sweep}.xlsx", index_col=0)
path_to_candidates = f"{path_models}/candidates/{fn_sweep}"
pathlib.Path(path_to_candidates).mkdir(parents=True, exist_ok=True)
df_sweeps.style.background_gradient(cmap="RdYlGn_r").to_excel(f"{path_to_candidates}/sweep.xlsx")

models_ids = [
6,
359,
39,
376,
]

df_sweeps.loc[models_ids, :].style.background_gradient(cmap="RdYlGn_r").to_excel(f"{path_to_candidates}/selected.xlsx")

for model_id in models_ids:

    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]

    test = data.loc[split_dict['test'], feats + [feat_trgt]]
    train = data.loc[split_dict['trains'][fold_id], feats + [feat_trgt]]
    validation = data.loc[split_dict['validations'][fold_id], feats + [feat_trgt]]

    model_dir = str(pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).parent).replace('\\', '/') + '/' + pathlib.Path(df_sweeps.at[model_id, 'checkpoint']).stem
    model = TabularModel.load_model(model_dir)
    pathlib.Path(f"{path_to_candidates}/{model_id}").mkdir(parents=True, exist_ok=True)
    shutil.copytree(model_dir, f"{path_to_candidates}/{model_id}", dirs_exist_ok=True)
    
    df = data.loc[:, [feat_trgt]]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df = pd.concat(
        [
            df,
            model.predict(data),
            model.predict(data, ret_logits=True).loc[:, ['logits_0', 'logits_1']]
        ],
        axis=1
    )
    df.rename(columns={'Menopause_prediction': 'Prediction', 'logits_0': 'No_logits', 'logits_1': 'Yes_logits'},
              inplace=True)
    df['Prediction ID'] = df['Prediction']
    df['Prediction ID'].replace({'No': 0, 'Yes': 1}, inplace=True)
    df['Real ID'] = df['Menopause']
    df['Real ID'].replace({'No': 0, 'Yes': 1}, inplace=True)
    df.to_excel(f"{path_to_candidates}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    metrics_w_avg = [
        "accuracy",
        "f1_score",
        "precision",
        "recall",
        "specificity",
        "auroc"
    ]
    metrics_wo_avg = [
        "cohen_kappa"
    ]
    metrics_names = {
        "accuracy": "Accuracy",
        "f1_score": "F-1 Score",
        "precision": "Precision",
        "recall": "Recall",
        "specificity": "Specificity",
        "auroc": "AUROC",
        "cohen_kappa": "Cohen Kappa"
    }
    
    df_metrics = pd.DataFrame(
        index=[f"{m}_macro" for m in metrics_w_avg] +
              [f"{m}_weighted" for m in metrics_w_avg] +
              metrics_wo_avg,
        columns=list(colors_groups.keys()),
        data=np.zeros((len(metrics_w_avg) * 2 + 1, len(colors_groups))),
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction ID'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, 'Real ID'].values)
        probs = torch.from_numpy(df.loc[df['Group'] == group, ['Menopause_No_probability', 'Menopause_Yes_probability']].values)
        for avg_type in ['macro', 'weighted']:
            df_metrics.at[f"accuracy_{avg_type}", group] = multiclass_accuracy(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"f1_score_{avg_type}", group] = multiclass_f1_score(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"precision_{avg_type}", group] = multiclass_precision(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"recall_{avg_type}", group] = multiclass_recall(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"specificity_{avg_type}", group] = multiclass_specificity(preds=pred, target=real, num_classes=2, average=avg_type).numpy()
            df_metrics.at[f"auroc_{avg_type}", group] = multiclass_auroc(preds=probs, target=real, num_classes=2, average=avg_type).numpy()
        df_metrics.at["cohen_kappa", group] = multiclass_cohen_kappa(preds=pred, target=real, num_classes=2).numpy()
        
        conf_mtx = confusion_matrix(real, pred)
        cm_sum = np.sum(conf_mtx, axis=1, keepdims=True)
        cm_perc = conf_mtx / cm_sum.astype(float) * 100
        annot = np.empty_like(conf_mtx).astype(str)
        nrows, ncols = conf_mtx.shape
        for i in range(nrows):
            for j in range(ncols):
                c = conf_mtx[i, j]
                p = cm_perc[i, j]
                if i == j:
                    s = cm_sum[i]
                    annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
                elif c == 0:
                    annot[i, j] = ''
                else:
                    annot[i, j] = '%.1f%%\n%d' % (p, c)
        conf_mtx = pd.DataFrame(conf_mtx, index=class_names, columns=class_names)
        conf_mtx.index.name = 'Actual'
        conf_mtx.columns.name = 'Predicted'
        fig, ax = plt.subplots(figsize=(1.5*len(class_names), 0.8*len(class_names)))
        heatmap = sns.heatmap(conf_mtx, annot=annot, fmt='', ax=ax)
        heatmap.set_aspect('equal', adjustable='box')
        fig.savefig(f"{path_to_candidates}/{model_id}/confusion_matrix_{group}.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_to_candidates}/{model_id}/confusion_matrix_{group}.pdf", bbox_inches='tight')
        plt.close(fig)
        
    df_metrics.to_excel(f"{path_to_candidates}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    
    
    for avg_type in ['macro', 'weighted']:
        n_categories = len(metrics_w_avg) + len(metrics_wo_avg)
        theta = radar_factory(n_categories, frame='polygon')
        
        case_data = df_metrics.loc[[f"{m}_{avg_type}" for m in metrics_w_avg] + metrics_wo_avg, list(colors_groups.keys())].T.values
        
        fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar'))
        ax.set_rgrids(list(np.linspace(0, 1, 21)))
        for d, group in zip(case_data, colors_groups):
            ax.plot(theta, d, color=colors_groups[group])
            ax.fill(theta, d, facecolor=colors_groups[group], alpha=0.25, label='_nolegend_')
        ax.set_varlabels([metrics_names[m_name] for m_name in metrics_w_avg + metrics_wo_avg])
        labels = (list(colors_groups.keys()))
        legend = ax.legend(labels, loc=(0.9, .95), labelspacing=0.1, fontsize='small')
        fig.savefig(f"{path_to_candidates}/{model_id}/metrics_{avg_type}.png", bbox_inches='tight', dpi=200)
        fig.savefig(f"{path_to_candidates}/{model_id}/metrics_{avg_type}.pdf", bbox_inches='tight')
        plt.close(fig)
    
    try:
        explanation = model.explain(data, method=explain_method, baselines=explain_baselines)
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=len(feats),
            plot_type="violin",
            show=False,
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_logits_beeswarm.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_logits_beeswarm.pdf", bbox_inches='tight')
        plt.close(fig)
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=len(feats),
            plot_type="bar",
            show=False,
        )
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_logits_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_to_candidates}/{model_id}/explain_logits_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass