In [9]:
"""
Optuna example that optimizes a classifier configuration for cancer dataset using
Catboost.

In this example, we optimize the validation accuracy of cancer detection using
Catboost. We optimize both the choice of booster model and their hyperparameters.

"""

import numpy as np
import optuna

import catboost as cb
from sklearn.datasets import load_breast_cancer
from sklearn.metrics import accuracy_score, roc_auc_score
from sklearn.model_selection import train_test_split

In [16]:
def objective(trial):
    data, target = load_breast_cancer(return_X_y=True)
    train_x, valid_x, train_y, valid_y = train_test_split(data, target, test_size=0.3)

    param = {
        "objective": trial.suggest_categorical("objective", ["Logloss", "CrossEntropy"]),
        "colsample_bylevel": trial.suggest_float("colsample_bylevel", 0.01, 0.1),
        "depth": trial.suggest_int("depth", 1, 12),
        "iterations": trial.suggest_int("iterations", 100, 1000),
        "boosting_type": trial.suggest_categorical("boosting_type", ["Ordered", "Plain"]),
        "bootstrap_type": trial.suggest_categorical(
        "bootstrap_type", ["Bayesian", "Bernoulli", "MVS"]
        ),
        "used_ram_limit": "3gb",
    }

    if param["bootstrap_type"] == "Bayesian":
        param["bagging_temperature"] = trial.suggest_float("bagging_temperature", 0, 10)
    elif param["bootstrap_type"] == "Bernoulli":
        param["subsample"] = trial.suggest_float("subsample", 0.1, 1)

    gbm = cb.CatBoostClassifier(**param)

    gbm.fit(train_x, train_y, eval_set=[(valid_x, valid_y)], verbose=0, early_stopping_rounds=100)

    preds = gbm.predict(valid_x)
    probas = gbm.predict_proba(valid_x)[:,1]
    pred_labels = np.rint(preds)
    accuracy = accuracy_score(valid_y, pred_labels)
    roc_score = roc_auc_score(valid_y, probas)
    return accuracy

In [17]:
if __name__ == "__main__":
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=100, timeout=600)

    print("Number of finished trials: {}".format(len(study.trials)))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: {}".format(trial.value))

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

[I 2023-07-11 23:24:07,306] A new study created in memory with name: no-name-f1b1a4a2-b76d-48d4-b419-bd42d9c231de
[I 2023-07-11 23:24:07,648] Trial 0 finished with value: 0.9824561403508771 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.05308735331373208, 'depth': 5, 'iterations': 833, 'boosting_type': 'Plain', 'bootstrap_type': 'MVS'}. Best is trial 0 with value: 0.9824561403508771.
[I 2023-07-11 23:24:08,828] Trial 1 finished with value: 0.9766081871345029 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.07234153489450942, 'depth': 5, 'iterations': 613, 'boosting_type': 'Ordered', 'bootstrap_type': 'Bayesian', 'bagging_temperature': 5.314246967107663}. Best is trial 0 with value: 0.9824561403508771.
[I 2023-07-11 23:24:12,761] Trial 2 finished with value: 0.9766081871345029 and parameters: {'objective': 'CrossEntropy', 'colsample_bylevel': 0.07469047013289508, 'depth': 7, 'iterations': 896, 'boosting_type': 'Ordered', 'bootstrap_type': 'Bayesian', 'bag

[I 2023-07-11 23:24:50,657] Trial 27 finished with value: 0.9766081871345029 and parameters: {'objective': 'CrossEntropy', 'colsample_bylevel': 0.039118274921630655, 'depth': 8, 'iterations': 665, 'boosting_type': 'Plain', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:24:57,545] Trial 28 finished with value: 0.9532163742690059 and parameters: {'objective': 'CrossEntropy', 'colsample_bylevel': 0.07584575218287912, 'depth': 9, 'iterations': 565, 'boosting_type': 'Ordered', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:24:59,258] Trial 29 finished with value: 0.9649122807017544 and parameters: {'objective': 'CrossEntropy', 'colsample_bylevel': 0.07146185864168826, 'depth': 5, 'iterations': 789, 'boosting_type': 'Ordered', 'bootstrap_type': 'Bernoulli', 'subsample': 0.11424938838413579}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:24:59,574] Trial 30 finished with value: 0.9473

[I 2023-07-11 23:25:15,672] Trial 54 finished with value: 0.9473684210526315 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.06258420148527107, 'depth': 5, 'iterations': 658, 'boosting_type': 'Plain', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:15,836] Trial 55 finished with value: 0.9766081871345029 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.07166573412897898, 'depth': 5, 'iterations': 292, 'boosting_type': 'Plain', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:16,350] Trial 56 finished with value: 0.9707602339181286 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.06393391438226982, 'depth': 7, 'iterations': 820, 'boosting_type': 'Plain', 'bootstrap_type': 'Bernoulli', 'subsample': 0.44864520326142865}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:16,745] Trial 57 finished with value: 0.9649122807017544 and par

