# **Hyperparameter Tuning with Optuna (Tabnet)**

## 1. **Defining the Search Space and Launching the Optimization Study**

In this section, we tune the TabNet model using **Optuna** to automatically search for the best hyperparameters that maximize validation F1-score.

We used:
- `optuna.create_study(direction="maximize")` to find hyperparameter combinations that boost F1
- A search space including TabNet dimensions (`n_d`, `n_a`), number of steps, `gamma`, `lambda_sparse`, learning rate, etc.
- `torch.optim.lr_scheduler.StepLR` as our learning rate scheduler
- A custom objective function that returns validation F1-score

We saved:
- The full Optuna study (`tabnet_study.pkl`)
- The best hyperparameter configuration (`tabnet_best_params.csv`)

This step follows best practices from **Chapter 7 – Advanced Deep Learning** in *François Chollet’s Deep Learning with Python*, especially around hyperparameter optimization and model generalization.


In [None]:
# tuning/tabnet_tuning.ipynb
import sys
import os

# Set path to root of the project (adjust if needed)
project_root = os.path.abspath("..")  # one level up
if project_root not in sys.path:
    sys.path.append(project_root)

In [15]:
# 1. Imports and Setup
import optuna
import joblib
import os
import numpy as np
import pandas as pd
import torch



from models.tabnet_model import build_tabnet_model
from training.tabnet_trainer import train_tabnet_model
from sklearn.metrics import f1_score, roc_auc_score
from pytorch_tabnet.tab_model import TabNetClassifier
from utils.style_utils import styled_print
from torch.optim import Adam, RMSprop
from torch.optim.lr_scheduler import StepLR  



# 2. Load the original non-SMOTE, scaled dataset
data_path = "../artifacts/tabnet/data_scaled_nosmote_for_tabnet.pkl"
X_train_scaled, y_train, X_val_scaled, y_val, X_test_scaled, y_test = joblib.load(data_path)

styled_print("🗂️ Loaded the original non-SMOTE, scaled dataset from: artifacts/tabnet/data_scaled_nosmote_for_tabnet.pkl")

#  3. Define Optuna Objective Function
def objective(trial):
    params = {
        "n_d": trial.suggest_categorical("n_d", [8, 16, 32, 64]),
        "n_a": trial.suggest_categorical("n_a", [8, 16, 32, 64]),
        "n_steps": trial.suggest_int("n_steps", 3, 7),
        "gamma": trial.suggest_float("gamma", 1.0, 2.0, step=0.5),
        "lambda_sparse": trial.suggest_float("lambda_sparse", 1e-5, 1e-2, log=True),
        "optimizer_params": dict(lr=trial.suggest_float("lr", 1e-4, 1e-2, log=True)),
        "scheduler_params": {"step_size": 10, "gamma": 0.95},
        "scheduler_fn": torch.optim.lr_scheduler.StepLR,
        "mask_type": "entmax",  # better for sparse data
        "device_name": "cuda" if torch.cuda.is_available() else "cpu",
    }

    model = TabNetClassifier(**params)
    model.fit(
        X_train=X_train_scaled.values,
        y_train=y_train.values,
        eval_set=[(X_val_scaled.values, y_val.values)],
        eval_metric=["auc", "f1_score"],
        patience=10,
        max_epochs=50,
        batch_size=256,
        virtual_batch_size=128
    )

    preds = model.predict(X_val_scaled.values)
    score = f1_score(y_val, preds)

    return score  # or ROC AUC


In [20]:
# 4. Run Tuning (the Optuna Study)

study = optuna.create_study(direction="maximize")  # Creates a new study
study.optimize(objective, n_trials=10, show_progress_bar=True)  # Run with progress bar


[I 2025-03-28 04:55:27,175] A new study created in memory with name: no-name-2afb2dc4-2a33-443a-97ee-0679dc2d69e6


  0%|          | 0/10 [00:00<?, ?it/s]



