# Debugging autoreload

In [None]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
import torch
from glob import glob
import ast
import matplotlib.pyplot as plt
import seaborn as sns
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig, DANetConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from torchmetrics.functional.regression import mean_absolute_error, pearson_corrcoef
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
from src.pt.model_sweep import model_sweep_custom
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash


# Load data

In [None]:
path_data = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/060_EpiSImAge/SImAge_repeat"
path_configs = "D:/Work/bbs/notebooks/immunology/001_pytorch_tabular_SImAge_repeat"
data = pd.read_excel(f"{path_data}/data.xlsx", index_col=1)
feats = pd.read_excel(f"{path_data}/feats_con10.xlsx", index_col=0).index.values.tolist()
cv_df = pd.read_excel(f"{path_data}/cv_ids.xlsx", index_col=0)
cv_df = cv_df.loc[data.index, :]
train_only = data.loc[cv_df.index[cv_df['fold_0002'] == 'trn'].values, feats + ['Age']]
validation_only = data.loc[cv_df.index[cv_df['fold_0002'] == 'val'].values, feats + ['Age']]
train_validation = data.loc[data["Dataset"] == "Train/Validation", feats + ['Age']]
test = data.loc[data["Dataset"] == "Test Controls", feats + ['Age']]
cv_indexes = [
    (
        np.where(train_validation.index.isin(cv_df.index[cv_df[f"fold_{i:04d}"] == 'trn']))[0],
        np.where(train_validation.index.isin(cv_df.index[cv_df[f"fold_{i:04d}"] == 'val']))[0],
    ) 
    for i in range(5)
]

# Models Search Spaces

## GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [5, 10, 15, 20, 25, 30, 35],
    "model_config__gflu_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__gflu_feature_init_sparsity": [0.1, 0.2, 0.3, 0.4, 0.5],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/GANDALFConfig.yaml", GANDALFConfig)
print(np.prod([len(p_vals) for p_name, p_vals in search_space.items()]))

## CategoryEmbeddingModel Search Space

In [None]:
search_space = {
    "model_config__layers": ["256-128-64", "512-256-128", "32-16", "32-32-16", "16-8", "32-16-8", "128-64", "128-128", "16-16"],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)
print(np.prod([len(p_vals) for _, p_vals in search_space.items()]))

## TabNetModel Search Space

In [None]:
search_space = {
    "model_config__n_d": [8, 16, 24, 32, 40, 48],
    "model_config__n_a": [8, 16, 24, 32, 40, 48],
    "model_config__n_steps": [3, 5, 7],
    "model_config__gamma": [1.3, 1.4, 1.5, 1.6, 1.7, 1.8],
    "model_config__n_independent": [1, 2, 3, 4, 5],
    "model_config__n_shared": [1, 2, 3, 4, 5],
    "model_config__mask_type": ["sparsemax", "entmax"],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/TabNetModelConfig.yaml", TabNetModelConfig)
print(np.prod([len(p_vals) for _, p_vals in search_space.items()]))

## FTTransformer Search Space

In [None]:
search_space = {
    "model_config__num_heads": [2, 4, 8, 16, 32],
    "model_config__num_attn_blocks": [4, 6, 8, 10, 12],
    "model_config__attn_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__add_norm_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__ff_dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/FTTransformerConfig.yaml", FTTransformerConfig)
print(np.prod([len(p_vals) for _, p_vals in search_space.items()]))

## DANet Search Space

In [None]:
search_space = {
    "model_config__n_layers": [4, 8, 16, 20, 32],
    "model_config__abstlay_dim_1": [8, 16, 32, 64],
    "model_config__k": [3, 4, 5, 6, 7],
    "model_config__dropout_rate": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config.head_config__dropout": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__learning_rate": [0.001],
    "model_config__seed": [42, 1337, 666],
}
model_config = read_parse_config(f"{path_configs}/DANetConfig.yaml", DANetConfig)
print(np.prod([len(p_vals) for _, p_vals in search_space.items()]))

# Grid Search and Random Search

In [None]:
%%capture

strategy = 'random_search' # 'grid_search'
seed = 456456456
n_random_trials = 500
is_cross_validation = False

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
data_config['continuous_feature_transform'] = 'quantile_normal'
data_config['normalize_continuous_features'] = True
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
trainer_config['auto_lr_find'] = True
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if is_cross_validation:
        result = tuner.tune(
        train=train_validation,
        validation=None,
        search_space=search_space,
        metric="mean_absolute_error",
        mode="min",
        strategy=strategy,
        n_trials=n_random_trials,
        cv=cv_indexes,
        return_best_model=True,
        verbose=False,
        progress_bar=True,
        random_state=seed,
    )
    else:
        result = tuner.tune(
            train=train_only,
            validation=validation_only,
            search_space=search_space,
            metric="mean_absolute_error",
            mode="min",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=None,
            return_best_model=True,
            verbose=False,
            progress_bar=False,
            random_state=seed,
        )

result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/{model_config['_model_name']}_{strategy}_{seed}_{optimizer_config['lr_scheduler']}.xlsx")

# Model Sweep Training

## Generate models' configs from trials files

In [None]:
n_top_trials = 50

target_models_types = [
    # 'CategoryEmbeddingModel',
    'GANDALF',
    'TabNetModel',
    # 'FTTransformer',
    'DANet'
]

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)

