# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from sklearn.model_selection import BaseCrossValidator, ParameterGrid, ParameterSampler
import torch
import pickle
import shap
from sklearn.model_selection import RepeatedStratifiedKFold
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash
import pathlib


# Load data

In [None]:
epi_dataa_type = 'no_harm'

selection_method = 'f_regression' # 'f_regression' 'spearman'
n_feats = 100

imm = 'CXCL9'

tst_n_splits = 5
tst_n_repeats = 5
tst_random_state = 1337

val_n_splits = 4
val_n_repeats = 2
val_random_state = 1337

path_data = f"D:/YandexDisk/Work/bbd/immunology/003_EpImAge/{epi_dataa_type}/{selection_method}_{n_feats}/{imm}"
pathlib.Path(f"{path_data}/pytorch_tabular").mkdir(parents=True, exist_ok=True)
path_configs = "D:/Work/bbs/notebooks/immunology/003_EpImAge/immuno_regression_configs"
data = pd.read_excel(f"{path_data}/data.xlsx", index_col=0)
feats = pd.read_excel(f"{path_data}/feats_con.xlsx", index_col=0).index.values.tolist()

stratify_cat_parts_all = {
    'ctrl_central': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Central')].values,
    'ctrl_yakutia': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Yakutia')].values,
    'esrd': data.index[(data['Status'] == 'ESRD')].values,
}

for part_all, ids_all in stratify_cat_parts_all.items():
    trgt_all = data.loc[ids_all, 'Age'].values
    ptp_all = np.ptp(trgt_all)
    num_bins_all = 5
    bins_all = np.linspace(np.min(trgt_all) - 0.1 * ptp_all, np.max(trgt_all) + 0.1 * ptp_all, num_bins_all + 1)
    binned_all = np.digitize(trgt_all, bins_all) - 1
    unique_all, counts_tst = np.unique(binned_all, return_counts=True)
    k_fold_all = RepeatedStratifiedKFold(
        n_splits=tst_n_splits,
        n_repeats=tst_n_repeats,
        random_state=tst_random_state
    )
    splits_all = k_fold_all.split(X=ids_all, y=binned_all, groups=binned_all)
    
    for split_id, (ids_trn_val, ids_tst) in enumerate(splits_all):
        data.loc[ids_all[ids_trn_val], f"Split_{split_id}"] = "trn_val"
        data.loc[ids_all[ids_tst], f"Split_{split_id}"] = "tst"

samples = {}
for split_id in range(tst_n_splits * tst_n_repeats):
    samples[split_id] = {
        'test': data.index[data[f"Split_{split_id}"] == "tst"].values,
        'train_validation': data.index[data[f"Split_{split_id}"] == "trn_val"].values,
        'trains': {},
        'validations': {},
    }

    stratify_cat_parts_trnval = {
        'ctrl_central': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Central') & (data[f"Split_{split_id}"] == 'trn_val')].values,
        'ctrl_yakutia': data.index[(data['Status'] == 'Control') & (data['Region'] == 'Yakutia') & (data[f"Split_{split_id}"] == 'trn_val')].values,
        'esrd': data.index[(data['Status'] == 'ESRD') & (data[f"Split_{split_id}"] == 'trn_val')].values,
    }

    for part_trnval, ids_trnval in stratify_cat_parts_trnval.items():
        trgt_trnval = data.loc[ids_trnval, 'Age'].values
        ptp_trnval = np.ptp(trgt_trnval)
        num_bins_trnval = 5
        bins_trnval = np.linspace(np.min(trgt_trnval) - 0.1 * ptp_trnval, np.max(trgt_trnval) + 0.1 * ptp_trnval, num_bins_trnval + 1)
        binned_trnval = np.digitize(trgt_trnval, bins_trnval) - 1
        unique_trnval, counts_trnval = np.unique(binned_trnval, return_counts=True)
        k_fold_trnval = RepeatedStratifiedKFold(
            n_splits=val_n_splits,
            n_repeats=val_n_repeats,
            random_state=val_random_state
        )
        splits_trnval = k_fold_trnval.split(X=ids_trnval, y=binned_trnval, groups=binned_trnval)
        
        for fold_id, (ids_trn, ids_val) in enumerate(splits_trnval):
            data.loc[ids_trnval[ids_trn], f"Split_{split_id}_Fold_{fold_id}"] = "trn"
            data.loc[ids_trnval[ids_val], f"Split_{split_id}_Fold_{fold_id}"] = "val"
         
    for fold_id in range(val_n_splits * val_n_repeats):
        samples[split_id]['trains'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "trn"].values
        samples[split_id]['validations'][fold_id] = data.index[data[f"Split_{split_id}_Fold_{fold_id}"] == "val"].values

    samples[split_id]['cv_indexes'] = [
        (
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'trn')]))[0],
            np.where(data.index[data[f"Split_{split_id}"] == "trn_val"].isin(data.index[(data[f"Split_{split_id}"] == "trn_val") & (data[f"Split_{split_id}_Fold_{i}"] == 'val')]))[0],
        )
        for i in range(val_n_splits * val_n_repeats)
    ]