epoch 0  | loss: 0.07094 | val_0_auc: 0.63635 | val_0_f1_score: 0.40426 |  0:00:23s
epoch 1  | loss: 0.02515 | val_0_auc: 0.66184 | val_0_f1_score: 0.35714 |  0:00:45s
epoch 2  | loss: 0.01868 | val_0_auc: 0.84247 | val_0_f1_score: 0.47619 |  0:01:08s
epoch 3  | loss: 0.01647 | val_0_auc: 0.83031 | val_0_f1_score: 0.672   |  0:01:31s
epoch 4  | loss: 0.01501 | val_0_auc: 0.84646 | val_0_f1_score: 0.70492 |  0:01:56s
epoch 5  | loss: 0.01433 | val_0_auc: 0.77067 | val_0_f1_score: 0.624   |  0:02:18s
epoch 6  | loss: 0.01348 | val_0_auc: 0.84022 | val_0_f1_score: 0.56075 |  0:02:43s
epoch 7  | loss: 0.0124  | val_0_auc: 0.91392 | val_0_f1_score: 0.73381 |  0:03:06s
epoch 8  | loss: 0.01147 | val_0_auc: 0.54064 | val_0_f1_score: 0.27184 |  0:03:27s
epoch 9  | loss: 0.01191 | val_0_auc: 0.77175 | val_0_f1_score: 0.35849 |  0:03:49s
epoch 10 | loss: 0.01019 | val_0_auc: 0.90572 | val_0_f1_score: 0.75385 |  0:04:11s
epoch 11 | loss: 0.00983 | val_0_auc: 0.9129  | val_0_f1_score: 0.69182 |  0



[I 2025-03-28 05:03:32,144] Trial 0 finished with value: 0.7538461538461538 and parameters: {'n_d': 64, 'n_a': 32, 'n_steps': 6, 'gamma': 1.0, 'lambda_sparse': 0.0018731207587726801, 'lr': 0.0006340178596401655}. Best is trial 0 with value: 0.7538461538461538.




epoch 0  | loss: 0.02383 | val_0_auc: 0.93012 | val_0_f1_score: 0.47059 |  0:00:12s
epoch 1  | loss: 0.01056 | val_0_auc: 0.93348 | val_0_f1_score: 0.5     |  0:00:24s
epoch 2  | loss: 0.00771 | val_0_auc: 0.94552 | val_0_f1_score: 0.72131 |  0:00:37s
epoch 3  | loss: 0.00641 | val_0_auc: 0.96844 | val_0_f1_score: 0.77778 |  0:00:51s
epoch 4  | loss: 0.0058  | val_0_auc: 0.95929 | val_0_f1_score: 0.7438  |  0:01:03s
epoch 5  | loss: 0.00516 | val_0_auc: 0.96412 | val_0_f1_score: 0.80315 |  0:01:16s
epoch 6  | loss: 0.00481 | val_0_auc: 0.9649  | val_0_f1_score: 0.76423 |  0:01:28s
epoch 7  | loss: 0.00431 | val_0_auc: 0.95612 | val_0_f1_score: 0.81818 |  0:01:40s
epoch 8  | loss: 0.00464 | val_0_auc: 0.96733 | val_0_f1_score: 0.68966 |  0:01:53s
epoch 9  | loss: 0.00425 | val_0_auc: 0.97065 | val_0_f1_score: 0.7377  |  0:02:05s
epoch 10 | loss: 0.00399 | val_0_auc: 0.97655 | val_0_f1_score: 0.81818 |  0:02:17s
epoch 11 | loss: 0.00377 | val_0_auc: 0.96888 | val_0_f1_score: 0.74797 |  0



[I 2025-03-28 05:10:26,671] Trial 1 finished with value: 0.8333333333333334 and parameters: {'n_d': 32, 'n_a': 16, 'n_steps': 4, 'gamma': 2.0, 'lambda_sparse': 0.006276234052405445, 'lr': 0.0056400114420726115}. Best is trial 1 with value: 0.8333333333333334.




epoch 0  | loss: 0.0501  | val_0_auc: 0.80959 | val_0_f1_score: 0.57143 |  0:00:30s
epoch 1  | loss: 0.01287 | val_0_auc: 0.869   | val_0_f1_score: 0.66667 |  0:00:55s
epoch 2  | loss: 0.0101  | val_0_auc: 0.92312 | val_0_f1_score: 0.64384 |  0:01:21s
epoch 3  | loss: 0.00887 | val_0_auc: 0.93863 | val_0_f1_score: 0.74419 |  0:01:47s
epoch 4  | loss: 0.00782 | val_0_auc: 0.95913 | val_0_f1_score: 0.69504 |  0:02:11s
epoch 5  | loss: 0.00706 | val_0_auc: 0.89452 | val_0_f1_score: 0.69421 |  0:02:35s
epoch 6  | loss: 0.00661 | val_0_auc: 0.93842 | val_0_f1_score: 0.66242 |  0:03:00s
epoch 7  | loss: 0.00613 | val_0_auc: 0.95088 | val_0_f1_score: 0.67626 |  0:03:26s
epoch 8  | loss: 0.0058  | val_0_auc: 0.92104 | val_0_f1_score: 0.75    |  0:03:51s
epoch 9  | loss: 0.00538 | val_0_auc: 0.93066 | val_0_f1_score: 0.76423 |  0:04:15s
epoch 10 | loss: 0.00528 | val_0_auc: 0.95356 | val_0_f1_score: 0.75758 |  0:04:42s
epoch 11 | loss: 0.00493 | val_0_auc: 0.94498 | val_0_f1_score: 0.78462 |  0