[I 2023-07-11 23:25:31,923] Trial 82 finished with value: 0.9883040935672515 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.03850878628667513, 'depth': 11, 'iterations': 479, 'boosting_type': 'Ordered', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:33,701] Trial 83 finished with value: 0.9707602339181286 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.03808845308313217, 'depth': 11, 'iterations': 473, 'boosting_type': 'Ordered', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:37,330] Trial 84 finished with value: 0.9590643274853801 and parameters: {'objective': 'Logloss', 'colsample_bylevel': 0.04984269770988129, 'depth': 11, 'iterations': 421, 'boosting_type': 'Ordered', 'bootstrap_type': 'MVS'}. Best is trial 26 with value: 0.9941520467836257.
[I 2023-07-11 23:25:40,458] Trial 85 finished with value: 0.9649122807017544 and parameters: {'objective': 'Logloss

Number of finished trials: 100
Best trial:
  Value: 1.0
  Params: 
    objective: Logloss
    colsample_bylevel: 0.04195581634874234
    depth: 12
    iterations: 430
    boosting_type: Ordered
    bootstrap_type: MVS


In [13]:
params={"loss_function":"CrossEntropy", 
        "colsample_bylevel":.06, 
        "depth":7, 
        "boosting_type":"Ordered", 
        "bootstrap_type":"Bayesian",
        "bagging_temperature":7,
        "custom_metric":"AUC", 
        "iterations":1000
       }
model=cb.CatBoostClassifier(**params, )

In [14]:
data_, target_ = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(data_, target_, test_size=0.3, random_state=42)

In [15]:
model.fit(X_train, y_train, eval_set=(X_test, y_test), plot=True)

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))

0:	learn: 0.6574026	test: 0.6544608	best: 0.6544608 (0)	total: 1.56ms	remaining: 1.55s
1:	learn: 0.6359846	test: 0.6307700	best: 0.6307700 (1)	total: 4.09ms	remaining: 2.04s
2:	learn: 0.6258802	test: 0.6201740	best: 0.6201740 (2)	total: 5.31ms	remaining: 1.76s
3:	learn: 0.6214590	test: 0.6152132	best: 0.6152132 (3)	total: 6.69ms	remaining: 1.67s
4:	learn: 0.6114832	test: 0.6039461	best: 0.6039461 (4)	total: 7.98ms	remaining: 1.59s
5:	learn: 0.5716921	test: 0.5658409	best: 0.5658409 (5)	total: 15ms	remaining: 2.48s
6:	learn: 0.5688864	test: 0.5635638	best: 0.5635638 (6)	total: 16.1ms	remaining: 2.29s
7:	learn: 0.5432266	test: 0.5392821	best: 0.5392821 (7)	total: 17.7ms	remaining: 2.19s
8:	learn: 0.5126781	test: 0.5095718	best: 0.5095718 (8)	total: 25.8ms	remaining: 2.84s
9:	learn: 0.4852392	test: 0.4815779	best: 0.4815779 (9)	total: 36.3ms	remaining: 3.59s
10:	learn: 0.4679319	test: 0.4656052	best: 0.4656052 (10)	total: 41ms	remaining: 3.69s
11:	learn: 0.4573780	test: 0.4537265	best: 0.

134:	learn: 0.1003997	test: 0.0996323	best: 0.0996323 (134)	total: 395ms	remaining: 2.53s
135:	learn: 0.1003834	test: 0.0995592	best: 0.0995592 (135)	total: 400ms	remaining: 2.54s
136:	learn: 0.1003831	test: 0.0995605	best: 0.0995592 (135)	total: 401ms	remaining: 2.53s
137:	learn: 0.1002697	test: 0.0995567	best: 0.0995567 (137)	total: 402ms	remaining: 2.51s
138:	learn: 0.1002697	test: 0.0995567	best: 0.0995567 (137)	total: 403ms	remaining: 2.5s
139:	learn: 0.0985921	test: 0.0989280	best: 0.0989280 (139)	total: 409ms	remaining: 2.51s
140:	learn: 0.0983373	test: 0.0985739	best: 0.0985739 (140)	total: 410ms	remaining: 2.5s
141:	learn: 0.0983373	test: 0.0985739	best: 0.0985739 (140)	total: 411ms	remaining: 2.48s
142:	learn: 0.0974930	test: 0.0981406	best: 0.0981406 (142)	total: 424ms	remaining: 2.54s
143:	learn: 0.0971644	test: 0.0979719	best: 0.0979719 (143)	total: 425ms	remaining: 2.52s
144:	learn: 0.0967031	test: 0.0974226	best: 0.0974226 (144)	total: 427ms	remaining: 2.52s
145:	learn: 

