# Hyperparameter Tuning for NAM

In the following, we will tune the most important hyperparameters of the Neural Additive Model (NAM) using Optuna since we realised during training that the model is very sensitive to the choice of some hyperparameters. We will use a very similar hyperparameter search space as the one used in appendix A.6 of the [NAM paper](https://arxiv.org/pdf/2004.13912.pdf).

## Packages and Presets

In [17]:
import numpy as np
import pandas as pd
from skimpy import clean_columns
from pickle import load

from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader
from neural_additive_model import NAM

import torch
import torch.nn as nn
import optuna
import logging


# append path to parent folder to allow imports from utils folder
import sys

sys.path.append("../..")
from utils.utils import (
    set_all_seeds,
    HeartFailureDataset,
    train_and_validate_one_epoch,
    get_n_units,
    penalized_binary_cross_entropy,
)

import warnings
warnings.filterwarnings("ignore")

In [18]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
N_TRIALS = 1_000
N_EPOCHS = 300

## Preprocess Data

In [19]:
train_df = pd.read_csv("../data/heart_failure/train_val_split.csv").pipe(
    clean_columns
)
X = train_df.drop(columns=["heart_disease"], axis=1)
outlier_idx = X.query("resting_bp == 0").index
X = X.drop(outlier_idx)
y = train_df["heart_disease"]
y = y[X.index]

# create categorical variable for cholesterol level
X["chol_level"] = pd.cut(
    X["cholesterol"],
    bins=[-1, 10, 200, 240, 1000],
    labels=["imputed", "normal", "borderline", "high"],
)

X.index = range(len(X))
y.index = range(len(y))


# load and apply preprocessor:
preprocessor = load(open("../models/preprocessor.pkl", "rb"))

## Hyperparameter Tuning

In [20]:
# create objective function for optuna
def objective(trial):
    NAM_params = {
        "out_size": 1,
        "hidden_profile": trial.suggest_categorical(
            "hidden_profile", [
                [64, 64, 32],
                [1024],
                [256],
                [128, 256, 128, 64],
                [256, 128, 128],
            ]
        ),
        # Don't use either exu nor relu_n as we had problems otherwise with using 
        # SHAP on our model whilst leading similar performance as linear layers
        # with ReLU activation functions
        "use_exu": False, 
        "use_relu_n": False, 
        "within_feature_dropout": trial.suggest_categorical(
            "within_feature_dropout",
            [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        ),
        "feature_dropout": trial.suggest_categorical(
            "feature_dropout", [0, 0.05, 0.1, 0.2]
        )
    }
    
    pen_bce_params = {
        "output_regularization": trial.suggest_float("output_regularization", 1e-3, 1e-1, log=True),
        "l2_regularization": trial.suggest_float("l2_regularization", 1e-6, 1e-4, log=True),
    }
        
    adam_params = {"lr": trial.suggest_float("lr", 1e-4, 1e-1, log=True)}
   
    set_all_seeds(SEED)
    
    balanced_accuracies = {}
    
    skf = StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)
    for fold_num, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
        y_train, y_val = y[train_idx].to_numpy(), y[val_idx].to_numpy()
        
        X_train = preprocessor.fit_transform(X_train)
        X_val = preprocessor.transform(X_val)
        
        NAM_params["n_features"] = X_train.shape[1]
        NAM_params["in_size"] = get_n_units(X_train)
        
        set_all_seeds(SEED)
        model = NAM(**NAM_params).to(DEVICE)
        
        
        optimizer = torch.optim.Adam(model.parameters(), **adam_params)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
        
        # Create the dataset and dataloader
        train_dataset = HeartFailureDataset(X_train, y_train)
        val_dataset = HeartFailureDataset(X_val, y_val)


        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=True, pin_memory=True
        )

        val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, pin_memory=True)
        
        balanced_accuracies[fold_num] = []
        
        for epoch in range(N_EPOCHS):
            set_all_seeds(SEED)
            _,_, _, balanced_accuracy = train_and_validate_one_epoch(
                model=model,
                optimizer=optimizer,
                criterion=penalized_binary_cross_entropy,
                train_loader=train_loader,
                val_loader=val_loader,
                device=DEVICE,
                scheduler=scheduler,
                use_penalized_BCE=True,
                **pen_bce_params,
            )
            
            balanced_accuracies[fold_num].append(balanced_accuracy)
            
            trial.report(balanced_accuracy, epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
            
        
        return (np.mean([lst[-5:] for lst in balanced_accuracies.values()]))

In [21]:
# set up logger and log optuna trials to log file
logging.basicConfig(filename="hyperparameter_tuning.log", level=logging.INFO)

optuna.logging.enable_propagation()
optuna.logging.disable_default_handler()

# prune bad trials since NAM is computationally expensive
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=20)

study = optuna.create_study(
    direction="maximize",
    study_name="NAM_hyperparameter_tuning",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=pruner,
)

In [22]:
study.optimize(
    objective, 
    n_trials=N_TRIALS,
    timeout = 20 * 60 * 60, # timeout after 20 hours
    show_progress_bar=True
)

  0%|          | 0/1000 [00:00<?, ?it/s]

[I 2024-04-08 23:21:41,970] Trial 0 finished with value: 0.8494973938942666 and parameters: {'hidden_profile': [1024], 'within_feature_dropout': 0.5, 'feature_dropout': 0.05, 'output_regularization': 0.01673808578875214, 'l2_regularization': 1.9010245319870378e-06, 'lr': 0.0007523742884534858}. Best is trial 0 with value: 0.8494973938942666.
[I 2024-04-08 23:22:25,152] Trial 1 finished with value: 0.8484735666418466 and parameters: {'hidden_profile': [256], 'within_feature_dropout': 0.5, 'feature_dropout': 0.1, 'output_regularization': 0.06586289317583113, 'l2_regularization': 3.292759134423615e-06, 'lr': 0.009717775305059635}. Best is trial 0 with value: 0.8494973938942666.
[I 2024-04-08 23:23:42,187] Trial 2 finished with value: 0.8705323901712584 and parameters: {'hidden_profile': [256, 128, 128], 'within_feature_dropout': 0.05, 'feature_dropout': 0, 'output_regularization': 0.00191358804876923, 'l2_regularization': 4.0215545266902885e-05, 'lr': 0.00016736010167825804}. Best is tria

In [23]:
best_params = study.best_params
    
print(best_params)

{'hidden_profile': [1024], 'within_feature_dropout': 0.4, 'feature_dropout': 0, 'output_regularization': 0.018909985590848253, 'l2_regularization': 3.86778344591064e-05, 'lr': 0.0058300321150487515}


In [26]:
p_importance = optuna.visualization.plot_param_importances(study)
p_importance.write_image("param_importance.png")
p_importance.show()

In [27]:
p_history = optuna.visualization.plot_optimization_history(study)
p_history.write_image("optimization_history.png")
p_history.show()