[I 2025-03-28 05:20:26,091] Trial 2 finished with value: 0.7846153846153846 and parameters: {'n_d': 64, 'n_a': 64, 'n_steps': 6, 'gamma': 2.0, 'lambda_sparse': 0.00665999788846571, 'lr': 0.0018485713748730651}. Best is trial 1 with value: 0.8333333333333334.




epoch 0  | loss: 0.32257 | val_0_auc: 0.28813 | val_0_f1_score: 0.02778 |  0:00:18s
epoch 1  | loss: 0.02314 | val_0_auc: 0.43393 | val_0_f1_score: 0.14815 |  0:00:37s
epoch 2  | loss: 0.01809 | val_0_auc: 0.49042 | val_0_f1_score: 0.12346 |  0:00:54s
epoch 3  | loss: 0.01416 | val_0_auc: 0.58987 | val_0_f1_score: 0.18182 |  0:01:13s
epoch 4  | loss: 0.01213 | val_0_auc: 0.70482 | val_0_f1_score: 0.31111 |  0:01:30s
epoch 5  | loss: 0.01041 | val_0_auc: 0.77503 | val_0_f1_score: 0.29907 |  0:01:46s
epoch 6  | loss: 0.00989 | val_0_auc: 0.8098  | val_0_f1_score: 0.41818 |  0:02:02s
epoch 7  | loss: 0.00909 | val_0_auc: 0.80947 | val_0_f1_score: 0.45045 |  0:02:18s
epoch 8  | loss: 0.00907 | val_0_auc: 0.8457  | val_0_f1_score: 0.55652 |  0:02:34s
epoch 9  | loss: 0.00867 | val_0_auc: 0.83892 | val_0_f1_score: 0.61017 |  0:02:50s
epoch 10 | loss: 0.00803 | val_0_auc: 0.86756 | val_0_f1_score: 0.59016 |  0:03:06s
epoch 11 | loss: 0.00764 | val_0_auc: 0.84094 | val_0_f1_score: 0.63415 |  0



[I 2025-03-28 05:29:10,304] Trial 3 finished with value: 0.72 and parameters: {'n_d': 64, 'n_a': 32, 'n_steps': 4, 'gamma': 1.5, 'lambda_sparse': 0.00011342536650419821, 'lr': 0.00015500344931567613}. Best is trial 1 with value: 0.8333333333333334.




epoch 0  | loss: 0.17099 | val_0_auc: 0.27258 | val_0_f1_score: 0.0     |  0:00:09s
epoch 1  | loss: 0.02605 | val_0_auc: 0.50443 | val_0_f1_score: 0.15    |  0:00:19s
epoch 2  | loss: 0.01751 | val_0_auc: 0.71247 | val_0_f1_score: 0.4     |  0:00:29s
epoch 3  | loss: 0.01372 | val_0_auc: 0.81846 | val_0_f1_score: 0.52427 |  0:00:38s
epoch 4  | loss: 0.01207 | val_0_auc: 0.87013 | val_0_f1_score: 0.74576 |  0:00:48s
epoch 5  | loss: 0.0106  | val_0_auc: 0.89178 | val_0_f1_score: 0.63063 |  0:00:57s
epoch 6  | loss: 0.00996 | val_0_auc: 0.893   | val_0_f1_score: 0.75806 |  0:01:07s
epoch 7  | loss: 0.00931 | val_0_auc: 0.89367 | val_0_f1_score: 0.78462 |  0:01:17s
epoch 8  | loss: 0.00872 | val_0_auc: 0.87333 | val_0_f1_score: 0.77165 |  0:01:26s
epoch 9  | loss: 0.00832 | val_0_auc: 0.882   | val_0_f1_score: 0.73684 |  0:01:36s
epoch 10 | loss: 0.00836 | val_0_auc: 0.88253 | val_0_f1_score: 0.68852 |  0:01:45s
epoch 11 | loss: 0.00786 | val_0_auc: 0.88947 | val_0_f1_score: 0.7619  |  0



