# Setup

In [1]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from catboost import CatBoostClassifier
import optuna
from optuna.integration import CatBoostPruningCallback

In [7]:
import matplotlib
%matplotlib inline

# Optuna

In [2]:
def load_data():
    data = pd.read_csv(
        'https://raw.githubusercontent.com/antbartash/australian_rain/main/data/data_transformed.csv',
        index_col=0
    )
    X, y = data.drop(columns=['RainTomorrow', 'RainToday']), data['RainTomorrow']
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    for column in ['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm']:
        X_train[column] = X_train[column].astype(np.float32).fillna(-1).apply(lambda x: str(x))
        X_test[column] = X_test[column].astype(np.float32).fillna(-1).apply(lambda x: str(x))
    return X_train, X_test, y_train, y_test

In [3]:
def objective(trial):
    PARAMS = {
        'n_estimators': trial.suggest_int('n_estimators', 50, 1000),
        'learning_rate': trial.suggest_float('learning_rate', 1e-6, 0.5),
        'depth': trial.suggest_int('depth', 1, 10),
        'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 0.0, 100.0),
        'random_strength': trial.suggest_float('random_strength', 0.0, 100.0), # CPU only
        'bagging_temperature': trial.suggest_float('bagging_temperature', 0.0, 100.0),
        'grow_policy': trial.suggest_categorical('grow_policy', ['SymmetricTree', 'Depthwise']),
        'scale_pos_weight': trial.suggest_float('scale_pos_weight', 1, 4)
    }
    X_train, X_valid, y_train, y_valid = load_data()
    pruning_callback = CatBoostPruningCallback(trial, 'AUC')
    model = CatBoostClassifier(
        cat_features=['Location', 'WindGustDir', 'WindDir9am', 'WindDir3pm'],
        custom_metric='AUC',
        random_state=42, verbose=False, task_type='CPU'
    )
    model.set_params(**PARAMS)
    model.fit(X_train, y_train, 
              eval_set=(X_valid, y_valid),
              callbacks=[pruning_callback]
    )
    score = np.mean(cross_val_score(model, X_train, y_train, cv=3, scoring='roc_auc'))
    return score

In [4]:
sampler = optuna.samplers.TPESampler(
    n_startup_trials = 5, # the random sampling is used instead of the TPE algorithm until the given number of trials finish in the same study
    n_ei_candidates = 24, # number of candidate samples used to calculate the expected improvement.
    multivariate = True, # multivariate TPE when suggesting candidates; default: False
    seed = 42
)
study = optuna.create_study(
    pruner=optuna.pruners.MedianPruner(n_startup_trials=10, n_warmup_steps=100, interval_steps=20),
    direction='maximize',
    sampler=sampler,
    storage='sqlite:///db.sqlite3'
)
study.optimize(
    objective, n_trials = 100,
    timeout = 3600, # in seconds
    n_jobs = 1,
    show_progress_bar = True
)

