<a href="https://colab.research.google.com/github/JaehyunAhn/AI_for_Education/blob/master/VotingClassifier_ensemble_with_earlystopping_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from sklearn.base import BaseEstimator
from sklearn.model_selection import train_test_split
from xgboost import XGBRegressor, XGBClassifier

class XGBoostWithEarlyStop(BaseEstimator):
    def __init__(self, test_size=0.2, **estimator_params):
        self.test_size = test_size
        if self.estimator is not None:
            self.set_params(**estimator_params)

    def set_params(self, **params):
        return self.estimator.set_params(**params)

    def get_params(self, **params):
        return self.estimator.get_params()

    def fit(self, X, y):
        x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=self.test_size)
        self.estimator.fit(x_train, y_train,
                           eval_set=[(x_train, y_train), (x_val, y_val)])
        return self

    def predict(self, X):
        return self.estimator.predict(X)

    def predict_proba(self, X):
        return self.estimator.predict_proba(X)

class XGBoostRegressorWithEarlyStop(XGBoostWithEarlyStop):
    def __init__(self, *args, **kwargs):
        self.estimator = XGBRegressor()
        self._estimator_type = 'regressor'
        super(XGBoostRegressorWithEarlyStop, self).__init__(*args, **kwargs)

class XGBoostClassifierWithEarlyStop(XGBoostWithEarlyStop):
    def __init__(self, *args, **kwargs):
        self.estimator = XGBClassifier()
        self._estimator_type = 'classifier'
        super(XGBoostClassifierWithEarlyStop, self).__init__(*args, **kwargs)


In [None]:
from sklearn.pipeline import Pipeline
from sklearn.ensemble import VotingClassifier
from xgboost import XGBClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.linear_model import LogisticRegression


X, y = load_breast_cancer(return_X_y=True, as_frame=True)

# Split your data into training and validation sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)

# Define your preprocessing steps in the pipeline
scaler = StandardScaler().set_output(transform='pandas')  # You can replace this with any preprocessing steps you need

xgb_model = XGBoostClassifierWithEarlyStop(
    objective="binary:logistic",
    eval_metric="mae",
    n_estimators=600,  # Set a large number
    max_depth=6,
    learning_rate=0.1,
    verbose_eval=True,
    early_stopping_rounds=3
)

logistic_model = LogisticRegression()

# Create the VotingClassifier with only the XGBoost model

voting_classifier = VotingClassifier(
    estimators=[
        ('xgb', xgb_model),
        ('lgr', logistic_model)
        ],
    voting='soft'
)

pipe = Pipeline(
    [
        ('scaler', scaler),
        ('clf', voting_classifier)
    ]
)

# Fit the VotingClassifier with training data
pipe.fit(X_train, y_train)

# Use the VotingClassifier to make predictions on the validation data
y_pred = pipe.predict(X_test)

Parameters: { "verbose_eval" } are not used.



[0]	validation_0-mae:0.45472	validation_1-mae:0.43781
[1]	validation_0-mae:0.42186	validation_1-mae:0.40401
[2]	validation_0-mae:0.39221	validation_1-mae:0.37021
[3]	validation_0-mae:0.36516	validation_1-mae:0.34255
[4]	validation_0-mae:0.34072	validation_1-mae:0.31463
[5]	validation_0-mae:0.31838	validation_1-mae:0.29199
[6]	validation_0-mae:0.29746	validation_1-mae:0.27150
[7]	validation_0-mae:0.27753	validation_1-mae:0.25391
[8]	validation_0-mae:0.25995	validation_1-mae:0.23554
[9]	validation_0-mae:0.24321	validation_1-mae:0.22091
[10]	validation_0-mae:0.22852	validation_1-mae:0.20634
[11]	validation_0-mae:0.21396	validation_1-mae:0.19614
[12]	validation_0-mae:0.20074	validation_1-mae:0.18707
[13]	validation_0-mae:0.18877	validation_1-mae:0.17464
[14]	validation_0-mae:0.17687	validation_1-mae:0.16548
[15]	validation_0-mae:0.16603	validation_1-mae:0.15730
[16]	validation_0-mae:0.15727	validation_1-mae:0.14827
[17]	validation_0-mae:0.14798	validation_1-mae:0.14130
[18]	validation_0-ma

In [None]:
# xgb_model.set_params(early_stopping_rounds=7)
# xgb_model.fit(X_train, y_train)

In [None]:
# xgb_model.get_params()

In [None]:
from sklearn.metrics import f1_score

f1_score(y_true=y_test, y_pred=y_pred)

0.9879518072289157