[I 2025-03-28 05:36:08,124] Trial 4 finished with value: 0.8421052631578947 and parameters: {'n_d': 8, 'n_a': 32, 'n_steps': 3, 'gamma': 2.0, 'lambda_sparse': 0.005221911135507731, 'lr': 0.00032460570676150257}. Best is trial 4 with value: 0.8421052631578947.




epoch 0  | loss: 0.05316 | val_0_auc: 0.84499 | val_0_f1_score: 0.4902  |  0:00:16s
epoch 1  | loss: 0.01948 | val_0_auc: 0.88228 | val_0_f1_score: 0.32609 |  0:00:33s
epoch 2  | loss: 0.01549 | val_0_auc: 0.90677 | val_0_f1_score: 0.54545 |  0:00:50s
epoch 3  | loss: 0.01258 | val_0_auc: 0.89621 | val_0_f1_score: 0.60504 |  0:01:07s
epoch 4  | loss: 0.01088 | val_0_auc: 0.94916 | val_0_f1_score: 0.672   |  0:01:24s
epoch 5  | loss: 0.01013 | val_0_auc: 0.90816 | val_0_f1_score: 0.672   |  0:01:41s
epoch 6  | loss: 0.00941 | val_0_auc: 0.92837 | val_0_f1_score: 0.62903 |  0:01:57s
epoch 7  | loss: 0.0078  | val_0_auc: 0.95715 | val_0_f1_score: 0.63492 |  0:02:15s
epoch 8  | loss: 0.00707 | val_0_auc: 0.95227 | val_0_f1_score: 0.70588 |  0:02:31s
epoch 9  | loss: 0.00617 | val_0_auc: 0.9438  | val_0_f1_score: 0.72308 |  0:02:48s
epoch 10 | loss: 0.0056  | val_0_auc: 0.93171 | val_0_f1_score: 0.68966 |  0:03:05s
epoch 11 | loss: 0.00518 | val_0_auc: 0.92943 | val_0_f1_score: 0.69492 |  0



[I 2025-03-28 05:42:10,030] Trial 5 finished with value: 0.7230769230769231 and parameters: {'n_d': 32, 'n_a': 8, 'n_steps': 6, 'gamma': 1.0, 'lambda_sparse': 0.004354358263594677, 'lr': 0.002952142859579131}. Best is trial 4 with value: 0.8421052631578947.




epoch 0  | loss: 0.11109 | val_0_auc: 0.36989 | val_0_f1_score: 0.32609 |  0:00:12s
epoch 1  | loss: 0.02073 | val_0_auc: 0.59435 | val_0_f1_score: 0.32609 |  0:00:24s
epoch 2  | loss: 0.01566 | val_0_auc: 0.6312  | val_0_f1_score: 0.35789 |  0:00:37s
epoch 3  | loss: 0.01429 | val_0_auc: 0.66343 | val_0_f1_score: 0.31915 |  0:00:50s
epoch 4  | loss: 0.0106  | val_0_auc: 0.7913  | val_0_f1_score: 0.50467 |  0:01:04s
epoch 5  | loss: 0.00945 | val_0_auc: 0.77053 | val_0_f1_score: 0.53913 |  0:01:18s
epoch 6  | loss: 0.00952 | val_0_auc: 0.88736 | val_0_f1_score: 0.61538 |  0:01:31s
epoch 7  | loss: 0.00742 | val_0_auc: 0.85735 | val_0_f1_score: 0.61789 |  0:01:45s
epoch 8  | loss: 0.0071  | val_0_auc: 0.8757  | val_0_f1_score: 0.61157 |  0:01:59s
epoch 9  | loss: 0.00736 | val_0_auc: 0.93229 | val_0_f1_score: 0.61157 |  0:02:13s
epoch 10 | loss: 0.00634 | val_0_auc: 0.91694 | val_0_f1_score: 0.68908 |  0:02:26s
epoch 11 | loss: 0.00654 | val_0_auc: 0.90261 | val_0_f1_score: 0.55046 |  0



