# Debugging autoreload

In [ ]:
%load_ext autoreload
%autoreload 2

# Load packages

In [None]:
from pytorch_tabular.utils import load_covertype_dataset
from rich.pretty import pprint
from glob import glob
import copy
from sklearn.model_selection import train_test_split
import numpy as np
from pytorch_tabular.utils import make_mixed_dataset, print_metrics
from pytorch_tabular import available_models
from pytorch_tabular import TabularModel
from pytorch_tabular.models import CategoryEmbeddingModelConfig, GANDALFConfig, TabNetModelConfig, FTTransformerConfig
from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig
from pytorch_tabular.models.common.heads import LinearHeadConfig
from pytorch_tabular.tabular_model_tuner import TabularModelTuner
from pytorch_tabular import MODEL_SWEEP_PRESETS
import pandas as pd
from pytorch_tabular import model_sweep
import warnings
from src.utils.configs import read_parse_config
from src.utils.hash import dict_hash


# Load data

In [None]:
path_data = "D:/YandexDisk/Work/pydnameth/datasets/GPL21145/GSEUNN/special/060_EpiSImAge/SImAge_repeat"
path_configs = "D:/Work/bbs/notebooks/immunology/001_pytorch_tabular_SImAge_log"
data = pd.read_excel(f"{path_data}/data.xlsx", index_col=1)
feats = pd.read_excel(f"{path_data}/feats_con10.xlsx", index_col=0).index.values.tolist()
cv_df = pd.read_excel(f"{path_data}/cv_ids.xlsx", index_col=0)
cv_df = cv_df.loc[data.index, :]
train_only = data.loc[cv_df.index[cv_df['fold_0002'] == 'trn'].values, feats + ['Age']]
validation_only = data.loc[cv_df.index[cv_df['fold_0002'] == 'val'].values, feats + ['Age']]
train_validation = data.loc[data["Dataset"] == "Train/Validation", feats + ['Age']]
test = data.loc[data["Dataset"] == "Test Controls", feats + ['Age']]
cv_indexes = [
    (
        np.where(train_validation.index.isin(cv_df.index[cv_df[f"fold_{i:04d}"] == 'trn']))[0],
        np.where(train_validation.index.isin(cv_df.index[cv_df[f"fold_{i:04d}"] == 'val']))[0],
    ) 
    for i in range(5)
]

# GANDALF Search Space

In [None]:
search_space = {
    "model_config__gflu_stages": [10, 15, 20, 25],
    "model_config__gflu_dropout": [0.1, 0.2],
    "model_config__gflu_feature_init_sparsity": [0.2, 0.3, 0.4],
    "model_config.head_config__dropout": [0.1, 0.2],
    "model_config__seed": [42, 1337, 666, 777, 16022008, 26111993, 28042020, 16, 7, 456456456],
}
model_config = read_parse_config(f"{path_configs}/GANDALFConfig.yaml", GANDALFConfig)

# CategoryEmbeddingModel Search Space

In [None]:
search_space = {
    "model_config__layers": ["256-128-64", "512-256-256", "32-16", "32-32-16", "16-8", "32-16-8", "128-64", "128-128", "10-10"],
    "model_config.head_config__dropout": [0.05, 0.1, 0.15, 0.2, 0.25],
    "model_config__seed": [42, 1337, 666, 777, 16022008, 26111993, 28042020, 16, 7, 456456456],
}
model_config = read_parse_config(f"{path_configs}/CategoryEmbeddingModelConfig.yaml", CategoryEmbeddingModelConfig)

# TabNetModel Search Space

In [None]:
search_space = {
    "model_config__n_d": [16, 32],
    "model_config__n_a": [16, 32],
    "model_config__n_steps": [3, 6],
    "model_config__gamma": [1.3, 1.5],
    "model_config__n_independent": [2, 3],
    "model_config__n_shared": [2, 3],
    "model_config__mask_type": ["sparsemax", "entmax"],
    "model_config.head_config__dropout": [0.1, 0.2],
    "model_config__seed": [42, 1337, 666, 777, 16022008, 26111993, 28042020, 16, 7, 456456456],
}
model_config = read_parse_config(f"{path_configs}/TabNetModelConfig.yaml", TabNetModelConfig)

# Grid Search and Random Search