In [None]:
pipe[:-1].transform(X_test)

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
204,-0.522890,-0.122794,-0.497700,-0.506027,0.212038,0.055471,-0.105254,-0.306818,0.484503,0.227268,...,-0.307205,-0.109254,-0.361904,-0.366112,0.413852,-0.083749,-0.005465,-0.252428,0.269132,0.213717
70,1.241654,0.451826,1.190610,1.128734,-0.443207,-0.001867,0.258175,0.767271,-0.932875,-1.062454,...,1.708867,0.175433,1.709603,1.548952,-0.560564,-0.108812,0.002394,0.925044,-0.534975,-0.940430
131,0.292564,0.063799,0.320839,0.167451,0.866598,0.381705,0.760084,0.802641,0.509296,-0.588708,...,0.567310,0.090320,0.493686,0.404524,0.915696,-0.074202,0.544642,0.506692,-0.038270,-0.176695
431,-0.541981,-0.317868,-0.482608,-0.541593,0.606145,0.565582,-0.139582,-0.572610,0.013421,1.258198,...,-0.733250,-0.363125,-0.552892,-0.627396,0.514221,0.066027,-0.137098,-0.675344,-0.526291,0.538972
540,-0.776527,-1.004867,-0.753468,-0.705296,0.225061,0.178056,-0.270130,-0.625924,0.042347,0.805664,...,-0.859637,-0.837115,-0.874072,-0.720884,0.075106,-0.238897,-0.434745,-0.744106,-0.920529,-0.115276
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
417,0.303473,0.403058,0.368497,0.304164,1.058511,1.069761,0.832900,0.905108,1.145670,0.921626,...,1.364361,0.332451,1.448626,1.358752,0.794417,0.885325,0.750932,1.449887,0.250028,1.137676
267,-0.217436,0.564206,-0.256626,-0.306506,-1.164936,-0.403431,-0.616656,-0.743476,-0.713863,-0.499615,...,-0.341859,0.683175,-0.314157,-0.392546,-1.346787,-0.470424,-0.603707,-0.855008,-0.717332,-0.708105
327,-0.642890,-0.264859,-0.696278,-0.596581,-1.352051,-1.266863,-1.126029,-1.155114,-1.759334,-0.201225,...,-0.694519,-0.459978,-0.756632,-0.615146,-1.313331,-1.061775,-1.279395,-1.371179,-1.194933,-0.701162
200,-0.588345,0.080762,-0.598974,-0.558745,-0.047730,-0.437438,-0.601703,-0.232438,0.707647,-0.281833,...,-0.415245,0.436641,-0.477564,-0.429780,0.426398,-0.284248,-0.641035,-0.153544,-0.331777,-0.093913


In [None]:
X_test

Unnamed: 0,mean radius,mean texture,mean perimeter,mean area,mean smoothness,mean compactness,mean concavity,mean concave points,mean symmetry,mean fractal dimension,...,worst radius,worst texture,worst perimeter,worst area,worst smoothness,worst compactness,worst concavity,worst concave points,worst symmetry,worst fractal dimension
204,12.47,18.60,81.09,481.9,0.09965,0.10580,0.080050,0.038210,0.1925,0.06373,...,14.97,24.64,96.05,677.9,0.1426,0.2378,0.267100,0.10150,0.3014,0.08750
70,18.94,21.31,123.60,1130.0,0.09009,0.10290,0.108000,0.079510,0.1582,0.05461,...,24.86,26.58,165.90,1866.0,0.1193,0.2336,0.268700,0.17890,0.2551,0.06589
131,15.46,19.48,101.70,748.9,0.10920,0.12230,0.146600,0.080870,0.1931,0.05796,...,19.26,26.00,124.90,1156.0,0.1546,0.2394,0.379100,0.15140,0.2837,0.08019
431,12.40,17.68,81.47,467.8,0.10540,0.13160,0.077410,0.027990,0.1811,0.07102,...,12.88,22.91,89.61,515.8,0.1450,0.2629,0.240300,0.07370,0.2556,0.09359
540,11.54,14.44,74.65,402.9,0.09984,0.11200,0.067370,0.025940,0.1818,0.06782,...,12.26,19.68,78.78,457.8,0.1345,0.2118,0.179700,0.06918,0.2329,0.08134
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
417,15.50,21.08,102.90,803.1,0.11200,0.15710,0.152200,0.084810,0.2085,0.06864,...,23.17,27.65,157.10,1748.0,0.1517,0.4002,0.421100,0.21340,0.3003,0.10480
267,13.59,21.84,87.16,561.0,0.07956,0.08259,0.040720,0.021420,0.1635,0.05859,...,14.80,30.04,97.66,661.5,0.1005,0.1730,0.145300,0.06189,0.2446,0.07024
327,12.03,17.93,76.09,446.0,0.07683,0.03892,0.001546,0.005592,0.1382,0.06070,...,13.07,22.25,82.74,523.4,0.1013,0.0739,0.007732,0.02796,0.2171,0.07037
200,12.23,19.56,78.54,461.0,0.09586,0.08087,0.041870,0.041070,0.1979,0.06013,...,14.44,28.36,92.15,638.4,0.1429,0.2042,0.137700,0.10800,0.2668,0.08174