271:	learn: 0.0601068	test: 0.0742251	best: 0.0742251 (271)	total: 795ms	remaining: 2.13s
272:	learn: 0.0598147	test: 0.0741375	best: 0.0741375 (272)	total: 802ms	remaining: 2.13s
273:	learn: 0.0598147	test: 0.0741376	best: 0.0741375 (272)	total: 803ms	remaining: 2.13s
274:	learn: 0.0591990	test: 0.0738801	best: 0.0738801 (274)	total: 812ms	remaining: 2.14s
275:	learn: 0.0589647	test: 0.0739602	best: 0.0738801 (274)	total: 818ms	remaining: 2.15s
276:	learn: 0.0589647	test: 0.0739602	best: 0.0738801 (274)	total: 819ms	remaining: 2.14s
277:	learn: 0.0588338	test: 0.0738059	best: 0.0738059 (277)	total: 825ms	remaining: 2.14s
278:	learn: 0.0587657	test: 0.0738488	best: 0.0738059 (277)	total: 830ms	remaining: 2.15s
279:	learn: 0.0586271	test: 0.0739664	best: 0.0738059 (277)	total: 831ms	remaining: 2.14s
280:	learn: 0.0585345	test: 0.0740789	best: 0.0738059 (277)	total: 832ms	remaining: 2.13s
281:	learn: 0.0579027	test: 0.0739083	best: 0.0738059 (277)	total: 839ms	remaining: 2.13s
282:	learn

398:	learn: 0.0424857	test: 0.0680325	best: 0.0680325 (398)	total: 1.2s	remaining: 1.8s
399:	learn: 0.0424857	test: 0.0680325	best: 0.0680325 (398)	total: 1.2s	remaining: 1.8s
400:	learn: 0.0420755	test: 0.0680552	best: 0.0680325 (398)	total: 1.2s	remaining: 1.79s
401:	learn: 0.0420755	test: 0.0680553	best: 0.0680325 (398)	total: 1.2s	remaining: 1.79s
402:	learn: 0.0420755	test: 0.0680553	best: 0.0680325 (398)	total: 1.2s	remaining: 1.78s
403:	learn: 0.0418059	test: 0.0676769	best: 0.0676769 (403)	total: 1.2s	remaining: 1.78s
404:	learn: 0.0417925	test: 0.0676614	best: 0.0676614 (404)	total: 1.21s	remaining: 1.77s
405:	learn: 0.0417925	test: 0.0676614	best: 0.0676614 (404)	total: 1.21s	remaining: 1.76s
406:	learn: 0.0413888	test: 0.0676323	best: 0.0676323 (406)	total: 1.21s	remaining: 1.77s
407:	learn: 0.0412031	test: 0.0674776	best: 0.0674776 (407)	total: 1.22s	remaining: 1.77s
408:	learn: 0.0412030	test: 0.0674778	best: 0.0674776 (407)	total: 1.22s	remaining: 1.76s
409:	learn: 0.0410

512:	learn: 0.0316123	test: 0.0622358	best: 0.0621246 (510)	total: 1.59s	remaining: 1.51s
513:	learn: 0.0314406	test: 0.0621042	best: 0.0621042 (513)	total: 1.6s	remaining: 1.51s
514:	learn: 0.0313285	test: 0.0619879	best: 0.0619879 (514)	total: 1.6s	remaining: 1.51s
515:	learn: 0.0312282	test: 0.0617608	best: 0.0617608 (515)	total: 1.61s	remaining: 1.51s
516:	learn: 0.0311855	test: 0.0617312	best: 0.0617312 (516)	total: 1.61s	remaining: 1.5s
517:	learn: 0.0310547	test: 0.0616649	best: 0.0616649 (517)	total: 1.61s	remaining: 1.5s
518:	learn: 0.0310547	test: 0.0616648	best: 0.0616648 (518)	total: 1.61s	remaining: 1.49s
519:	learn: 0.0310342	test: 0.0616426	best: 0.0616426 (519)	total: 1.61s	remaining: 1.49s
520:	learn: 0.0309873	test: 0.0615797	best: 0.0615797 (520)	total: 1.62s	remaining: 1.49s
521:	learn: 0.0307465	test: 0.0614299	best: 0.0614299 (521)	total: 1.63s	remaining: 1.49s
522:	learn: 0.0307368	test: 0.0614245	best: 0.0614245 (522)	total: 1.63s	remaining: 1.49s
523:	learn: 0.