[I 2025-03-28 05:51:15,124] Trial 6 finished with value: 0.7903225806451613 and parameters: {'n_d': 16, 'n_a': 8, 'n_steps': 5, 'gamma': 1.0, 'lambda_sparse': 0.00036685753910227595, 'lr': 0.0007537588768137618}. Best is trial 4 with value: 0.8421052631578947.




epoch 0  | loss: 0.0211  | val_0_auc: 0.90155 | val_0_f1_score: 0.70492 |  0:00:15s
epoch 1  | loss: 0.00649 | val_0_auc: 0.95511 | val_0_f1_score: 0.74603 |  0:00:30s
epoch 2  | loss: 0.00546 | val_0_auc: 0.96531 | val_0_f1_score: 0.76423 |  0:00:45s
epoch 3  | loss: 0.00508 | val_0_auc: 0.95593 | val_0_f1_score: 0.8     |  0:01:00s
epoch 4  | loss: 0.00492 | val_0_auc: 0.9806  | val_0_f1_score: 0.75188 |  0:01:15s
epoch 5  | loss: 0.00456 | val_0_auc: 0.96623 | val_0_f1_score: 0.80916 |  0:01:30s
epoch 6  | loss: 0.00427 | val_0_auc: 0.96247 | val_0_f1_score: 0.8062  |  0:01:45s
epoch 7  | loss: 0.00442 | val_0_auc: 0.94871 | val_0_f1_score: 0.76923 |  0:01:59s
epoch 8  | loss: 0.00452 | val_0_auc: 0.95921 | val_0_f1_score: 0.7619  |  0:02:15s
epoch 9  | loss: 0.00402 | val_0_auc: 0.95821 | val_0_f1_score: 0.81538 |  0:02:29s
epoch 10 | loss: 0.00382 | val_0_auc: 0.95954 | val_0_f1_score: 0.8062  |  0:02:45s
epoch 11 | loss: 0.00403 | val_0_auc: 0.95313 | val_0_f1_score: 0.58621 |  0



[I 2025-03-28 05:57:54,781] Trial 7 finished with value: 0.8307692307692308 and parameters: {'n_d': 8, 'n_a': 16, 'n_steps': 6, 'gamma': 2.0, 'lambda_sparse': 0.0031452972956356328, 'lr': 0.008870714624775414}. Best is trial 4 with value: 0.8421052631578947.




epoch 0  | loss: 0.06299 | val_0_auc: 0.10785 | val_0_f1_score: 0.0     |  0:00:22s
epoch 1  | loss: 0.0346  | val_0_auc: 0.4705  | val_0_f1_score: 0.2     |  0:00:45s
epoch 2  | loss: 0.02249 | val_0_auc: 0.41403 | val_0_f1_score: 0.24444 |  0:01:07s
epoch 3  | loss: 0.01807 | val_0_auc: 0.57935 | val_0_f1_score: 0.2766  |  0:01:29s
epoch 4  | loss: 0.01559 | val_0_auc: 0.75199 | val_0_f1_score: 0.55932 |  0:01:52s
epoch 5  | loss: 0.01238 | val_0_auc: 0.72583 | val_0_f1_score: 0.592   |  0:02:14s
epoch 6  | loss: 0.01068 | val_0_auc: 0.74394 | val_0_f1_score: 0.51667 |  0:02:36s
epoch 7  | loss: 0.00967 | val_0_auc: 0.75267 | val_0_f1_score: 0.50847 |  0:02:58s
epoch 8  | loss: 0.01011 | val_0_auc: 0.75329 | val_0_f1_score: 0.48387 |  0:03:21s
epoch 9  | loss: 0.00923 | val_0_auc: 0.79983 | val_0_f1_score: 0.49123 |  0:03:44s
epoch 10 | loss: 0.00912 | val_0_auc: 0.79018 | val_0_f1_score: 0.51282 |  0:04:07s
epoch 11 | loss: 0.0089  | val_0_auc: 0.77859 | val_0_f1_score: 0.49573 |  0



[I 2025-03-28 06:13:10,718] Trial 8 finished with value: 0.7230769230769231 and parameters: {'n_d': 64, 'n_a': 64, 'n_steps': 5, 'gamma': 1.5, 'lambda_sparse': 1.64947322619163e-05, 'lr': 0.00013093000549781525}. Best is trial 4 with value: 0.8421052631578947.




