From 011848174846c3d8d0d0f6d70d930e26a1005bff Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Fri, 18 Sep 2020 12:12:11 -0700 Subject: [PATCH 1/4] Bump version to 0.9.1 (#309) --- m2cgen/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/m2cgen/VERSION.txt b/m2cgen/VERSION.txt index 6f4eebdf..f374f666 100644 --- a/m2cgen/VERSION.txt +++ b/m2cgen/VERSION.txt @@ -1 +1 @@ -0.8.1 +0.9.1 From 21e0c1c198d77f0a6f633f4054800815dd25d153 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Oct 2020 10:31:32 -0700 Subject: [PATCH 2/4] Bump pytest from 6.0.1 to 6.1.1 (#314) Bumps [pytest](https://github.com/pytest-dev/pytest) from 6.0.1 to 6.1.1. - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/master/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/6.0.1...6.1.1) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 74ccb8e2..641af450 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -7,7 +7,7 @@ git+git://github.com/scikit-learn-contrib/lightning.git@782c18c12961e509099ae84c # Testing tools flake8==3.8.3 -pytest==6.0.1 +pytest==6.1.1 pytest-mock==3.3.1 coveralls==2.1.2 pytest-cov==2.10.1 From 87840c4f5a0fd8fe653fbd8fc6cec501b412b965 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Oct 2020 10:32:11 -0700 Subject: [PATCH 3/4] Bump py-mini-racer from 0.3.0 to 0.4.0 (#310) Bumps [py-mini-racer](https://github.com/sqreen/PyMiniRacer) from 0.3.0 to 0.4.0. - [Release notes](https://github.com/sqreen/PyMiniRacer/releases) - [Changelog](https://github.com/sqreen/PyMiniRacer/blob/master/HISTORY.rst) - [Commits](https://github.com/sqreen/PyMiniRacer/compare/v0.3.0...v0.4.0) Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- requirements-test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-test.txt b/requirements-test.txt index 641af450..d9a6450f 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -15,4 +15,4 @@ pytest-cov==2.10.1 # Other stuff numpy==1.18.5 scipy==1.5.2 -py-mini-racer==0.3.0 +py-mini-racer==0.4.0 From 6b0ede152ac35cdafe8c91b0e54144dfd09efc06 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 7 Oct 2020 10:33:12 -0700 Subject: [PATCH 4/4] Bump numpy from 1.18.5 to 1.19.2 (#307) * Bump numpy from 1.18.5 to 1.19.2 Bumps [numpy](https://github.com/numpy/numpy) from 1.18.5 to 1.19.2. - [Release notes](https://github.com/numpy/numpy/releases) - [Changelog](https://github.com/numpy/numpy/blob/master/doc/HOWTO_RELEASE.rst.txt) - [Commits](https://github.com/numpy/numpy/compare/v1.18.5...v1.19.2) Signed-off-by: dependabot[bot] * Revert "Simplify train/test splitting routine for tests (#291)" This reverts commit fd86e056b3aa460d76c140a1b65656b60f5a827d. * update tests * hotfix Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: StrikerRUS Co-authored-by: Nikita Titov --- requirements-test.txt | 2 +- tests/assemblers/test_lightgbm.py | 112 +++++++++--------- tests/assemblers/test_linear.py | 190 +++++++++++++++--------------- tests/assemblers/test_xgboost.py | 88 +++++++------- tests/test_cli.py | 4 +- tests/utils.py | 2 +- 6 files changed, 199 insertions(+), 199 deletions(-) diff --git a/requirements-test.txt b/requirements-test.txt index d9a6450f..67946d96 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -13,6 +13,6 @@ coveralls==2.1.2 pytest-cov==2.10.1 # Other stuff -numpy==1.18.5 +numpy==1.19.2 scipy==1.5.2 py-mini-racer==0.4.0 diff --git a/tests/assemblers/test_lightgbm.py b/tests/assemblers/test_lightgbm.py index a910ed42..7ed5496f 100644 --- a/tests/assemblers/test_lightgbm.py +++ b/tests/assemblers/test_lightgbm.py @@ -23,18 +23,18 @@ def test_binary_classification(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(23), - ast.NumVal(868.2000000000002), + ast.FeatureRef(20), + ast.NumVal(16.795), ast.CompOpType.GT), - ast.NumVal(0.26400127816506497), - ast.NumVal(0.633133056485969)), + ast.NumVal(0.27502096830384837), + ast.NumVal(0.6391171126839048)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(22), - ast.NumVal(105.95000000000002), + ast.FeatureRef(27), + ast.NumVal(0.14205), ast.CompOpType.GT), - ast.NumVal(-0.18744882409486507), - ast.NumVal(0.13458899352064668)), + ast.NumVal(-0.21340153096570616), + ast.NumVal(0.11583109256834748)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -84,18 +84,18 @@ def test_regression(): expected = ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.837500000000001), + ast.FeatureRef(12), + ast.NumVal(9.725), ast.CompOpType.GT), - ast.NumVal(23.961356387224317), - ast.NumVal(22.32858336612959)), + ast.NumVal(22.030283219508686), + ast.NumVal(23.27840740210207)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.725000000000003), + ast.FeatureRef(5), + ast.NumVal(6.8375), ast.CompOpType.GT), - ast.NumVal(-0.5031712645462916), - ast.NumVal(0.6885501354513913)), + ast.NumVal(1.2777791671888081), + ast.NumVal(-0.2686772850549309)), ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected) @@ -115,17 +115,17 @@ def test_regression_random_forest(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(5.200000000000001), + ast.NumVal(9.605), ast.CompOpType.GT), - ast.NumVal(20.195681040256623), - ast.NumVal(38.30000037757679)), + ast.NumVal(17.398543657369768), + ast.NumVal(29.851408659650296)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.853000000000001), + ast.NumVal(6.888), ast.CompOpType.GT), - ast.NumVal(36.14745794219976), - ast.NumVal(19.778245570310993)), + ast.NumVal(37.2235298136268), + ast.NumVal(19.948122884684025)), ast.BinNumOpType.ADD), ast.NumVal(0.5), ast.BinNumOpType.MUL) @@ -148,23 +148,23 @@ def test_regression_with_negative_values(): ast.FeatureRef(8), ast.NumVal(0.0), ast.CompOpType.GT), - ast.NumVal(155.96889994777868), - ast.NumVal(147.72971715548434)), + ast.NumVal(156.64462853604854), + ast.NumVal(148.40956590509697)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(2), ast.NumVal(0.00780560282464346), ast.CompOpType.GT), - ast.NumVal(4.982244683562974), - ast.NumVal(-2.978315963345233)), + ast.NumVal(4.996373375352607), + ast.NumVal(-3.1063596100284814)), ast.BinNumOpType.ADD), ast.IfExpr( ast.CompExpr( ast.FeatureRef(8), ast.NumVal(-0.0010539205031971832), ast.CompOpType.LTE), - ast.NumVal(-3.488666332734598), - ast.NumVal(3.670539900363904)), + ast.NumVal(-3.5131100858883424), + ast.NumVal(3.6285643795846214)), ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected) @@ -191,15 +191,15 @@ def test_simple_sigmoid_output_transform(): ast.FeatureRef(12), ast.NumVal(19.23), ast.CompOpType.GT), - ast.NumVal(4.0050691250), - ast.NumVal(4.0914737728)), + ast.NumVal(4.002437528537838), + ast.NumVal(4.090096709787509)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(15.065), + ast.NumVal(14.895), ast.CompOpType.GT), - ast.NumVal(-0.0420531079), - ast.NumVal(0.0202891577)), + ast.NumVal(-0.0417499606641773), + ast.NumVal(0.02069953712454655)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -225,15 +225,15 @@ def test_log1p_exp_output_transform(): ast.FeatureRef(12), ast.NumVal(19.23), ast.CompOpType.GT), - ast.NumVal(0.6623996001), - ast.NumVal(0.6684477608)), + ast.NumVal(0.6622623010380544), + ast.NumVal(0.6684065452877841)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(15.065), + ast.NumVal(15.145), ast.CompOpType.GT), - ast.NumVal(0.1405782705), - ast.NumVal(0.1453764991)), + ast.NumVal(0.1404975120475147), + ast.NumVal(0.14535916856709272)), ast.BinNumOpType.ADD))) assert utils.cmp_exprs(actual, expected) @@ -253,17 +253,17 @@ def test_maybe_sqr_output_transform(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(11.655), + ast.NumVal(9.725), ast.CompOpType.GT), - ast.NumVal(4.5671830654), - ast.NumVal(4.6516575813)), + ast.NumVal(4.569350528717041), + ast.NumVal(4.663526439666748)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.725), + ast.NumVal(11.655), ast.CompOpType.GT), - ast.NumVal(-0.0348178434), - ast.NumVal(0.0549301624)), + ast.NumVal(-0.04462450027465819), + ast.NumVal(0.033305134773254384)), ast.BinNumOpType.ADD), to_reuse=True) @@ -287,18 +287,18 @@ def test_exp_output_transform(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.8375), + ast.FeatureRef(12), + ast.NumVal(9.725), ast.CompOpType.GT), - ast.NumVal(3.1481886430), - ast.NumVal(3.1123367238)), + ast.NumVal(3.1043985065105892), + ast.NumVal(3.1318783133960197)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(12), - ast.NumVal(9.725), + ast.FeatureRef(5), + ast.NumVal(6.8375), ast.CompOpType.GT), - ast.NumVal(-0.0113689739), - ast.NumVal(0.0153551274)), + ast.NumVal(0.028409619436010138), + ast.NumVal(-0.0060740730485278754)), ast.BinNumOpType.ADD)) assert utils.cmp_exprs(actual, expected) @@ -323,11 +323,11 @@ def test_bin_class_sigmoid_output_transform(): ast.NumVal(0.5), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(23), - ast.NumVal(868.2), + ast.FeatureRef(20), + ast.NumVal(16.795), ast.CompOpType.GT), - ast.NumVal(0.5280025563), - ast.NumVal(1.2662661130)), + ast.NumVal(0.5500419366076967), + ast.NumVal(1.2782342253678096)), ast.BinNumOpType.MUL), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), diff --git a/tests/assemblers/test_linear.py b/tests/assemblers/test_linear.py index 23337c2d..dc0781f9 100644 --- a/tests/assemblers/test_linear.py +++ b/tests/assemblers/test_linear.py @@ -141,55 +141,55 @@ def test_statsmodels_wo_const(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0940752519), + ast.NumVal(-0.09519078450227643), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0461122112), + ast.NumVal(0.048952926782237956), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0034800646), + ast.NumVal(0.007485539189808044), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.9669908485), + ast.NumVal(2.7302631809978273), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-2.1264724710), + ast.NumVal(-2.5078200782168034), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(5.9738064897), + ast.NumVal(5.891794660307579), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0062638276), + ast.NumVal(-0.008663096157185936), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.9385894841), + ast.NumVal(-0.9742684875268565), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.1568975632), + ast.NumVal(0.1591703441858682), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0091548228), + ast.NumVal(-0.009351831548409096), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.3949784315), + ast.NumVal(-0.36395034626096245), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0135685532), + ast.NumVal(0.014529018124980565), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.4392385223), + ast.NumVal(-0.437443877026267), ast.BinNumOpType.MUL), ] @@ -213,61 +213,61 @@ def test_statsmodels_w_const(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1082106941), + ast.NumVal(-0.1086131135490779), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0444969007), + ast.NumVal(0.046461486329934965), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.0189847585), + ast.NumVal(0.027432259970185422), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.7998640040), + ast.NumVal(2.6160671309537693), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-16.7498366967), + ast.NumVal(-17.51793656329748), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(3.9040863643), + ast.NumVal(3.7674418196771957), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.0014333844), + ast.NumVal(-2.1581753172923886e-05), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-1.4436181595), + ast.NumVal(-1.4711768622633619), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.2868165881), + ast.NumVal(0.29567671400629103), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0118539736), + ast.NumVal(-0.012233831527258853), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.9449930750), + ast.NumVal(-0.9220356453705244), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0083181952), + ast.NumVal(0.009038220462695548), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.5415938640), + ast.NumVal(-0.5425830337142312), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, - ast.NumVal(35.5746356887), + ast.NumVal(36.36708074657767), *feature_weight_mul) assert utils.cmp_exprs(actual, expected) @@ -308,55 +308,55 @@ def test_statsmodels_processmle(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0915126856), + ast.NumVal(-0.0980302102110356), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0455368812), + ast.NumVal(0.04863869398287732), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0092227692), + ast.NumVal(0.009514054355147874), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.8566616798), + ast.NumVal(2.977113829322681), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-2.1208777964), + ast.NumVal(-2.6048073854474705), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(5.9725253309), + ast.NumVal(5.887987153279099), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0061566965), + ast.NumVal(-0.008183580358672775), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.9414114075), + ast.NumVal(-0.996428929917054), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.1522429507), + ast.NumVal(0.1618353156581333), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0092123938), + ast.NumVal(-0.009213049690188308), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.3928508764), + ast.NumVal(-0.3634816838591863), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0134405151), + ast.NumVal(0.014700492832969888), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.4364996490), + ast.NumVal(-0.4384298738156768), ast.BinNumOpType.MUL), ] @@ -773,55 +773,55 @@ def test_lightning_regression(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0610645819), + ast.NumVal(-0.08558826944690746), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0856563713), + ast.NumVal(0.0803724696787377), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0562044566), + ast.NumVal(-0.03516743076774846), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.2804204925), + ast.NumVal(0.26469178947134087), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(0.1359261760), + ast.NumVal(0.15651985221012488), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(1.6307305501), + ast.NumVal(1.5186399078028587), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.0866147265), + ast.NumVal(0.10089874009193693), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.0726894150), + ast.NumVal(-0.011426237067026246), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.0435440193), + ast.NumVal(0.0861987777487293), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0077364839), + ast.NumVal(-0.0057791506839322574), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.2902775116), + ast.NumVal(0.3357752757550913), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0229879957), + ast.NumVal(0.020189965076849486), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.7614706871), + ast.NumVal(-0.7390647599317609), ast.BinNumOpType.MUL), ] @@ -843,123 +843,123 @@ def test_lightning_binary_class(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(0.1605265174), + ast.NumVal(0.16218889967390476), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.1045225083), + ast.NumVal(0.10012761963766906), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.6237391536), + ast.NumVal(0.6289276652681673), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.1680225811), + ast.NumVal(0.17618420156072845), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(0.0011013688), + ast.NumVal(0.0010492096607182045), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(-0.0027528486), + ast.NumVal(-0.0029135563693806913), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0058878714), + ast.NumVal(-0.005923882409142498), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.0023719811), + ast.NumVal(-0.0023293599172479755), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.0019944105), + ast.NumVal(0.0020808828960210517), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(0.0009924456), + ast.NumVal(0.0009846430705550103), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.0003994860), + ast.NumVal(0.0010399810925427265), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0124697033), + ast.NumVal(0.011203056917272093), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.0123674096), + ast.NumVal(-0.007271351370867731), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(13), - ast.NumVal(-0.2844204905), + ast.NumVal(-0.26333437096804224), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(14), - ast.NumVal(0.0000273704), + ast.NumVal(1.8533543368532444e-05), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(15), - ast.NumVal(-0.0007498013), + ast.NumVal(-0.0008266341686278445), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(16), - ast.NumVal(-0.0010784399), + ast.NumVal(-0.0011090316301215724), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(17), - ast.NumVal(-0.0001848694), + ast.NumVal(-0.0001910857095336291), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(18), - ast.NumVal(0.0000632254), + ast.NumVal(0.00010735116208006556), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(19), - ast.NumVal(-0.0000369618), + ast.NumVal(-4.076097659514017e-05), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(20), - ast.NumVal(0.1520223021), + ast.NumVal(0.15300712110146406), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(21), - ast.NumVal(0.0925348635), + ast.NumVal(0.06316277258339074), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(22), - ast.NumVal(0.4861047372), + ast.NumVal(0.495291178977687), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(23), - ast.NumVal(-0.2798670185), + ast.NumVal(-0.29589136204657845), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(24), - ast.NumVal(0.0009925506), + ast.NumVal(0.000771932729567487), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(25), - ast.NumVal(-0.0103414976), + ast.NumVal(-0.011877978242492428), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(26), - ast.NumVal(-0.0155024577), + ast.NumVal(-0.01678004536869617), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(27), - ast.NumVal(-0.0038881538), + ast.NumVal(-0.004070431062579625), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(28), - ast.NumVal(0.0010126166), + ast.NumVal(0.001158641497209262), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(29), - ast.NumVal(0.0002312558), + ast.NumVal(0.00010737287732588742), ast.BinNumOpType.MUL), ] @@ -986,22 +986,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(0.0895848274), + ast.NumVal(0.09686334892116512), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.3258329434), + ast.NumVal(0.32572202133211947), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.4900856238), + ast.NumVal(-0.48444233646554424), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(-0.2214482506), + ast.NumVal(-0.219719145605816), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( @@ -1011,22 +1011,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1074247041), + ast.NumVal(-0.1089136473832088), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(-0.1693225196), + ast.NumVal(-0.16956003333433572), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.0357417324), + ast.NumVal(0.0365531256007199), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(-0.0161614171), + ast.NumVal(-0.01016100116780896), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( @@ -1036,22 +1036,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1825063678), + ast.NumVal(-0.16690339219780817), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(-0.2185655665), + ast.NumVal(-0.19466284646233858), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.3053017646), + ast.NumVal(0.2953585236360389), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.2175198459), + ast.NumVal(0.21288203082531384), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)]) diff --git a/tests/assemblers/test_xgboost.py b/tests/assemblers/test_xgboost.py index 6727da14..81bc6e3c 100644 --- a/tests/assemblers/test_xgboost.py +++ b/tests/assemblers/test_xgboost.py @@ -25,17 +25,17 @@ def test_binary_classification(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(20), - ast.NumVal(16.7950000763), + ast.NumVal(16.795), ast.CompOpType.GTE), - ast.NumVal(-0.5253885984), - ast.NumVal(0.4967741966)), + ast.NumVal(-0.5178947448730469), + ast.NumVal(0.4880000054836273)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(27), - ast.NumVal(0.1423499882), + ast.NumVal(0.142349988), ast.CompOpType.GTE), - ast.NumVal(-0.4393392801), - ast.NumVal(0.3904181421)), + ast.NumVal(-0.4447747468948364), + ast.NumVal(0.39517202973365784)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -92,17 +92,17 @@ def test_regression(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.7250003815), + ast.NumVal(9.725000381469727), ast.CompOpType.GTE), - ast.NumVal(5.0069689751), - ast.NumVal(8.7252864838)), + ast.NumVal(4.995625019073486), + ast.NumVal(8.715502738952637)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.9409999847), + ast.NumVal(6.941), ast.CompOpType.GTE), - ast.NumVal(8.3520317078), - ast.NumVal(3.9274528027)), + ast.NumVal(8.309040069580078), + ast.NumVal(3.930694580078125)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -127,17 +127,17 @@ def test_regression_best_ntree_limit(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.7250003815), + ast.NumVal(9.72500038), ast.CompOpType.GTE), - ast.NumVal(5.0069689751), - ast.NumVal(8.7252864838)), + ast.NumVal(4.995625019073486), + ast.NumVal(8.715502738952637)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.9409999847), + ast.NumVal(6.94099998), ast.CompOpType.GTE), - ast.NumVal(8.3520317078), - ast.NumVal(3.9274528027)), + ast.NumVal(8.309040069580078), + ast.NumVal(3.930694580078125)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -243,17 +243,17 @@ def test_regression_saved_without_feature_names(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.7250003815), + ast.NumVal(9.72500038), ast.CompOpType.GTE), - ast.NumVal(5.0069689751), - ast.NumVal(8.7252864838)), + ast.NumVal(4.995625019073486), + ast.NumVal(8.715502738952637)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.9409999847), + ast.NumVal(6.94099998), ast.CompOpType.GTE), - ast.NumVal(8.3520317078), - ast.NumVal(3.9274528027)), + ast.NumVal(8.309040069580078), + ast.NumVal(3.930694580078125)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -274,55 +274,55 @@ def test_linear_model(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.152305), + ast.NumVal(-0.154567), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0819002), + ast.NumVal(0.0815865), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0993571), + ast.NumVal(-0.0979713), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(4.76251), + ast.NumVal(4.80472), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(1.4137), + ast.NumVal(1.35478), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(0.329731), + ast.NumVal(0.327222), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.0616366), + ast.NumVal(0.0610654), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(0.462437), + ast.NumVal(0.46989), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(-0.067064), + ast.NumVal(-0.0674318), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.000510475), + ast.NumVal(-0.000506212), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.0720296), + ast.NumVal(0.0732867), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0108551), + ast.NumVal(0.0108842), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.140799), + ast.NumVal(-0.140096), ast.BinNumOpType.MUL), ] @@ -330,7 +330,7 @@ def test_linear_model(): ast.NumVal(0.5), assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, - ast.NumVal(11.1651), + ast.NumVal(11.138), *feature_weight_mul), ast.BinNumOpType.ADD) @@ -352,17 +352,17 @@ def test_regression_random_forest(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.8410000801), + ast.NumVal(6.94099998), ast.CompOpType.GTE), - ast.NumVal(17.4066162109), - ast.NumVal(9.6789960861)), + ast.NumVal(18.38124656677246), + ast.NumVal(9.772658348083496)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(7.5799999237), + ast.NumVal(9.539999961853027), ast.CompOpType.GTE), - ast.NumVal(9.0286970139), - ast.NumVal(15.9452571869)), + ast.NumVal(8.342499732971191), + ast.NumVal(15.027499198913574)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) diff --git a/tests/test_cli.py b/tests/test_cli.py index 81d01071..6058e4fb 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -93,7 +93,7 @@ def test_generate_code(): utils.verify_python_model_is_expected( generated_code, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-39.924763953119275) + expected_output=-44.40540274041321) def test_function_name(): @@ -167,4 +167,4 @@ def test_unsupported_args_are_ignored(): utils.verify_python_model_is_expected( generated_code, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-39.924763953119275) + expected_output=-44.40540274041321) diff --git a/tests/utils.py b/tests/utils.py index 38741c9c..f4d3a803 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -128,7 +128,7 @@ def __init__(self, dataset_name, test_fraction): (self.X_train, self.X_test, self.y_train, _) = train_test_split( - self.X, self.y, test_size=test_fraction, random_state=13) + self.X, self.y, test_size=test_fraction, random_state=15) if additional_test_data is not None: self.X_test = np.vstack((additional_test_data, self.X_test))