In [None]:
strategy = 'grid_search'
seed = 456456456
n_random_trials = 100
is_cross_validation = False

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    if is_cross_validation:
        result = tuner.tune(
        train=train_validation,
        validation=None,
        search_space=search_space,
        metric="mean_absolute_error",
        mode="min",
        strategy=strategy,
        n_trials=n_random_trials,
        cv=cv_indexes,
        return_best_model=True,
        verbose=False,
        progress_bar=False,
        random_state=seed,
    )
    else:
        result = tuner.tune(
            train=train_only,
            validation=validation_only,
            search_space=search_space,
            metric="mean_absolute_error",
            mode="min",
            strategy=strategy,
            n_trials=n_random_trials,
            cv=None,
            return_best_model=True,
            verbose=False,
            progress_bar=False,
            random_state=seed,
        )

result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/{model_config['_model_name']}_{strategy}_{seed}_{dict_hash(search_space)[:10]}.xlsx")

# FTTransformer Greed Search

In [None]:
seed = 42
n_trials = 100

search_space = {
    "model_config__num_heads": [2, 4, 8, 16, 32],
    "model_config__num_attn_blocks": [4, 6, 8, 10, 12],
    "model_config__attn_dropout": [0.05, 0.1, 0.15, 0.2],
    "model_config__add_norm_dropout": [0.05, 0.1, 0.15, 0.2],
    "model_config__ff_dropout": [0.05, 0.1, 0.15, 0.2],
    "model_config.head_config__dropout": [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
}

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
model_config = read_parse_config(f"{path_configs}/FTTransformerConfig.yaml", FTTransformerConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    result = tuner.tune(
        train=train,
        validation=None,
        search_space=search_space,
        metric="mean_absolute_error",
        mode="min",
        strategy="random_search",
        n_trials=n_trials,
        cv=5,
        return_best_model=True,
        verbose=False,
        progress_bar=False,
        random_state=seed,
    )
result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/FTTransformer_{seed}_{n_trials}_{dict_hash(search_space)[:10]}.xlsx")

# TabNetModel Greed Search

In [None]:
seed = 42
n_trials = 100

# search_space = {
#     "model_config__n_d": [4, 8, 12, 16, 24, 32, 48],
#     "model_config__n_a": [4, 8, 12, 16, 24, 32, 48],
#     "model_config__n_steps": [3, 5, 7],
#     "model_config__gamma": [1.1, 1.3, 1.5, 1.7, 1.9],
#     "model_config__n_independent": [1, 2, 3],
#     "model_config__n_shared": [1, 2, 3],
#     "model_config__mask_type": ["sparsemax", "entmax"],
#     "model_config.head_config__dropout": [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3],
# }

search_space = {
    "model_config__n_d": [16, 32],
    "model_config__n_a": [16, 32],
    "model_config__n_steps": [3, 5],
    "model_config__gamma": [1.3, 1.9],
    "model_config__n_independent": [1, 2, 3],
    "model_config__n_shared": [1, 2, 3],
    "model_config__mask_type": ["sparsemax"],
    "model_config.head_config__dropout": [0.1, 0.2],
}

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
model_config = read_parse_config(f"{path_configs}/TabNetModelConfig.yaml", TabNetModelConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['checkpoints'] = None
trainer_config['load_best'] = False
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)

tuner = TabularModelTuner(
    data_config=data_config,
    model_config=model_config,
    optimizer_config=optimizer_config,
    trainer_config=trainer_config,
    suppress_lightning_logger=True,
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    result = tuner.tune(
        train=train,
        validation=None,
        search_space=search_space,
        metric="mean_absolute_error",
        mode="min",
        strategy="random_search",
        n_trials=n_trials,
        cv=5,
        return_best_model=True,
        verbose=False,
        progress_bar=False,
        random_state=seed,
    )
result.trials_df.to_excel(f"{trainer_config['checkpoints_path']}/trials/TabNetModel_{seed}_{n_trials}_{dict_hash(search_space)[:10]}.xlsx")

# Model Sweep training

In [None]:
print(list(MODEL_SWEEP_PRESETS.keys()))
print(list(MODEL_SWEEP_PRESETS["standard"]))

In [None]:
seed = 22222
n_top_trials = 10

target_models_types = [
    'CategoryEmbeddingModel',
    'GANDALF',
    'TabNetModel'
]

data_config = read_parse_config(f"{path_configs}/DataConfig.yaml", DataConfig)
optimizer_config = read_parse_config(f"{path_configs}/OptimizerConfig.yaml", OptimizerConfig)
trainer_config = read_parse_config(f"{path_configs}/TrainerConfig.yaml", TrainerConfig)
trainer_config['seed'] = seed

common_params = {
    "task": "regression",
}

head_config = LinearHeadConfig(
    layers="",
    activation='ReLU',
    dropout=0.1,
    use_batch_norm=False,
    initialization="kaiming"
).__dict__

model_list = []
for model_type in target_models_types:
    trials_files = glob(f"{trainer_config['checkpoints_path']}/trials/{model_type}_*.xlsx")
    for trials_file in trials_files:
        df_trials = pd.read_excel(trials_file, index_col=0)
        df_trials.sort_values(['mean_absolute_error'], ascending=[True], inplace=True)
        df_trials = df_trials.head(n_top_trials)
        for _, row in df_trials.iterrows():
            head_config_tmp = copy.deepcopy(head_config)
            head_config_tmp['dropout'] = row['model_config.head_config__dropout']
            if model_type == 'CategoryEmbeddingModel':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", CategoryEmbeddingModelConfig)
                model_config['layers'] = row['model_config__layers']
                model_config['head_config'] = head_config_tmp
                model_list.append(CategoryEmbeddingModelConfig(**model_config))
            elif model_type == 'GANDALF':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", GANDALFConfig)
                model_config['gflu_stages'] = row['model_config__gflu_stages']
                model_config['gflu_feature_init_sparsity'] = row['model_config__gflu_feature_init_sparsity']
                model_config['gflu_dropout'] = row['model_config__gflu_dropout']
                model_config['head_config'] = head_config_tmp
                model_list.append(GANDALFConfig(**model_config))
            elif model_type == 'TabNetModel':
                model_config = read_parse_config(f"{path_configs}/{model_type}Config.yaml", TabNetModelConfig)
                model_config['n_steps'] = row['model_config__n_steps']
                model_config['n_shared'] = row['model_config__n_shared']
                model_config['n_independent'] = row['model_config__n_independent']
                model_config['n_d'] = row['model_config__n_d']
                model_config['n_a'] = row['model_config__n_a']
                model_config['mask_type'] = row['model_config__mask_type']
                model_config['gamma'] = row['model_config__gamma']
                model_config['head_config'] = head_config_tmp
                model_list.append(TabNetModelConfig(**model_config))
print(len(model_list))

In [None]:
              
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    sweep_df, best_model = model_sweep(
        task="regression",
        train=train,
        validation=None,
        test=test,
        data_config=data_config,
        optimizer_config=optimizer_config,
        trainer_config=trainer_config,
        model_list=model_list,
        common_model_args=common_params,
        metrics=["mean_absolute_error", "pearson_corrcoef"],
        metrics_params=[{}, {}],
        metrics_prob_input=[False, False],
        rank_metric=("mean_absolute_error", "lower_is_better"),
        return_best_model=True,
        seed=seed,
        progress_bar=False,
        verbose=False,
        suppress_lightning_logger=True,
    )
sweep_df.style.background_gradient(
    subset=["test_mean_absolute_error", "time_taken", "time_taken_per_epoch"], cmap="RdYlGn_r"
).background_gradient(
    subset=["test_pearson_corrcoef"], cmap="RdYlGn"
).to_excel(f"{trainer_config['checkpoints_path']}/sweep_{seed}.xlsx")

In [None]:
best_model.trainer.checkpoint_callback.best_model_path
best_model.summary()

# TabularModel training

In [None]:
tabular_model = TabularModel(
    data_config=f"{path_configs}/DataConfig.yaml",
    model_config=f"{path_configs}/CategoryEmbeddingModelConfig.yaml",
    optimizer_config=f"{path_configs}/OptimizerConfig.yaml",
    trainer_config=f"{path_configs}/TrainerConfig.yaml",
    verbose=True,
    suppress_lightning_logger=False
)

tabular_model.fit(
    train=train,
    validation=None,
    # target_transform=[np.log, np.exp],
    # callbacks=[DeviceStatsMonitor()],
)

In [None]:
prediction = tabular_model.predict(test, progress_bar='rich')

In [None]:
tabular_model.evaluate(test, verbose=True, ckpt_path="best")

In [None]:
tabular_model.trainer.checkpoint_callback.best_model_path

In [None]:
tabular_model.summary()

In [None]:
tabular_model.save_model(tabular_model.config['checkpoints_path'])

In [None]:
tabular_model.save_config(tabular_model.config['checkpoints_path'])

In [None]:
model = TabularModel.load_model(f"{path_data}/pytorch_tabular")