epoch 0  | loss: 0.03955 | val_0_auc: 0.65386 | val_0_f1_score: 0.20455 |  0:00:13s
epoch 1  | loss: 0.01255 | val_0_auc: 0.59192 | val_0_f1_score: 0.26415 |  0:00:25s
epoch 2  | loss: 0.0095  | val_0_auc: 0.8388  | val_0_f1_score: 0.47273 |  0:00:38s
epoch 3  | loss: 0.00779 | val_0_auc: 0.89188 | val_0_f1_score: 0.47863 |  0:00:51s
epoch 4  | loss: 0.00651 | val_0_auc: 0.91396 | val_0_f1_score: 0.66102 |  0:01:04s
epoch 5  | loss: 0.00537 | val_0_auc: 0.87722 | val_0_f1_score: 0.59504 |  0:01:17s
epoch 6  | loss: 0.00594 | val_0_auc: 0.95052 | val_0_f1_score: 0.71429 |  0:01:30s
epoch 7  | loss: 0.00503 | val_0_auc: 0.9571  | val_0_f1_score: 0.69919 |  0:01:43s
epoch 8  | loss: 0.00437 | val_0_auc: 0.93685 | val_0_f1_score: 0.72581 |  0:01:57s
epoch 9  | loss: 0.00418 | val_0_auc: 0.94165 | val_0_f1_score: 0.7377  |  0:02:10s
epoch 10 | loss: 0.00411 | val_0_auc: 0.94447 | val_0_f1_score: 0.71429 |  0:02:23s
epoch 11 | loss: 0.0039  | val_0_auc: 0.96144 | val_0_f1_score: 0.752   |  0



[I 2025-03-28 06:20:21,340] Trial 9 finished with value: 0.7874015748031497 and parameters: {'n_d': 32, 'n_a': 16, 'n_steps': 4, 'gamma': 2.0, 'lambda_sparse': 1.7746607930209065e-05, 'lr': 0.0008677580582246937}. Best is trial 4 with value: 0.8421052631578947.


In [26]:
import joblib
import pandas as pd
from utils.style_utils import styled_print 

# Save study object
joblib.dump(study, "tuning_results/tabnet_study.pkl")
styled_print("📂Saved Optuna study to: tuning_results/tabnet_study.pkl")

# Save best params
best_params = study.best_params
pd.DataFrame([best_params]).to_csv("tuning_results/tabnet_best_params.csv", index=False)
styled_print("📂Saved best params to: tuning_results/tabnet_best_params.csv")


## **2. Re-train TabNet on Train+Val using Best Params**

After finding the optimal configuration, we merged the training and validation sets to retrain the TabNet model from scratch with the best hyperparameters.

Key details:

- Model: TabNetClassifier with best parameters from Optuna

- Dataset: Full Train + Val (70% + 15% of original)

- Training setup: Same as before (50 epochs max, early stopping, batch size = 256)

**⚠️ Note: Test data was not used during tuning or retraining. It is held out for final evaluation only.**

In [None]:
from pytorch_tabnet.tab_model import TabNetClassifier
import joblib
import numpy as np

styled_print(" Re-training Tabnet Model Using Best Parameters (NO test evaluation yet!)")
# Get best parameters from Optuna
best_params = study.best_params
styled_print(f" Best Params from Optuna: {best_params}")

# Merge train + val
X_trainval = np.vstack((X_train_scaled, X_val_scaled))
y_trainval = np.hstack((y_train, y_val))

# Rebuild model with best hyperparameters
tabnet_best = TabNetClassifier(
    n_d=best_params["n_d"],
    n_a=best_params["n_a"],
    n_steps=best_params["n_steps"],
    gamma=best_params["gamma"],
    lambda_sparse=best_params["lambda_sparse"],
    optimizer_params=dict(lr=best_params["lr"]),
    device_name="cpu",  
    verbose=1
)

# Retrain on full train+val
tabnet_best.fit(
    X_trainval, y_trainval,
    max_epochs=50,
    patience=10,
    batch_size=256,
    virtual_batch_size=128
)




