From 7e3d944f841fcb50b33295f6a975b320a3031e8b Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Sat, 9 Feb 2019 20:47:45 -0800 Subject: [PATCH] Add some XGBoost assembler tests --- tests/assemblers/test_tree.py | 7 ++- tests/assemblers/test_xgboost.py | 83 +++++++++++++++++++++++++++++++- tests/utils.py | 10 +++- 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/tests/assemblers/test_tree.py b/tests/assemblers/test_tree.py index 37b3098b..360d3daa 100644 --- a/tests/assemblers/test_tree.py +++ b/tests/assemblers/test_tree.py @@ -51,7 +51,7 @@ def test_two_conditions(): def test_multi_class(): estimator = tree.DecisionTreeClassifier() - estimator.fit([[1], [2], [3]], [1, -1, 1]) + estimator.fit([[1], [2], [3]], [0, 1, 2]) assembler = assemblers.TreeModelAssembler(estimator) actual = assembler.assemble() @@ -62,17 +62,20 @@ def test_multi_class(): ast.NumVal(1.5), ast.CompOpType.LTE), ast.VectorVal([ + ast.NumVal(1.0), ast.NumVal(0.0), - ast.NumVal(1.0)]), + ast.NumVal(0.0)]), ast.IfExpr( ast.CompExpr( ast.FeatureRef(0), ast.NumVal(2.5), ast.CompOpType.LTE), ast.VectorVal([ + ast.NumVal(0.0), ast.NumVal(1.0), ast.NumVal(0.0)]), ast.VectorVal([ + ast.NumVal(0.0), ast.NumVal(0.0), ast.NumVal(1.0)]))) diff --git a/tests/assemblers/test_xgboost.py b/tests/assemblers/test_xgboost.py index 955040ac..29db9ba2 100644 --- a/tests/assemblers/test_xgboost.py +++ b/tests/assemblers/test_xgboost.py @@ -1,4 +1,83 @@ +import xgboost +from tests import utils +from m2cgen import assemblers, ast -def test_xgboost(): - pass +def test_binary_classification(): + estimator = xgboost.XGBClassifier(n_estimators=2, random_state=1, + max_depth=1) + utils.train_model_classification_binary(estimator) + + assembler = assemblers.XGBoostModelAssembler(estimator) + actual = assembler.assemble() + + sigmoid = ast.BinNumExpr( + ast.NumVal(1), + ast.BinNumExpr( + ast.NumVal(1), + ast.ExpExpr( + ast.BinNumExpr( + ast.NumVal(0), + ast.SubroutineExpr( + ast.BinNumExpr( + ast.BinNumExpr( + ast.NumVal(-0.0), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(20), + ast.NumVal(16.7950001), + ast.CompOpType.GTE), + ast.NumVal(-0.17062147), + ast.NumVal(0.1638484)), + ast.BinNumOpType.ADD), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(27), + ast.NumVal(0.142349988), + ast.CompOpType.GTE), + ast.NumVal(-0.16087772), + ast.NumVal(0.149866998)), + ast.BinNumOpType.ADD)), + ast.BinNumOpType.SUB)), + ast.BinNumOpType.ADD), + ast.BinNumOpType.DIV, + to_reuse=True) + + expected = ast.VectorVal([ + ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB), + sigmoid]) + + assert utils.cmp_exprs(actual, expected) + + +def test_regression(): + base_score = 0.6 + estimator = xgboost.XGBRegressor(n_estimators=2, random_state=1, + max_depth=1, base_score=base_score) + utils.train_model_regression(estimator) + + assembler = assemblers.XGBoostModelAssembler(estimator) + actual = assembler.assemble() + + expected = ast.SubroutineExpr( + ast.BinNumExpr( + ast.BinNumExpr( + ast.NumVal(base_score), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(12), + ast.NumVal(9.72500038), + ast.CompOpType.GTE), + ast.NumVal(1.67318344), + ast.NumVal(2.92757893)), + ast.BinNumOpType.ADD), + ast.IfExpr( + ast.CompExpr( + ast.FeatureRef(5), + ast.NumVal(6.94099998), + ast.CompOpType.GTE), + ast.NumVal(3.3400948), + ast.NumVal(1.72118247)), + ast.BinNumOpType.ADD)) + + assert utils.cmp_exprs(actual, expected) diff --git a/tests/utils.py b/tests/utils.py index 15020f54..ae3d0e8f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -31,7 +31,11 @@ def cmp_exprs(left, right): return True if not isinstance(left, ast.Expr) and not isinstance(right, ast.Expr): - assert left == right, str(left) + " != " + str(right) + if _is_float(left) and _is_float(right): + comp_res = np.isclose(left, right) + else: + comp_res = left == right + assert comp_res, str(left) + " != " + str(right) return True if isinstance(left, ast.Expr) and isinstance(right, ast.Expr): @@ -158,3 +162,7 @@ def inner(*args, **kwarg): return inner return wrap + + +def _is_float(value): + return isinstance(value, (float, np.float16, np.float32, np.float64))