Skip to content

Commit

Permalink
Add some XGBoost assembler tests
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Feb 10, 2019
1 parent f340465 commit 7e3d944
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
7 changes: 5 additions & 2 deletions tests/assemblers/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_two_conditions():
def test_multi_class():
estimator = tree.DecisionTreeClassifier()

estimator.fit([[1], [2], [3]], [1, -1, 1])
estimator.fit([[1], [2], [3]], [0, 1, 2])

assembler = assemblers.TreeModelAssembler(estimator)
actual = assembler.assemble()
Expand All @@ -62,17 +62,20 @@ def test_multi_class():
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.NumVal(0.0)]),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(0.0),
ast.NumVal(1.0)])))

Expand Down
83 changes: 81 additions & 2 deletions tests/assemblers/test_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,83 @@
import xgboost
from tests import utils
from m2cgen import assemblers, ast


def test_xgboost():
pass
def test_binary_classification():
estimator = xgboost.XGBClassifier(n_estimators=2, random_state=1,
max_depth=1)
utils.train_model_classification_binary(estimator)

assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()

sigmoid = ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(
ast.NumVal(1),
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(-0.0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(20),
ast.NumVal(16.7950001),
ast.CompOpType.GTE),
ast.NumVal(-0.17062147),
ast.NumVal(0.1638484)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.142349988),
ast.CompOpType.GTE),
ast.NumVal(-0.16087772),
ast.NumVal(0.149866998)),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
to_reuse=True)

expected = ast.VectorVal([
ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB),
sigmoid])

assert utils.cmp_exprs(actual, expected)


def test_regression():
base_score = 0.6
estimator = xgboost.XGBRegressor(n_estimators=2, random_state=1,
max_depth=1, base_score=base_score)
utils.train_model_regression(estimator)

assembler = assemblers.XGBoostModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(base_score),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
10 changes: 9 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ def cmp_exprs(left, right):
return True

if not isinstance(left, ast.Expr) and not isinstance(right, ast.Expr):
assert left == right, str(left) + " != " + str(right)
if _is_float(left) and _is_float(right):
comp_res = np.isclose(left, right)
else:
comp_res = left == right
assert comp_res, str(left) + " != " + str(right)
return True

if isinstance(left, ast.Expr) and isinstance(right, ast.Expr):
Expand Down Expand Up @@ -158,3 +162,7 @@ def inner(*args, **kwarg):
return inner

return wrap


def _is_float(value):
return isinstance(value, (float, np.float16, np.float32, np.float64))

0 comments on commit 7e3d944

Please sign in to comment.