[I 2023-12-29 09:53:04,726] A new study created in RDB with name: no-name-568d9569-4f33-4c33-a3b3-53cde451dcff


  0%|          | 0/100 [00:00<?, ?it/s]

  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 09:55:08,071] Trial 0 finished with value: 0.8797334698575231 and parameters: {'n_estimators': 406, 'learning_rate': 0.47535720249065166, 'depth': 8, 'l2_leaf_reg': 59.86584841970366, 'random_strength': 15.601864044243651, 'bagging_temperature': 15.599452033620265, 'grow_policy': 'Depthwise', 'scale_pos_weight': 2.8033450352296265}. Best is trial 0 with value: 0.8797334698575231.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 09:59:48,810] Trial 1 finished with value: 0.8704546881053975 and parameters: {'n_estimators': 723, 'learning_rate': 0.010293226563406928, 'depth': 10, 'l2_leaf_reg': 83.24426408004217, 'random_strength': 21.233911067827616, 'bagging_temperature': 18.182496720710063, 'grow_policy': 'Depthwise', 'scale_pos_weight': 2.5742692948967134}. Best is trial 0 with value: 0.8797334698575231.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:01:44,511] Trial 2 finished with value: 0.8935199672021055 and parameters: {'n_estimators': 460, 'learning_rate': 0.14561527886988077, 'depth': 7, 'l2_leaf_reg': 13.949386065204184, 'random_strength': 29.214464853521815, 'bagging_temperature': 36.63618432936917, 'grow_policy': 'Depthwise', 'scale_pos_weight': 1.5990213464750793}. Best is trial 2 with value: 0.8935199672021055.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:02:45,677] Trial 3 finished with value: 0.8712304584991265 and parameters: {'n_estimators': 539, 'learning_rate': 0.29620769201645236, 'depth': 1, 'l2_leaf_reg': 60.75448519014384, 'random_strength': 17.052412368729154, 'bagging_temperature': 6.505159298527952, 'grow_policy': 'Depthwise', 'scale_pos_weight': 3.4251920443493833}. Best is trial 2 with value: 0.8935199672021055.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:04:09,776] Trial 4 finished with value: 0.8880178942391005 and parameters: {'n_estimators': 339, 'learning_rate': 0.048836959331077935, 'depth': 7, 'l2_leaf_reg': 44.01524937396013, 'random_strength': 12.203823484477883, 'bagging_temperature': 49.51769101112702, 'grow_policy': 'Depthwise', 'scale_pos_weight': 1.7763399448000508}. Best is trial 2 with value: 0.8935199672021055.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:06:19,355] Trial 5 finished with value: 0.8930161462570849 and parameters: {'n_estimators': 675, 'learning_rate': 0.12827792148317357, 'depth': 5, 'l2_leaf_reg': 4.499247574509964, 'random_strength': 9.821377188217419, 'bagging_temperature': 10.610448900710598, 'grow_policy': 'Depthwise', 'scale_pos_weight': 1.01966941141605}. Best is trial 2 with value: 0.8935199672021055.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:11:51,715] Trial 6 finished with value: 0.890796254508866 and parameters: {'n_estimators': 824, 'learning_rate': 0.1601745181551702, 'depth': 9, 'l2_leaf_reg': 14.13974149941982, 'random_strength': 30.677102273465298, 'bagging_temperature': 89.45995238443746, 'grow_policy': 'Depthwise', 'scale_pos_weight': 1.8426083349112128}. Best is trial 2 with value: 0.8935199672021055.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:13:33,630] Trial 7 finished with value: 0.8943794431860274 and parameters: {'n_estimators': 285, 'learning_rate': 0.11045440456696545, 'depth': 9, 'l2_leaf_reg': 25.506477011446293, 'random_strength': 68.37798681868179, 'bagging_temperature': 37.960424342022336, 'grow_policy': 'Depthwise', 'scale_pos_weight': 2.592487076994037}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:14:35,473] Trial 8 finished with value: 0.8888958337807894 and parameters: {'n_estimators': 140, 'learning_rate': 0.15011135293329173, 'depth': 10, 'l2_leaf_reg': 2.7855516382359227, 'random_strength': 91.89539888709731, 'bagging_temperature': 42.521483697754455, 'grow_policy': 'Depthwise', 'scale_pos_weight': 3.1940141809393214}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:19:50,931] Trial 9 finished with value: 0.8885005549041874 and parameters: {'n_estimators': 642, 'learning_rate': 0.2466363022911385, 'depth': 10, 'l2_leaf_reg': 40.98339260564337, 'random_strength': 64.00017511019783, 'bagging_temperature': 30.798972982811634, 'grow_policy': 'Depthwise', 'scale_pos_weight': 2.404755750102874}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:21:05,140] Trial 10 finished with value: 0.8541029101619331 and parameters: {'n_estimators': 309, 'learning_rate': 0.008343868327088888, 'depth': 9, 'l2_leaf_reg': 54.090549664426625, 'random_strength': 82.50121028007065, 'bagging_temperature': 34.673883258091195, 'grow_policy': 'Depthwise', 'scale_pos_weight': 2.520092287049116}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:22:24,838] Trial 11 finished with value: 0.8484577504524845 and parameters: {'n_estimators': 522, 'learning_rate': 0.0026838979662154316, 'depth': 5, 'l2_leaf_reg': 3.9503501308004516, 'random_strength': 33.3207567891878, 'bagging_temperature': 55.316576437476186, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 2.153065231169769}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:24:07,389] Trial 12 finished with value: 0.8646542444750248 and parameters: {'n_estimators': 652, 'learning_rate': 0.018024402262590222, 'depth': 7, 'l2_leaf_reg': 9.625984240500463, 'random_strength': 74.62826899353799, 'bagging_temperature': 22.206797745965492, 'grow_policy': 'Depthwise', 'scale_pos_weight': 1.506548912617304}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:25:15,675] Trial 13 finished with value: 0.8905311973493295 and parameters: {'n_estimators': 361, 'learning_rate': 0.24074759066797208, 'depth': 5, 'l2_leaf_reg': 9.147610225809355, 'random_strength': 58.954716212976805, 'bagging_temperature': 58.601095396488034, 'grow_policy': 'Depthwise', 'scale_pos_weight': 3.1337538714438673}. Best is trial 7 with value: 0.8943794431860274.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:27:24,398] Trial 14 finished with value: 0.8978171906452436 and parameters: {'n_estimators': 220, 'learning_rate': 0.15033022742840282, 'depth': 10, 'l2_leaf_reg': 5.032799991220095, 'random_strength': 68.59716312901651, 'bagging_temperature': 36.096982905396814, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.8220212547009424}. Best is trial 14 with value: 0.8978171906452436.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:27:51,269] Trial 15 finished with value: 0.8766052728101625 and parameters: {'n_estimators': 97, 'learning_rate': 0.1602752654675233, 'depth': 10, 'l2_leaf_reg': 11.623144703741932, 'random_strength': 50.36648761429636, 'bagging_temperature': 64.99881789347448, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 2.3525352347620903}. Best is trial 14 with value: 0.8978171906452436.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:31:15,005] Trial 16 finished with value: 0.900019763444437 and parameters: {'n_estimators': 421, 'learning_rate': 0.1535779005240276, 'depth': 9, 'l2_leaf_reg': 22.81240481274744, 'random_strength': 82.58415186620924, 'bagging_temperature': 31.55900421424277, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 2.093530609725283}. Best is trial 16 with value: 0.900019763444437.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:33:21,236] Trial 17 finished with value: 0.9010773253239965 and parameters: {'n_estimators': 472, 'learning_rate': 0.13964642869291113, 'depth': 8, 'l2_leaf_reg': 19.78413735015228, 'random_strength': 83.4060283987825, 'bagging_temperature': 37.90671864084201, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.1894703562782643}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:36:26,991] Trial 18 finished with value: 0.9005179239363371 and parameters: {'n_estimators': 608, 'learning_rate': 0.23837725656971387, 'depth': 6, 'l2_leaf_reg': 27.71841825059886, 'random_strength': 91.46734029149297, 'bagging_temperature': 14.521399853988633, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.830526138560691}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:37:14,172] Trial 19 finished with value: 0.8949585506212226 and parameters: {'n_estimators': 372, 'learning_rate': 0.22629769931869645, 'depth': 3, 'l2_leaf_reg': 37.905670149827635, 'random_strength': 94.93884665892139, 'bagging_temperature': 13.729086467051836, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.3788039399339103}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:41:16,061] Trial 20 finished with value: 0.9004429896846875 and parameters: {'n_estimators': 783, 'learning_rate': 0.21967402959625598, 'depth': 6, 'l2_leaf_reg': 25.49892885728805, 'random_strength': 83.29762732176201, 'bagging_temperature': 3.874435138162335, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.748497124874134}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:46:06,010] Trial 21 finished with value: 0.8968643221787637 and parameters: {'n_estimators': 811, 'learning_rate': 0.28006734725736676, 'depth': 7, 'l2_leaf_reg': 23.859851537190742, 'random_strength': 80.16620038341391, 'bagging_temperature': 1.4627951699830355, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.1507877884277962}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:49:19,798] Trial 22 finished with value: 0.8941830538948542 and parameters: {'n_estimators': 618, 'learning_rate': 0.40446875109301417, 'depth': 6, 'l2_leaf_reg': 20.165808276983356, 'random_strength': 96.72916966351906, 'bagging_temperature': 44.49162606757371, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 2.491023083706337}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:50:56,221] Trial 23 finished with value: 0.90060082050988 and parameters: {'n_estimators': 570, 'learning_rate': 0.16965462869310532, 'depth': 5, 'l2_leaf_reg': 24.93684743624295, 'random_strength': 94.65568845231142, 'bagging_temperature': 1.2598103281196344, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 2.737266650319726}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:51:59,078] Trial 24 finished with value: 0.8956000362364884 and parameters: {'n_estimators': 541, 'learning_rate': 0.15019789086678176, 'depth': 3, 'l2_leaf_reg': 27.520778004197876, 'random_strength': 93.9588142253623, 'bagging_temperature': 19.271766709381996, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 3.3852177536046892}. Best is trial 17 with value: 0.9010773253239965.


  pruning_callback = CatBoostPruningCallback(trial, 'AUC')