common_params = {
    "task": "regression",
}

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for model_type in target_models_types:
    trials_files = glob(f"{trainer_config['checkpoints_path']}/trials/{model_type}*.xlsx")
    for trials_file in trials_files:
        df_trials = pd.read_excel(trials_file, index_col=0)
        df_trials.sort_values(['mean_absolute_error'], ascending=[True], inplace=True)
        df_trials = df_trials.head(n_top_trials)
        for _, row in df_trials.iterrows():
            head_config_tmp = copy.deepcopy(head_config)
            head_config_tmp['dropout'] = float(row['model_config.head_config__dropout'])
            if model_type == 'CategoryEmbeddingModel':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", CategoryEmbeddingModelConfig)
                model_config['layers'] = row['model_config__layers']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(CategoryEmbeddingModelConfig(**model_config))
            elif model_type == 'GANDALF':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", GANDALFConfig)
                model_config['gflu_stages'] = int(row['model_config__gflu_stages'])
                model_config['gflu_feature_init_sparsity'] = float(row['model_config__gflu_feature_init_sparsity'])
                model_config['gflu_dropout'] = float(row['model_config__gflu_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(GANDALFConfig(**model_config))
            elif model_type == 'TabNetModel':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", TabNetModelConfig)
                model_config['n_steps'] = row['model_config__n_steps']
                model_config['n_shared'] = row['model_config__n_shared']
                model_config['n_independent'] = row['model_config__n_independent']
                model_config['n_d'] = row['model_config__n_d']
                model_config['n_a'] = row['model_config__n_a']
                model_config['mask_type'] = row['model_config__mask_type']
                model_config['gamma'] = row['model_config__gamma']
                model_config['learning_rate'] = row['model_config__learning_rate']
                model_config['seed'] = row['model_config__seed']
                model_config['head_config'] = head_config_tmp
                model_list.append(TabNetModelConfig(**model_config))
            elif model_type == 'FTTransformer':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", FTTransformerConfig)
                model_config['num_heads'] = int(row['model_config__num_heads'])
                model_config['num_attn_blocks'] = int(row['model_config__num_attn_blocks'])
                model_config['attn_dropout'] = float(row['model_config__attn_dropout'])
                model_config['add_norm_dropout'] = float(row['model_config__add_norm_dropout'])
                model_config['ff_dropout'] = float(row['model_config__ff_dropout'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(FTTransformerConfig(**model_config))
            elif model_type == 'DANet':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", DANetConfig)
                model_config['n_layers'] = int(row['model_config__n_layers'])
                model_config['abstlay_dim_1'] = int(row['model_config__abstlay_dim_1'])
                model_config['k'] = int(row['model_config__k'])
                model_config['dropout_rate'] = float(row['model_config__dropout_rate'])
                model_config['learning_rate'] = float(row['model_config__learning_rate'])
                model_config['seed'] = int(row['model_config__seed'])
                model_config['head_config'] = head_config_tmp
                model_list.append(DANetConfig(**model_config))
print(len(model_list))

## Perform model sweep

In [None]:
%%capture

seed = 666

trainer_config['seed'] = seed
trainer_config['checkpoints'] = 'valid_loss'
trainer_config['load_best'] = True
trainer_config['auto_lr_find'] = True

data_config['continuous_feature_transform'] = 'yeo-johnson' #  'box-cox' 'yeo-johnson' 'quantile_normal'
data_config['normalize_continuous_features'] = True

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sweep_df, best_model = model_sweep_custom(
        task="regression",
        # train=train_validation,
        train=train_only,
        # validation=None,
        validation=validation_only,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list=model_list,
        common_model_args=common_params,
        metrics=["mean_absolute_error", "pearson_corrcoef"],
        metrics_params=[{}, {}],
        metrics_prob_input=[False, False],
        rank_metric=("mean_absolute_error", "lower_is_better"),
        return_best_model=True,
        seed=seed,
        progress_bar=False,
        verbose=False,
        suppress_lightning_logger=True,
    )
fn_suffix = f"{seed}_{best_model.config['lr_scheduler']}_{best_model.config['continuous_feature_transform']}"
sweep_df.style.background_gradient(
    subset=[
        "train_loss",
        "validation_loss",
        "test_loss",
        "time_taken",
        "time_taken_per_epoch"
    ], cmap="RdYlGn_r"
).to_excel(f"{trainer_config['checkpoints_path']}/sweep_{fn_suffix}.xlsx")

## Save best models

In [None]:
%%capture

sweep_df = pd.read_excel(f"{trainer_config['checkpoints_path']}/progress.xlsx", index_col=0)

models_ids = [318, 129, 417, 402, 142, 346]

for model_id in models_ids:

    tabular_model = TabularModel(
        data_config=data_config,
        model_config=ast.literal_eval(sweep_df.at[model_id, 'params']),
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        verbose=True,
        suppress_lightning_logger=False
    )
    datamodule = tabular_model.prepare_dataloader(
        train=train_only,
        validation=validation_only,
        seed=seed,
    )
    model = tabular_model.prepare_model(
        datamodule
    )
    tabular_model._prepare_for_training(
        model,
        datamodule
    )
    tabular_model.load_weights(sweep_df.at[model_id, 'checkpoint'])
    tabular_model.evaluate(test, verbose=False)
    tabular_model.save_model(f"{tabular_model.config['checkpoints_path']}/candidates/{model_id}")
    
    loaded_model = TabularModel.load_model(f"{tabular_model.config['checkpoints_path']}/candidates/{model_id}")
    
    df = data.loc[:, ['Age', 'Sex', 'Status']]
    df['Group'] = df['Status']
    df.loc[train_only.index, 'Group'] = 'Train'
    df.loc[validation_only.index, 'Group'] = 'Validation'
    df.loc[test.index, 'Group'] = 'Test'
    df['Prediction'] = loaded_model.predict(data)
    df['Error'] = df['Prediction'] - df['Age']
    df.to_csv(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/df.xlsx")
    
    colors_groups = {
        'Train': 'chartreuse',
        'Validation': 'lightskyblue',
        'Test': 'dodgerblue',
        'ESRD': 'crimson'
    }
    
    df_metrics = pd.DataFrame(
        index=list(colors_groups.keys()),
        columns=['mean_absolute_error', 'pearson_corrcoef', 'mean_age_acc']
    )
    for group in colors_groups.keys():
        pred = torch.from_numpy(df.loc[df['Group'] == group, 'Prediction'].values)
        real = torch.from_numpy(df.loc[df['Group'] == group, 'Age'].values)
        df_metrics.at[group, 'mean_absolute_error'] = mean_absolute_error(pred, real).numpy()
        df_metrics.at[group, 'pearson_corrcoef'] = pearson_corrcoef(pred, real).numpy()
        df_metrics.at[group, 'mean_age_acc'] = np.mean(df.loc[df['Group'] == group, 'Error'].values)
    df_metrics.to_excel(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/metrics.xlsx", index_label="Metrics")
    
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(4.5, 4))
    scatter = sns.scatterplot(
        data=df,
        x="Age",
        y="Prediction",
        hue="Group",
        palette=colors_groups,
        linewidth=0.2,
        alpha=0.75,
        edgecolor="k",
        s=20,
        hue_order=list(colors_groups.keys()),
        ax=ax
    )
    bisect = sns.lineplot(
        x=[0, 120],
        y=[0, 120],
        linestyle='--',
        color='black',
        linewidth=1.0,
        ax=ax
    )
    ax.set_title(f"{sweep_df.at[model_id, 'model']} ({sweep_df.at[model_id, '# Params']} params, {sweep_df.at[model_id, 'epochs']} epochs)")
    ax.set_xlim(0, 120)
    ax.set_ylim(0, 120)
    plt.gca().set_aspect('equal', adjustable='box')
    fig.savefig(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/scatter.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/scatter.pdf", bbox_inches='tight')
    plt.close(fig)
    
    df_fig = df.loc[:, ['Error', 'Group']]
    groups_rename = {
        group: f"{group}" + "\n" +
               fr"MAE: {df_metrics.at[group, 'mean_absolute_error']:0.2f}" + "\n"
               fr"Pearson $\rho$: {df_metrics.at[group, 'pearson_corrcoef']:0.2f}" + "\n" +
               fr"$\langle$Error$\rangle$: {df_metrics.at[group, 'mean_age_acc']:0.2f}" 
        for group in colors_groups
    }
    colors_groups_violin = {groups_rename[group]: colors_groups[group] for group in colors_groups}
    df_fig['Group'].replace(groups_rename, inplace=True)
    sns.set_theme(style='whitegrid')
    fig, ax = plt.subplots(figsize=(7, 4))
    violin = sns.violinplot(
        data=df_fig,
        x='Group',
        y='Error',
        palette=colors_groups_violin,
        scale='width',
        order=list(colors_groups_violin.keys()),
        saturation=0.75,
        legend=False,
        ax=ax
    )
    ax.set_xlabel('')
    fig.savefig(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/violin.png", bbox_inches='tight', dpi=200)
    fig.savefig(f"{loaded_model.config['checkpoints_path']}/candidates/{model_id}/violin.pdf", bbox_inches='tight')
    plt.close(fig)

# Simple TabularModel training

In [None]:
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = 'valid_loss'
trainer_config['load_best'] = True
trainer_config['auto_lr_find'] = True

tabular_model = TabularModel(
    data_config=f"{path_configs}/DataConfig.yaml",
    model_config=f"{path_configs}/CategoryEmbeddingModelConfig.yaml",
    optimizer_config=f"{path_configs}/OptimizerConfig.yaml",
    trainer_config=trainer_config,
    verbose=True,
    suppress_lightning_logger=False
)

tabular_model.fit(
    train=train_only,
    validation=validation_only,
    # target_transform=[np.log, np.exp],
    # callbacks=[DeviceStatsMonitor()],
)

## Play with trained model

In [None]:
tabular_model.predict(test, progress_bar='rich')

In [None]:
tabular_model.evaluate(test, verbose=True, ckpt_path="best")

In [None]:
tabular_model.config['checkpoints_path']

In [None]:
print(tabular_model.trainer.checkpoint_callback.best_model_path)

In [None]:
tabular_model.summary()

In [None]:
tabular_model.save_model(tabular_model.config['checkpoints_path'])

In [None]:
tabular_model.save_config(tabular_model.config['checkpoints_path'])

In [None]:
tabular_model = TabularModel.load_model(tabular_model.config['checkpoints_path'])