In [1]:
import os
import pandas as pd
import numpy as np
import json
import optuna
import xgboost as xgb
import ast
import time
from sklearn.model_selection import ParameterGrid
from typing import List, Dict, Union, Tuple

from constants import MODELS_DIR, RESULTS_DIR, PREDS_DIR, DEVICE
from data_handling import DataHandler # For type hinting
from metrics import softmax, weighted_softprob_obj, weighted_cross_entropy_eval # Import metrics
from optuna.integration import XGBoostPruningCallback

Using MPS device (Apple Silicon GPU)


In [2]:
def xgb_pruning_callback(trial, metric_name='weighted-CE'):
    """
    Creates a pruning callback function for XGBoost that reports mean validation 
    loss across all folds to Optuna at each boosting round. Returns a callback function compatible with xgb.cv()
    """
    def callback(env):
        current_round = env.iteration # Current boosting round
        mean_val_loss = env.evaluation_result_list[f'test-{metric_name}-mean'][current_round]
        
        # Report + pruning check
        trial.report(mean_val_loss, step=current_round)
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()
            
    return callback

In [3]:
def objective_xgb(trial: optuna.trial.Trial, 
                  dtrain: xgb.DMatrix,
                  num_boost_rounds: int = 150,
                  early_stopping_rounds: int = 30):
    """
    Objective function for XGBoost hyperparameter optimization using Optuna. Uses weighted_softprob_obj and weighted_cross_entropy_eval as custom objective and evaluation metric. Returns the mean val loss for the trial.
    """
    # Get hyperparams for the current trial
    params = {
        'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3, log=True),
        'gamma': trial.suggest_float('gamma', 0.0, 5.0, log=False),
        'subsample': trial.suggest_float('subsample', 0.5, 1.0, log=False),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.4, 1.0, log=False),
        'colsample_bylevel': trial.suggest_float('colsample_bylevel', 0.4, 1.0, log=False),
        'colsample_bynode': trial.suggest_float('colsample_bynode', 0.4, 1.0, log=False),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-8, 10.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 1e-8, 10.0, log=True),
        'max_depth': trial.suggest_int('max_depth', 3, 12, step=1),
        'min_child_weight': trial.suggest_int('min_child_weight', 1, 20, step=1)
    }
    
    # Create pruning callback
    pruning_callback = XGBoostPruningCallback(trial, "test-weighted-CE")
    
    # Run cross-validation
    cv_results = xgb.cv(
        params=params,
        dtrain=dtrain,
        num_boost_round=num_boost_rounds,
        nfold=3,
        obj=weighted_softprob_obj, #custom objective
        custom_metric=weighted_cross_entropy_eval, #custom eval, name='weighted-CE'
        maximize=False,
        early_stopping_rounds=early_stopping_rounds,
        callbacks=[pruning_callback],
        verbose_eval=False,
        shuffle=False # Ensures folds match what's used by other models
    )
    
    # Return best validation score
    best_score = cv_results['test-weighted-CE-mean'].min()
    return best_score

In [4]:
dh = DataHandler()

DataHandler initialized - Using 114 features - Test year: 2020


In [5]:
asha_pruner = optuna.pruners.SuccessiveHalvingPruner(
    min_resource=30,        # Minimum number of steps before pruning
    reduction_factor=2,    # Reduction factor for successive halving
    min_early_stopping_rate=0
)

In [6]:
study_xgb = optuna.create_study(study_name='xgb',
                                direction='minimize',
                                pruner=asha_pruner)

[I 2025-05-10 18:55:34,282] A new study created in memory with name: xgb


In [None]:
study_xgb.optimize(
    lambda trial: objective_xgb(
        trial,
        dtrain=dh.get_xgb_data('cv'),
        num_boost_rounds=150,
        early_stopping_rounds=50
    ),
    n_trials=128,
    timeout=1200,  
    n_jobs=-1)

[I 2025-05-10 18:55:36,989] Trial 5 pruned. Trial was pruned at iteration 30.
[I 2025-05-10 18:55:37,353] Trial 3 pruned. Trial was pruned at iteration 30.
[I 2025-05-10 18:55:37,361] Trial 0 pruned. Trial was pruned at iteration 30.
[I 2025-05-10 18:55:38,197] Trial 4 pruned. Trial was pruned at iteration 30.
[I 2025-05-10 18:55:38,821] Trial 6 pruned. Trial was pruned at iteration 30.
[I 2025-05-10 18:55:39,014] Trial 7 pruned. Trial was pruned at iteration 60.
[I 2025-05-10 18:55:39,027] Trial 2 finished with value: 1.0532096666666666 and parameters: {'learning_rate': 0.28868708991649145, 'gamma': 1.998709717187389, 'subsample': 0.9684527339681692, 'colsample_bytree': 0.520288944698347, 'colsample_bylevel': 0.9503962537459164, 'colsample_bynode': 0.7284891893907427, 'reg_alpha': 0.017156924784588613, 'reg_lambda': 5.254791852074559e-06, 'max_depth': 6, 'min_child_weight': 18}. Best is trial 2 with value: 1.0532096666666666.
[I 2025-05-10 18:55:39,665] Trial 1 finished with value: 1.

