# Hyperparameter Tuning for NAM

In the following, we will tune the most important hyperparameters of the Neural Additive Model (NAM) using Optuna since we realised during training that the model is very sensitive to the choice of some hyperparameters. We will propose similar hyperparameter values as the ones used in the [NAM paper](https://arxiv.org/pdf/2004.13912.pdf).

## Packages and Presets

In [11]:
import numpy as np
import pandas as pd
from skimpy import clean_columns
from pickle import load

from sklearn.model_selection import train_test_split, StratifiedKFold
from torch.utils.data import DataLoader
from neural_additive_model import NAM

import torch
import torch.nn as nn
import optuna
import logging


# append path to parent folder to allow imports from utils folder
import sys

sys.path.append("../..")
from utils.utils import (
    set_all_seeds,
    HeartFailureDataset,
    train_and_validate_one_epoch,
    get_n_units,
    penalized_binary_cross_entropy,
)

import warnings
warnings.filterwarnings("ignore")

In [12]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 42
N_TRIALS = 1_000
N_EPOCHS = 300

## Preprocess Data

In [13]:
train_df = pd.read_csv("../data/heart_failure/train_val_split.csv").pipe(
    clean_columns
)
X = train_df.drop(columns=["heart_disease"], axis=1)
outlier_idx = X.query("resting_bp == 0").index
X = X.drop(outlier_idx)
y = train_df["heart_disease"]
y = y[X.index]

# create categorical variable for cholesterol level
X["chol_level"] = pd.cut(
    X["cholesterol"],
    bins=[-1, 10, 200, 240, 1000],
    labels=["imputed", "normal", "borderline", "high"],
)

X.index = range(len(X))
y.index = range(len(y))


# load and apply preprocessor:
preprocessor = load(open("../models/preprocessor.pkl", "rb"))

## Hyperparameter Tuning

In [14]:
# create objective function for optuna
def objective(trial):
    NAM_params = {
        "out_size": 1,
        "hidden_profile": trial.suggest_categorical(
            "hidden_profile", [
                [64, 64, 32],
                [1024],
                [256],
                [128, 256, 128, 64],
                [256, 128, 128],
            ]
        ),
        "use_exu": False, 
        "use_relu_n": False, 
        "within_feature_dropout": trial.suggest_categorical(
            "within_feature_dropout",
            [0, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        ),
        "feature_dropout": trial.suggest_categorical(
            "feature_dropout", [0, 0.05, 0.1, 0.2]
        )
    }
    
    pen_bce_params = {
        "output_regularization": trial.suggest_float("output_regularization", 1e-3, 1e-1, log=True),
        "l2_regularization": trial.suggest_float("l2_regularization", 1e-6, 1e-4, log=True),
    }
        
    adam_params = {"lr": trial.suggest_float("lr", 1e-4, 1e-1, log=True)}
   
    set_all_seeds(SEED)
    
    balanced_accuracies = {}
    
    skf = StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)
    for fold_num, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
        y_train, y_val = y[train_idx].to_numpy(), y[val_idx].to_numpy()
        
        X_train = preprocessor.fit_transform(X_train)
        X_val = preprocessor.transform(X_val)
        
        NAM_params["n_features"] = X_train.shape[1]
        NAM_params["in_size"] = get_n_units(X_train)
        
        set_all_seeds(SEED)
        model = NAM(**NAM_params).to(DEVICE)
        
         
        optimizer = torch.optim.Adam(model.parameters(), **adam_params)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
        
        # Create the dataset and dataloader
        train_dataset = HeartFailureDataset(X_train, y_train)
        val_dataset = HeartFailureDataset(X_val, y_val)


        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=True, pin_memory=True
        )

        val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, pin_memory=True)
        
        balanced_accuracies[fold_num] = []
        #f1_scores[fold_num] = []
        
        for epoch in range(N_EPOCHS):
            set_all_seeds(SEED)
            _,_, _, balanced_accuracy = train_and_validate_one_epoch(
                model=model,
                optimizer=optimizer,
                criterion=penalized_binary_cross_entropy,
                train_loader=train_loader,
                val_loader=val_loader,
                device=DEVICE,
                scheduler=scheduler,
                use_penalized_BCE=True,
                **pen_bce_params,
            )
            
            balanced_accuracies[fold_num].append(balanced_accuracy)
            #f1_scores[fold_num].append(f1_score)
            
            trial.report(balanced_accuracy, epoch) #(balanced_accuracy + f1_score) / 2, fold_num + epoch)
            if trial.should_prune():
                raise optuna.TrialPruned()
            
        
        return (np.mean([lst[-5:] for lst in balanced_accuracies.values()]))

In [15]:
# set up logger and log optuna trials to log file
logging.basicConfig(filename="hyperparameter_tuning.log", level=logging.INFO)

optuna.logging.enable_propagation()
optuna.logging.disable_default_handler()

# prune bad trials since NAM is computationally expensive
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=20)

study = optuna.create_study(
    direction="maximize",
    study_name="NAM_hyperparameter_tuning",
    sampler=optuna.samplers.TPESampler(seed=SEED),
    pruner=pruner,
)

In [16]:
study.optimize(
    objective, 
    n_trials=N_TRIALS,
    timeout = 20 * 60 * 60, # timeout after 20 hours
    show_progress_bar=True
)

  0%|          | 0/1000 [00:00<?, ?it/s]

[I 2024-04-08 22:45:23,053] Trial 0 finished with value: 0.8421444527177959 and parameters: {'hidden_profile': [1024], 'within_feature_dropout': 0.6, 'feature_dropout': 0.1, 'output_regularization': 0.0038234752246751854, 'l2_regularization': 1.6738085788752145e-05, 'lr': 0.00026210878782654407}. Best is trial 0 with value: 0.8421444527177959.
[I 2024-04-08 22:47:26,265] Trial 1 finished with value: 0.6235107967237528 and parameters: {'hidden_profile': [128, 256, 128, 64, 32], 'within_feature_dropout': 0.7, 'feature_dropout': 0.05, 'output_regularization': 0.009780337016659405, 'l2_regularization': 1.1715937392307063e-06, 'lr': 0.053451661106468214}. Best is trial 0 with value: 0.8421444527177959.
[I 2024-04-08 22:48:13,898] Trial 2 finished with value: 0.8842144452717795 and parameters: {'hidden_profile': [1024], 'within_feature_dropout': 0.1, 'feature_dropout': 0.2, 'output_regularization': 0.005170191786366992, 'l2_regularization': 3.646439558980721e-06, 'lr': 0.004247058562261873}.

KeyboardInterrupt: 

In [None]:
import yaml

best_params = study.best_params

config_path = "config.yaml"

with open(config_path, "w") as file:
    yaml.dump(best_params, file)
    
print(best_params)

{'hidden_profile': [1024], 'use_exu': False, 'use_relu_n': True, 'within_feature_dropout': 0.2, 'feature_dropout': 0.2, 'output_regularization': 0.0044043429133187205, 'l2_regularization': 1.6634538799288264e-06, 'lr': 0.0016138573407084142}


In [None]:
p_importance = optuna.visualization.plot_param_importances(study)
p_importance.write_image("param_importance.png")
p_importance.show()

In [None]:
p_history = optuna.visualization.plot_optimization_history(study)
p_history.write_image("optimization_history.png")
p_history.show()