[I 2023-12-29 10:55:02,222] Trial 25 finished with value: 0.9018302574773327 and parameters: {'n_estimators': 722, 'learning_rate': 0.07139217176610062, 'depth': 8, 'l2_leaf_reg': 28.24988726549349, 'random_strength': 87.72871371560576, 'bagging_temperature': 54.029297770604856, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.1872960291692018}. Best is trial 25 with value: 0.9018302574773327.


In [5]:
print(f"Best Gini: {study.best_trial.value * 2 - 1}")
print(f"Best params: {study.best_trial.params}")

Best Gini: 0.8036605149546654
Best params: {'n_estimators': 722, 'learning_rate': 0.07139217176610062, 'depth': 8, 'l2_leaf_reg': 28.24988726549349, 'random_strength': 87.72871371560576, 'bagging_temperature': 54.029297770604856, 'grow_policy': 'SymmetricTree', 'scale_pos_weight': 1.1872960291692018}


# Plots

In [8]:
optuna.visualization.plot_edf(study)

In [9]:
optuna.visualization.plot_optimization_history(study)

In [10]:
optuna.visualization.plot_contour(study, params=['n_estimators', 'learning_rate'])

In [12]:
optuna.visualization.plot_parallel_coordinate(study)

In [13]:
optuna.visualization.plot_param_importances(study)

In [14]:
optuna.visualization.plot_slice(study)

In [15]:
optuna.visualization.plot_timeline(study)


plot_timeline is experimental (supported from v3.2.0). The interface can change in the future.

