Skip to content

Commit

Permalink
Fix application of best_ntree_limit to the entire list of estimators.…
Browse files Browse the repository at this point in the history
… Instead the limit is applied to per-class estimators split (#83)
  • Loading branch information
izeigerman authored and krinart committed Apr 3, 2019
1 parent 8610a18 commit d6ec7ec
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 9 deletions.
22 changes: 13 additions & 9 deletions m2cgen/assemblers/boosting.py
Expand Up @@ -9,14 +9,17 @@ class BaseBoostingAssembler(ModelAssembler):

classifier_name = None

def __init__(self, model, trees, base_score=0):
def __init__(self, model, trees, base_score=0, tree_limit=None):
super().__init__(model)
self.all_trees = trees
self._base_score = base_score

self._output_size = 1
self._is_classification = False

assert tree_limit is None or tree_limit > 0, "Unexpected tree limit"
self._tree_limit = tree_limit

model_class_name = type(model).__name__
if model_class_name == self.classifier_name:
self._is_classification = True
Expand All @@ -34,6 +37,9 @@ def assemble(self):
self.all_trees, self._base_score)

def _assemble_single_output(self, trees, base_score=0):
if self._tree_limit:
trees = trees[:self._tree_limit]

trees_ast = [self._assemble_tree(t) for t in trees]
result_ast = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
Expand Down Expand Up @@ -83,16 +89,14 @@ def __init__(self, model):
}

model_dump = model.get_booster().get_dump(dump_format="json")

# Respect XGBoost ntree_limit
ntree_limit = getattr(model, "best_ntree_limit", 0)

if ntree_limit > 0:
model_dump = model_dump[:ntree_limit]

trees = [json.loads(d) for d in model_dump]

super().__init__(model, trees, base_score=model.base_score)
# Limit the number of trees that should be used for
# assembling (if applicable).
best_ntree_limit = getattr(model, "best_ntree_limit", None)

super().__init__(model, trees, base_score=model.base_score,
tree_limit=best_ntree_limit)

def _assemble_tree(self, tree):
if "leaf" in tree:
Expand Down
81 changes: 81 additions & 0 deletions tests/assemblers/test_xgboost.py
Expand Up @@ -147,3 +147,84 @@ def test_regression_best_ntree_limit():
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)


def test_multi_class_best_ntree_limit():
base_score = 0.5
estimator = xgboost.XGBClassifier(n_estimators=100, random_state=1,
max_depth=1, base_score=base_score)

estimator.best_ntree_limit = 1

utils.train_model_classification(estimator)

assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()

estimator_exp_class1 = ast.ExpExpr(
ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(0.5),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(2.5999999),
ast.CompOpType.GTE),
ast.NumVal(-0.0731707439),
ast.NumVal(0.142857149)),
ast.BinNumOpType.ADD)),
to_reuse=True)

estimator_exp_class2 = ast.ExpExpr(
ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(0.5),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(2.5999999),
ast.CompOpType.GTE),
ast.NumVal(0.0341463387),
ast.NumVal(-0.0714285821)),
ast.BinNumOpType.ADD)),
to_reuse=True)

estimator_exp_class3 = ast.ExpExpr(
ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(0.5),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(4.85000038),
ast.CompOpType.GTE),
ast.NumVal(0.129441619),
ast.NumVal(-0.0681440532)),
ast.BinNumOpType.ADD)),
to_reuse=True)

exp_sum = ast.BinNumExpr(
ast.BinNumExpr(
estimator_exp_class1,
estimator_exp_class2,
ast.BinNumOpType.ADD),
estimator_exp_class3,
ast.BinNumOpType.ADD,
to_reuse=True)

expected = ast.VectorVal([
ast.BinNumExpr(
estimator_exp_class1,
exp_sum,
ast.BinNumOpType.DIV),
ast.BinNumExpr(
estimator_exp_class2,
exp_sum,
ast.BinNumOpType.DIV),
ast.BinNumExpr(
estimator_exp_class3,
exp_sum,
ast.BinNumOpType.DIV)
])

assert utils.cmp_exprs(actual, expected)

0 comments on commit d6ec7ec

Please sign in to comment.