In [8]:
study_xgb.trials_dataframe()

Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_colsample_bylevel,params_colsample_bynode,params_colsample_bytree,params_gamma,params_learning_rate,params_max_depth,params_min_child_weight,params_reg_alpha,params_reg_lambda,params_subsample,system_attrs_completed_rung_0,system_attrs_completed_rung_1,system_attrs_completed_rung_2,state
0,0,1.060148,2025-05-10 18:55:34.294601,2025-05-10 18:55:37.360718,0 days 00:00:03.066117,0.619453,0.744504,0.455250,3.155444,0.052927,9,17,1.400882e-02,7.573176e-06,0.599214,1.060148,,,PRUNED
1,1,1.056798,2025-05-10 18:55:34.295790,2025-05-10 18:55:39.665239,0 days 00:00:05.369449,0.419097,0.494009,0.585941,3.947809,0.280337,7,1,4.729810e-05,1.233926e-06,0.569612,1.056935,1.056881,,COMPLETE
2,2,1.053210,2025-05-10 18:55:34.298171,2025-05-10 18:55:39.026599,0 days 00:00:04.728428,0.950396,0.728489,0.520289,1.998710,0.288687,6,18,1.715692e-02,5.254792e-06,0.968453,1.053545,1.053550,,COMPLETE
3,3,1.099273,2025-05-10 18:55:34.298809,2025-05-10 18:55:37.353226,0 days 00:00:03.054417,0.432110,0.687432,0.871089,2.236638,0.025216,10,16,8.079612e+00,2.074956e-05,0.835226,1.099273,,,PRUNED
4,4,1.061909,2025-05-10 18:55:34.301058,2025-05-10 18:55:38.197869,0 days 00:00:03.896811,0.978205,0.627944,0.753282,1.990733,0.043332,10,2,7.333269e-07,1.219932e-07,0.942530,1.061909,,,PRUNED
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
123,123,1.043835,2025-05-10 18:57:43.986001,2025-05-10 18:58:04.089535,0 days 00:00:20.103534,0.835394,0.541491,0.740982,0.013396,0.200993,6,11,1.699528e-02,1.407544e-05,0.988436,1.045306,1.044264,1.043978,COMPLETE
124,124,1.043538,2025-05-10 18:57:46.132792,2025-05-10 18:58:06.354715,0 days 00:00:20.221923,0.705945,0.775716,0.744554,0.008419,0.162091,6,11,6.228771e-03,5.429656e-05,0.984194,1.045736,1.044333,1.043757,COMPLETE
125,125,1.048523,2025-05-10 18:57:47.836633,2025-05-10 18:57:53.497755,0 days 00:00:05.661122,0.835576,0.546092,0.737406,0.271598,0.275938,6,12,1.662542e-03,7.838037e-05,0.502259,1.048852,1.048523,,PRUNED
126,126,1.044345,2025-05-10 18:57:53.501050,2025-05-10 18:58:05.261409,0 days 00:00:11.760359,0.832252,0.400140,0.721940,0.026444,0.200305,5,11,6.930063e-03,2.489261e-03,0.981526,1.045386,1.044538,1.044421,COMPLETE


In [9]:
temp = study_xgb.best_trial
temp.intermediate_values

{0: 1.173268,
 1: 1.1251286666666667,
 2: 1.0965896666666668,
 3: 1.0790056666666665,
 4: 1.0673276666666667,
 5: 1.0598423333333333,
 6: 1.054681,
 7: 1.051691,
 8: 1.0499313333333333,
 9: 1.0486379999999997,
 10: 1.0478406666666669,
 11: 1.0473876666666666,
 12: 1.047119666666667,
 13: 1.0469823333333332,
 14: 1.046796,
 15: 1.0466893333333334,
 16: 1.0466866666666668,
 17: 1.0464719999999998,
 18: 1.046333,
 19: 1.0463496666666667,
 20: 1.0464393333333335,
 21: 1.0463366666666667,
 22: 1.0461793333333331,
 23: 1.0459903333333334,
 24: 1.0460026666666666,
 25: 1.045864,
 26: 1.0457053333333333,
 27: 1.0456176666666668,
 28: 1.0455686666666668,
 29: 1.0455173333333334,
 30: 1.045447,
 31: 1.045417,
 32: 1.0452823333333334,
 33: 1.0451773333333332,
 34: 1.0452476666666668,
 35: 1.045169,
 36: 1.0450906666666666,
 37: 1.04509,
 38: 1.0450766666666667,
 39: 1.0449873333333333,
 40: 1.0449353333333333,
 41: 1.0448973333333333,
 42: 1.0448450000000002,
 43: 1.0448033333333333,
 44: 1.04478