with open(f"{path_data}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'wb') as handle:
    pickle.dump(samples, handle, protocol=pickle.HIGHEST_PROTOCOL)


## Check samples intersection

In [None]:

for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

## Load non-model configs

In [None]:
data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['target'] = [f"{imm}_log"]
data_config['continuous_cols'] = feats
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints_path'] = f"{path_data}/pytorch_tabular"
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

lr_find_min_lr = 1e-8
lr_find_max_lr = 10
lr_find_num_training = 256
lr_find_mode = "exponential"
lr_find_early_stop_threshold = 8.0

# Models Search Spaces

## GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [3, 6, 9],
    "model_config__gflu_dropout": [0.0, 0.1],
    "model_config__gflu_feature_init_sparsity": [0.2, 0.3, 0.4],
    "model_config.head_config__dropout": [0.0, 0.1],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [1337],
}
grid_size = np.prod([len(p_vals) for _, p_vals in search_space.items()])
print(grid_size)

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for i, params in enumerate(ParameterGrid(search_space)):
    head_config_tmp = copy.deepcopy(head_config)
    head_config_tmp['dropout'] = params['model_config.head_config__dropout']
    model_config = read_parse_config(f"{path_configs}/models/GANDALFConfig.yaml", GANDALFConfig)
    model_config['gflu_stages'] = params['model_config__gflu_stages']
    model_config['gflu_feature_init_sparsity'] = params['model_config__gflu_feature_init_sparsity']
    model_config['gflu_dropout'] = params['model_config__gflu_dropout']
    model_config['learning_rate'] = params['model_config__learning_rate']
    model_config['seed'] = params['model_config__seed']
    model_config['head_config'] = head_config_tmp
    model_list.append(GANDALFConfig(**model_config))
    

# Model Sweep Training

## Perform model sweep

In [None]:
%%capture

common_params = {
    "task": "regression",
}

seeds = [1337] # [1337, 55763, 40279, 87571, 234461]

