Skip to content

Commit

Permalink
Update assembler for XGBoost to support newer versions (#185)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Mar 23, 2020
1 parent d058b8f commit 7c3b589
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 46 deletions.
5 changes: 3 additions & 2 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def __init__(self, model):
# assembling (if applicable).
best_ntree_limit = getattr(model, "best_ntree_limit", None)

super().__init__(model, trees, base_score=model.base_score,
super().__init__(model, trees,
base_score=model.get_params()["base_score"],
tree_limit=best_ntree_limit)

def _assemble_tree(self, tree):
Expand Down Expand Up @@ -171,7 +172,7 @@ def __init__(self, model):
weights = json.loads(model_dump[0])["weight"]
self._bias = json.loads(model_dump[0])["bias"]
super().__init__(model, weights,
base_score=model.base_score)
base_score=model.get_params()["base_score"])

def _assemble_estimators(self, weights, split_idx):
coef = utils.to_1d_array(weights)
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
numpy==1.16.1
scipy==1.1.0
scikit-learn==0.20.2
xgboost==0.90
xgboost==1.0.2
lightgbm==2.2.3
flake8==3.6.0
pytest==5.3.2
Expand Down
88 changes: 45 additions & 43 deletions tests/assemblers/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ def test_binary_classification():
ast.FeatureRef(20),
ast.NumVal(16.7950001),
ast.CompOpType.GTE),
ast.NumVal(-0.173057005),
ast.NumVal(0.163440868)),
ast.NumVal(-0.519171),
ast.NumVal(0.49032259)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.142349988),
ast.CompOpType.GTE),
ast.NumVal(-0.161026895),
ast.NumVal(0.149405137)),
ast.NumVal(-0.443304211),
ast.NumVal(0.391988248)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -96,16 +96,16 @@ def test_regression():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.NumVal(4.98425627),
ast.NumVal(8.75091362)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.NumVal(8.34557438),
ast.NumVal(3.9141891)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -131,16 +131,16 @@ def test_regression_best_ntree_limit():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.NumVal(4.98425627),
ast.NumVal(8.75091362)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.NumVal(8.34557438),
ast.NumVal(3.9141891)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -166,8 +166,8 @@ def test_multi_class_best_ntree_limit():
ast.FeatureRef(2),
ast.NumVal(2.45000005),
ast.CompOpType.GTE),
ast.NumVal(-0.0733167157),
ast.NumVal(0.143414631)),
ast.NumVal(-0.219950154),
ast.NumVal(0.430243909)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -179,8 +179,8 @@ def test_multi_class_best_ntree_limit():
ast.FeatureRef(2),
ast.NumVal(2.45000005),
ast.CompOpType.GTE),
ast.NumVal(0.0344139598),
ast.NumVal(-0.0717073306)),
ast.NumVal(0.103241883),
ast.NumVal(-0.215121984)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -192,8 +192,8 @@ def test_multi_class_best_ntree_limit():
ast.FeatureRef(3),
ast.NumVal(1.6500001),
ast.CompOpType.GTE),
ast.NumVal(0.13432835),
ast.NumVal(-0.0644444525)),
ast.NumVal(0.402985066),
ast.NumVal(-0.193333372)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand Down Expand Up @@ -247,23 +247,25 @@ def test_regression_saved_without_feature_names():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.NumVal(4.98425627),
ast.NumVal(8.75091362)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.NumVal(8.34557438),
ast.NumVal(3.9141891)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)


def test_linear_model():
# Default updater ("shotgun") is nondeterministic
estimator = xgboost.XGBRegressor(n_estimators=2, random_state=1,
updater="coord_descent",
feature_selector="shuffle",
booster="gblinear")
utils.get_regression_model_trainer()(estimator)
Expand All @@ -274,63 +276,63 @@ def test_linear_model():
feature_weight_mul = [
ast.BinNumExpr(
ast.FeatureRef(0),
ast.NumVal(-0.00999326),
ast.NumVal(-0.151436),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(1),
ast.NumVal(0.0520094),
ast.NumVal(0.084474),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(2),
ast.NumVal(0.10447),
ast.NumVal(-0.10035),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(3),
ast.NumVal(0.17387),
ast.NumVal(4.71537),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(4),
ast.NumVal(0.691745),
ast.NumVal(1.39071),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(5),
ast.NumVal(0.296357),
ast.NumVal(0.330592),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(6),
ast.NumVal(0.0288206),
ast.NumVal(0.0610453),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(7),
ast.NumVal(0.417822),
ast.NumVal(0.476255),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(8),
ast.NumVal(0.0551116),
ast.NumVal(-0.0677851),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(9),
ast.NumVal(0.00242449),
ast.NumVal(-0.000543615),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(10),
ast.NumVal(0.109585),
ast.NumVal(0.0717916),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(11),
ast.NumVal(0.00744202),
ast.NumVal(0.010832),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.FeatureRef(12),
ast.NumVal(0.0731089),
ast.NumVal(-0.139375),
ast.BinNumOpType.MUL),
]

expected = ast.BinNumExpr(
ast.NumVal(0.5),
assemblers.utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
ast.NumVal(3.13109),
ast.NumVal(11.1287),
*feature_weight_mul),
ast.BinNumOpType.ADD)

Expand All @@ -352,18 +354,18 @@ def test_regression_random_forest():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8375001),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(17.3671646),
ast.NumVal(9.48354053)),
ast.NumVal(18.1008453),
ast.NumVal(9.60167599)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.FeatureRef(5),
ast.NumVal(6.79699993),
ast.CompOpType.GTE),
ast.NumVal(8.31587982),
ast.NumVal(14.7766275)),
ast.NumVal(17.780262),
ast.NumVal(9.51712894)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)

0 comments on commit 7c3b589

Please sign in to comment.