639:	learn: 0.0239131	test: 0.0583596	best: 0.0583586 (638)	total: 1.99s	remaining: 1.12s
640:	learn: 0.0238269	test: 0.0582077	best: 0.0582077 (640)	total: 1.99s	remaining: 1.11s
641:	learn: 0.0237824	test: 0.0581851	best: 0.0581851 (641)	total: 2s	remaining: 1.11s
642:	learn: 0.0237824	test: 0.0581850	best: 0.0581850 (642)	total: 2s	remaining: 1.11s
643:	learn: 0.0237233	test: 0.0579487	best: 0.0579487 (643)	total: 2s	remaining: 1.1s
644:	learn: 0.0235933	test: 0.0577431	best: 0.0577431 (644)	total: 2s	remaining: 1.1s
645:	learn: 0.0235933	test: 0.0577430	best: 0.0577430 (645)	total: 2s	remaining: 1.1s
646:	learn: 0.0234679	test: 0.0576534	best: 0.0576534 (646)	total: 2.01s	remaining: 1.09s
647:	learn: 0.0234555	test: 0.0576279	best: 0.0576279 (647)	total: 2.01s	remaining: 1.09s
648:	learn: 0.0234502	test: 0.0576332	best: 0.0576279 (647)	total: 2.01s	remaining: 1.09s
649:	learn: 0.0233188	test: 0.0576725	best: 0.0576279 (647)	total: 2.02s	remaining: 1.09s
650:	learn: 0.0231730	test: 

775:	learn: 0.0181889	test: 0.0550153	best: 0.0548458 (770)	total: 2.39s	remaining: 691ms
776:	learn: 0.0181822	test: 0.0550084	best: 0.0548458 (770)	total: 2.4s	remaining: 689ms
777:	learn: 0.0181325	test: 0.0549964	best: 0.0548458 (770)	total: 2.4s	remaining: 686ms
778:	learn: 0.0181325	test: 0.0549964	best: 0.0548458 (770)	total: 2.4s	remaining: 682ms
779:	learn: 0.0181325	test: 0.0549964	best: 0.0548458 (770)	total: 2.4s	remaining: 678ms
780:	learn: 0.0181067	test: 0.0549483	best: 0.0548458 (770)	total: 2.41s	remaining: 677ms
781:	learn: 0.0181066	test: 0.0549483	best: 0.0548458 (770)	total: 2.42s	remaining: 673ms
782:	learn: 0.0180890	test: 0.0549082	best: 0.0548458 (770)	total: 2.42s	remaining: 670ms
783:	learn: 0.0180603	test: 0.0548857	best: 0.0548458 (770)	total: 2.42s	remaining: 666ms
784:	learn: 0.0179815	test: 0.0547606	best: 0.0547606 (784)	total: 2.42s	remaining: 664ms
785:	learn: 0.0178938	test: 0.0546330	best: 0.0546330 (785)	total: 2.43s	remaining: 662ms
786:	learn: 0.

903:	learn: 0.0145647	test: 0.0520390	best: 0.0520390 (903)	total: 2.79s	remaining: 297ms
904:	learn: 0.0145647	test: 0.0520390	best: 0.0520390 (904)	total: 2.8s	remaining: 294ms
905:	learn: 0.0145647	test: 0.0520390	best: 0.0520390 (905)	total: 2.8s	remaining: 290ms
906:	learn: 0.0145582	test: 0.0519618	best: 0.0519618 (906)	total: 2.8s	remaining: 287ms
907:	learn: 0.0144822	test: 0.0518725	best: 0.0518725 (907)	total: 2.81s	remaining: 284ms
908:	learn: 0.0144779	test: 0.0518747	best: 0.0518725 (907)	total: 2.81s	remaining: 281ms
909:	learn: 0.0144709	test: 0.0518389	best: 0.0518389 (909)	total: 2.82s	remaining: 279ms
910:	learn: 0.0144611	test: 0.0518422	best: 0.0518389 (909)	total: 2.82s	remaining: 276ms
911:	learn: 0.0144036	test: 0.0518573	best: 0.0518389 (909)	total: 2.82s	remaining: 272ms
912:	learn: 0.0143846	test: 0.0518839	best: 0.0518389 (909)	total: 2.82s	remaining: 269ms
913:	learn: 0.0143846	test: 0.0518839	best: 0.0518389 (909)	total: 2.83s	remaining: 266ms
914:	learn: 0

<catboost.core.CatBoostClassifier at 0x7f847e0ee160>