dfs_result = []
for seed in seeds:
    for split_id, split_dict in samples.items():
        for fold_id in split_dict['trains']:
            test = data.loc[split_dict['test'], feats + [f"{imm}_log"]]
            train = data.loc[split_dict['trains'][fold_id], feats + [f"{imm}_log"]]
            validation = data.loc[split_dict['validations'][fold_id], feats + [f"{imm}_log"]]

            trainer_config['seed'] = seed
            trainer_config['checkpoints'] = 'valid_loss'
            trainer_config['load_best'] = True
            trainer_config['auto_lr_find'] = True
            
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                sweep_df, best_model = model_sweep_custom(
                    task="regression",
                    train=train,
                    validation=validation,
                    test=test,
                    data_config=data_config,
                    optimizer_config=optimizer_config,
                    trainer_config=trainer_config,
                    model_list=model_list,
                    common_model_args=common_params,
                    metrics=["mean_absolute_error", "pearson_corrcoef"],
                    metrics_params=[{}, {}],
                    metrics_prob_input=[False, False],
                    rank_metric=("mean_absolute_error", "lower_is_better"),
                    return_best_model=True,
                    seed=seed,
                    progress_bar=False,
                    verbose=False,
                    suppress_lightning_logger=True,
                    min_lr = lr_find_min_lr,
                    max_lr = lr_find_max_lr,
                    num_training = lr_find_num_training,
                    mode = lr_find_mode,
                    early_stop_threshold = lr_find_early_stop_threshold,
                )
            sweep_df['seed'] = seed
            sweep_df['split_id'] = split_id
            sweep_df['fold_id'] = fold_id
            sweep_df["train_more"] = False
            sweep_df.loc[(sweep_df["train_loss"] > sweep_df["test_loss"]) | (sweep_df["train_loss"] > sweep_df["validation_loss"]), "train_more"] = True
            sweep_df["validation_test_mean_loss"] = (sweep_df["validation_loss"] + sweep_df["test_loss"]) / 2.0
            sweep_df["train_validation_test_mean_loss"] = (sweep_df["train_loss"] + sweep_df["validation_loss"] + sweep_df["test_loss"]) / 3.0
            
            dfs_result.append(sweep_df)
            
            fn_suffix = (f"models({len(model_list)})_"
                         f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})_"
                         f"{best_model.config['lr_scheduler']}_{best_model.config['continuous_feature_transform']}")
            try:
                df_result = pd.concat(dfs_result, ignore_index=True)
                df_result.style.background_gradient(
                    subset=[
                        "train_loss",
                        "validation_loss",
                        "test_loss",
                        "time_taken",
                        "time_taken_per_epoch"
                    ], cmap="RdYlGn_r"
                ).to_excel(f"{trainer_config['checkpoints_path']}/{fn_suffix}.xlsx")
            except PermissionError:
                pass

## Save best models

In [None]:
with open(f"{path_data}/samples_tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats}).pickle", 'rb') as handle:
    samples = pickle.load(handle)
    
for split_id in range(tst_n_splits * tst_n_repeats):
    for fold_id in range(val_n_splits * val_n_repeats):
        test_samples = samples[split_id]['test']
        train_samples = samples[split_id]['trains'][fold_id]
        validation_samples = samples[split_id]['validations'][fold_id]

        intxns = {
            'train_validation': set.intersection(set(train_samples), set(validation_samples)),
            'validation_test': set.intersection(set(validation_samples), set(test_samples)),
            'train_test': set.intersection(set(train_samples), set(test_samples))
        }
        
        for intxn_name, intxn_samples in intxns.items():
            if len(intxn_samples) > 0:
                print(f"Non-zero {intxn_name} intersection ({len(intxn_samples)}) for {split_id} Split and {fold_id} Fold!")

In [None]:
%%capture

n_models = 36

fn_sweep = (
    f"models({n_models})_"
    f"tst({tst_random_state}_{tst_n_splits}_{tst_n_repeats})_val({val_random_state}_{val_n_splits}_{val_n_repeats})_"
    f"{optimizer_config['lr_scheduler']}_{data_config['continuous_feature_transform']}"
)
df_sweeps = pd.read_excel(f"{trainer_config['checkpoints_path']}/{fn_sweep}.xlsx", index_col=0)
path_models = f"{trainer_config['checkpoints_path']}/candidates/{fn_sweep}"
pathlib.Path(path_models).mkdir(parents=True, exist_ok=True)
df_sweeps.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/sweep.xlsx")

models_ids = [
    1620,
    1626,
    5942,
    5940,
    1629,
    1632,
    5943,
    1623,
    5950,
    5951,
    5941,
    5948,
    1637,
    1622,
]

