In [1]:
## XGBoost练习2

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from sklearn.datasets import load_breast_cancer

import xgboost

In [3]:
def training_continuation():
    """
    基本的连续训练方法
    :return:
    """
    # 连续训练128轮
    X, y = load_breast_cancer(return_X_y=True)
    model = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss")
    model.fit(X=X, y=y, eval_set=[(X, y)])
    print("Total boosted rounds: ", model.get_booster().num_boosted_rounds())

    # 先训练32轮，保存模型，然后在训练96轮，从而达到128轮的目的。
    model_2 = xgboost.XGBClassifier(n_estimators=32, eval_metric="logloss")
    model_2.fit(X=X, y=y, eval_set=[(X, y)])
    assert model_2.get_booster().num_boosted_rounds() == 32

    # 保存模型并重新加载
    model_2.save_model("model/model.json")
    loaded_model = xgboost.XGBClassifier()
    loaded_model.load_model("model/model.json")

    # 在以前模型的基础上进行训练
    model_3 = xgboost.XGBClassifier(n_estimators=96, eval_metric="logloss")
    model_3.fit(X=X, y=y, eval_set=[(X, y)], xgb_model=loaded_model)
    print("Total boosted rounds:", model_3.get_booster().num_boosted_rounds())


In [4]:
def training_continuation_early_stop():
    # 早期停止参数
    early_stopping_rounds = 5
    early_stop = xgboost.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True)

    X, y = load_breast_cancer(return_X_y=True)
    model = xgboost.XGBClassifier(n_estimators=512, eval_metric="logloss", callbacks=[early_stop])
    model.fit(X, y, eval_set=[(X, y)])
    print("Total boosted rounds:", model.get_booster().num_boosted_rounds())
    # 196轮后大道最优，best的数值为195。
    best = model.best_iteration

    # 先训练128轮，将模型进行保存，然后再接着训练，直到模型达到最优。
    early_stop_2 = xgboost.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True)
    model_2 = xgboost.XGBClassifier(n_estimators=128, eval_metric="logloss", callbacks=[early_stop_2])
    model_2.fit(X, y, eval_set=[(X, y)])
    assert model_2.get_booster().num_boosted_rounds() == 128

    # 存储模型并重新进行加载
    model_2.save_model("model/model.json")
    loaded = xgboost.XGBClassifier()
    loaded.load_model("model/model.json")

    # 在原来模型的基础上重新训练
    early_stop_3 = xgboost.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True)
    model_3 = xgboost.XGBClassifier(n_estimators=512 - 128, eval_metric="logloss", callbacks=[early_stop_3])
    model_3.fit(X, y, eval_set=[(X, y)], xgb_model=loaded)

    # 打印结果
    print("Total boosted rounds:", model_3.get_booster().num_boosted_rounds())
    assert model_3.best_iteration == best

In [5]:
if __name__ == "__main__":
    training_continuation()
    training_continuation_early_stop()


[0]	validation_0-logloss:0.46043
[1]	validation_0-logloss:0.32756
[2]	validation_0-logloss:0.24233
[3]	validation_0-logloss:0.18487
[4]	validation_0-logloss:0.14270
[5]	validation_0-logloss:0.11199
[6]	validation_0-logloss:0.08949
[7]	validation_0-logloss:0.07410
[8]	validation_0-logloss:0.06163
[9]	validation_0-logloss:0.05216
[10]	validation_0-logloss:0.04463
[11]	validation_0-logloss:0.03818
[12]	validation_0-logloss:0.03270
[13]	validation_0-logloss:0.02887
[14]	validation_0-logloss:0.02565
[15]	validation_0-logloss:0.02322
[16]	validation_0-logloss:0.02088
[17]	validation_0-logloss:0.01884
[18]	validation_0-logloss:0.01717
[19]	validation_0-logloss:0.01599
[20]	validation_0-logloss:0.01496
[21]	validation_0-logloss:0.01401
[22]	validation_0-logloss:0.01327
[23]	validation_0-logloss:0.01243
[24]	validation_0-logloss:0.01178
[25]	validation_0-logloss:0.01122
[26]	validation_0-logloss:0.01072
[27]	validation_0-logloss:0.01025
[28]	validation_0-logloss:0.00994
[29]	validation_0-loglos



[31]	validation_0-logloss:0.00385
[32]	validation_0-logloss:0.00385
[33]	validation_0-logloss:0.00384
[34]	validation_0-logloss:0.00383
[35]	validation_0-logloss:0.00383
[36]	validation_0-logloss:0.00382
[37]	validation_0-logloss:0.00381
[38]	validation_0-logloss:0.00381
[39]	validation_0-logloss:0.00380
[40]	validation_0-logloss:0.00379
[41]	validation_0-logloss:0.00379
[42]	validation_0-logloss:0.00378
[43]	validation_0-logloss:0.00378
[44]	validation_0-logloss:0.00377
[45]	validation_0-logloss:0.00377
[46]	validation_0-logloss:0.00376
[47]	validation_0-logloss:0.00376
[48]	validation_0-logloss:0.00375
[49]	validation_0-logloss:0.00375
[50]	validation_0-logloss:0.00374
[51]	validation_0-logloss:0.00374
[52]	validation_0-logloss:0.00373
[53]	validation_0-logloss:0.00373
[54]	validation_0-logloss:0.00372
[55]	validation_0-logloss:0.00372
[56]	validation_0-logloss:0.00372
[57]	validation_0-logloss:0.00371
[58]	validation_0-logloss:0.00371
[59]	validation_0-logloss:0.00371
[60]	validatio