epoch 0  | loss: 0.14228 |  0:00:11s
epoch 1  | loss: 0.01925 |  0:00:23s
epoch 2  | loss: 0.01201 |  0:00:36s
epoch 3  | loss: 0.00975 |  0:00:47s
epoch 4  | loss: 0.00906 |  0:00:58s
epoch 5  | loss: 0.00835 |  0:01:09s
epoch 6  | loss: 0.00773 |  0:01:20s
epoch 7  | loss: 0.00762 |  0:01:32s
epoch 8  | loss: 0.00694 |  0:01:41s
epoch 9  | loss: 0.00671 |  0:01:51s
epoch 10 | loss: 0.00644 |  0:02:01s
epoch 11 | loss: 0.0063  |  0:02:12s
epoch 12 | loss: 0.00605 |  0:02:24s
epoch 13 | loss: 0.00575 |  0:02:35s
epoch 14 | loss: 0.00568 |  0:02:46s
epoch 15 | loss: 0.00551 |  0:02:56s
epoch 16 | loss: 0.00529 |  0:03:07s
epoch 17 | loss: 0.00553 |  0:03:18s
epoch 18 | loss: 0.00534 |  0:03:30s
epoch 19 | loss: 0.00529 |  0:03:40s
epoch 20 | loss: 0.00516 |  0:03:50s
epoch 21 | loss: 0.00497 |  0:04:00s
epoch 22 | loss: 0.00493 |  0:04:09s
epoch 23 | loss: 0.00478 |  0:04:19s
epoch 24 | loss: 0.0046  |  0:04:28s
epoch 25 | loss: 0.00475 |  0:04:38s
epoch 26 | loss: 0.00483 |  0:04:48s
e

In [33]:
# Save the final retrained TabNet model after tuning
tabnet_best.save_model("tuning_results/tabnet_best_model_after_tuning")
styled_print("📂 Saved final TabNet model after tuning to: tuning_results/tabnet_best_model_after_tuning.zip")


Successfully saved model at tuning_results/tabnet_best_model_after_tuning.zip


In [34]:
#  Save the training history dictionary if needed later
joblib.dump(tabnet_best.history, "tuning_results/tabnet_best_history_after_tuning.pkl")
styled_print("📂 Training history saved to tuning_results/tabnet_best_history_after_tuning.pkl")


## **3.Unit Testing**

In [39]:
import unittest
import numpy as np
import joblib
import torch
from sklearn.metrics import f1_score
from pytorch_tabnet.tab_model import TabNetClassifier

# Load dataset once for all tests
data_path = "../artifacts/tabnet/data_scaled_nosmote_for_tabnet.pkl"
X_train_scaled, y_train, X_val_scaled, y_val, X_test_scaled, y_test = joblib.load(data_path)

# Dummy trial class for objective test
class DummyTrial:
    def suggest_categorical(self, name, choices):
        return choices[0]
    def suggest_int(self, name, low, high):
        return low
    def suggest_float(self, name, low, high, step=None, log=False):
        return low

class TestTabNetTuning(unittest.TestCase):

    def test_objective_function_returns_float(self):
        """Test the Optuna objective function returns a float F1-score"""
        from tuning.tabnet_tuning import objective  # Adjust if needed
        trial = DummyTrial()
        result = objective(trial)
        self.assertIsInstance(result, float)
        self.assertGreaterEqual(result, 0)
        self.assertLessEqual(result, 1)

    def test_best_params_keys(self):
        """Test that the best params dictionary contains required keys"""
        from tuning_results.tabnet_study import study  # Assuming loaded
        best_params = study.best_params
        required_keys = {"n_d", "n_a", "n_steps", "gamma", "lambda_sparse", "lr"}
        self.assertTrue(required_keys.issubset(set(best_params.keys())))

    def test_model_trains_on_best_params(self):
        """Test that TabNetClassifier trains on full train+val without errors"""
        from tuning_results.tabnet_study import study
        best_params = study.best_params

        model = TabNetClassifier(
            n_d=best_params["n_d"],
            n_a=best_params["n_a"],
            n_steps=best_params["n_steps"],
            gamma=best_params["gamma"],
            lambda_sparse=best_params["lambda_sparse"],
            optimizer_params=dict(lr=best_params["lr"]),
            device_name="cpu"
        )

        # Merge train + val
        X_trainval = np.vstack((X_train_scaled, X_val_scaled))
        y_trainval = np.hstack((y_train, y_val))

        try:
            model.fit(
                X_trainval, y_trainval,
                max_epochs=1,  # Keep fast for unit test
                patience=2,
                batch_size=256,
                virtual_batch_size=128
            )
        except Exception as e:
            self.fail(f"Training crashed with error: {e}")