df_sweeps.loc[models_ids, :].style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{path_models}/selected.xlsx")

explain_method = "GradientShap"
explain_baselines = "b|1000"
explain_n_feats_to_plot = 25

for model_id in models_ids:
    split_id = df_sweeps.at[model_id, 'split_id']
    fold_id = df_sweeps.at[model_id, 'fold_id']
    split_dict = samples[split_id]
    
    test = data.loc[split_dict['test'], feats + [f"{imm}_log"]]
    train = data.loc[split_dict['trains'][fold_id], feats + [f"{imm}_log"]]
    validation = data.loc[split_dict['validations'][fold_id], feats + [f"{imm}_log"]]
    
    tabular_model = TabularModel(
        data_config=data_config,
        model_config=ast.literal_eval(df_sweeps.at[model_id, 'params']),
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=True,
        suppress_lightning_logger=True
    )
    datamodule = tabular_model.prepare_dataloader(
        train=train,
        validation=validation,
        seed=df_sweeps.at[model_id, 'seed'],
    )
    model = tabular_model.prepare_model(
        datamodule
    )
    tabular_model._prepare_for_training(
        model,
        datamodule
    )
    tabular_model.load_weights(df_sweeps.at[model_id, 'checkpoint'])
    tabular_model.evaluate(test, verbose=False)
    tabular_model.save_model(f"{path_models}/{model_id}")
    
    loaded_model = TabularModel.load_model(f"{path_models}/{model_id}")
    
    df = data.loc[:, data_config['target']]
    df.loc[train.index, 'Group'] = 'Train'
    df.loc[validation.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = loaded_model.predict(data)
    df['Error'] = df['Prediction'] - df[data_config['target'][0]]
    df.to_excel(f"{path_models}/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'dodgerblue',
        'Test': 'crimson',
    }
    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=['mean_absolute_error', 'pearson_corrcoef', 'bias']
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, data_config['target'][0]].values)
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'bias'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
    df_metrics.to_excel(f"{path_models}/{model_id}/metrics.xlsx", index_label="Metrics")
    
    sns.set_theme(style='whitegrid')
    xy_min = df[[data_config['target'][0], 'Prediction']].min().min()
    xy_max = df[[data_config['target'][0], 'Prediction']].max().max()
    xy_ptp = xy_max - xy_min
    fig, ax = plt.subplots(figsize=(4.5, 4))
    scatter = sns.scatterplot(
        data=df,
        x=data_config['target'][0],
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        y=[xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{df_sweeps.at[model_id, 'model']} ({df_sweeps.at[model_id, '# Params']} params, {df_sweeps.at[model_id, 'epochs']} epochs)")
    ax.set_xlim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    ax.set_ylim(xy_min - 0.1 * xy_ptp, xy_max + 0.1 * xy_ptp)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{path_models}/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'bias']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{path_models}/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{path_models}/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)
    
    try:
        explanation = loaded_model.explain(data, method=explain_method, baselines=explain_baselines)
        explanation.index = data.index
        explanation.to_excel(f"{path_models}/{model_id}/explanation.xlsx")
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="violin",
            show=False,
        )
        plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_models}/{model_id}/explain_beeswarm.pdf", bbox_inches='tight')
        plt.close(fig)
        
        sns.set_theme(style='whitegrid')
        fig = shap.summary_plot(
            shap_values=explanation.loc[:, feats].values,
            features=data.loc[:, feats].values,
            feature_names=feats,
            max_display=explain_n_feats_to_plot,
            plot_type="bar",
            show=False,
        )
        plt.savefig(f"{path_models}/{model_id}/explain_bar.png", bbox_inches='tight', dpi=200)
        plt.savefig(f"{path_models}/{model_id}/explain_bar.pdf", bbox_inches='tight')
        plt.close(fig)
    
    except NotImplementedError:
        pass