Skip to content

Commit

Permalink
Merge 9d0daf3 into 9072ba0
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Jan 12, 2020
2 parents 9072ba0 + 9d0daf3 commit 5793414
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 57 deletions.
34 changes: 17 additions & 17 deletions tests/assemblers/test_lightgbm.py
Expand Up @@ -28,16 +28,16 @@ def test_binary_classification():
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.2762557140263451),
ast.NumVal(0.6399134166614473)),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.14205000000000004),
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.2139321843285849),
ast.NumVal(0.1151466338793227)),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -95,18 +95,18 @@ def test_regression():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8455),
ast.NumVal(6.918),
ast.CompOpType.GT),
ast.NumVal(24.007392728914056),
ast.NumVal(22.35695742616179)),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.4903836928981587),
ast.NumVal(0.7222498915097475)),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down Expand Up @@ -138,17 +138,17 @@ def test_leaves_cutoff_threshold():
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.2762557140263451),
ast.NumVal(0.6399134166614473))),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.14205000000000004),
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.2139321843285849),
ast.NumVal(0.1151466338793227))),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673))),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down
60 changes: 30 additions & 30 deletions tests/assemblers/test_xgboost.py
Expand Up @@ -29,16 +29,16 @@ def test_binary_classification():
ast.FeatureRef(20),
ast.NumVal(16.7950001),
ast.CompOpType.GTE),
ast.NumVal(-0.17062147),
ast.NumVal(0.1638484)),
ast.NumVal(-0.173057005),
ast.NumVal(0.163440868)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.142349988),
ast.CompOpType.GTE),
ast.NumVal(-0.16087772),
ast.NumVal(0.149866998)),
ast.NumVal(-0.161026895),
ast.NumVal(0.149405137)),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -99,16 +99,16 @@ def test_regression():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -135,16 +135,16 @@ def test_regression_best_ntree_limit():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -169,10 +169,10 @@ def test_multi_class_best_ntree_limit():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(2.5999999),
ast.NumVal(2.45000005),
ast.CompOpType.GTE),
ast.NumVal(-0.0731707439),
ast.NumVal(0.142857149)),
ast.NumVal(-0.0733167157),
ast.NumVal(0.143414631)),
ast.BinNumOpType.ADD)),
to_reuse=True)

Expand All @@ -183,10 +183,10 @@ def test_multi_class_best_ntree_limit():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(2.5999999),
ast.NumVal(2.45000005),
ast.CompOpType.GTE),
ast.NumVal(0.0341463387),
ast.NumVal(-0.0714285821)),
ast.NumVal(0.0344139598),
ast.NumVal(-0.0717073306)),
ast.BinNumOpType.ADD)),
to_reuse=True)

Expand All @@ -196,11 +196,11 @@ def test_multi_class_best_ntree_limit():
ast.NumVal(0.5),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(2),
ast.NumVal(4.85000038),
ast.FeatureRef(3),
ast.NumVal(1.6500001),
ast.CompOpType.GTE),
ast.NumVal(0.129441619),
ast.NumVal(-0.0681440532)),
ast.NumVal(0.13432835),
ast.NumVal(-0.0644444525)),
ast.BinNumOpType.ADD)),
to_reuse=True)

Expand Down Expand Up @@ -255,16 +255,16 @@ def test_regression_saved_without_feature_names():
ast.FeatureRef(12),
ast.NumVal(9.72500038),
ast.CompOpType.GTE),
ast.NumVal(1.67318344),
ast.NumVal(2.92757893)),
ast.NumVal(1.6614188),
ast.NumVal(2.91697121)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.94099998),
ast.CompOpType.GTE),
ast.NumVal(3.3400948),
ast.NumVal(1.72118247)),
ast.NumVal(3.33810854),
ast.NumVal(1.71813202)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down Expand Up @@ -296,17 +296,17 @@ def test_leaves_cutoff_threshold():
ast.FeatureRef(20),
ast.NumVal(16.7950001),
ast.CompOpType.GTE),
ast.NumVal(-0.17062147),
ast.NumVal(0.1638484))),
ast.NumVal(-0.173057005),
ast.NumVal(0.163440868))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(27),
ast.NumVal(0.142349988),
ast.CompOpType.GTE),
ast.NumVal(-0.16087772),
ast.NumVal(0.149866998))),
ast.NumVal(-0.161026895),
ast.NumVal(0.149405137))),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down
5 changes: 4 additions & 1 deletion tests/e2e/test_cli.py
@@ -1,3 +1,4 @@
import pytest
import pickle
import platform
import subprocess
Expand All @@ -15,7 +16,7 @@ def execute_test(exec_args):
utils.verify_python_model_is_expected(
generated_code,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
expected_output=-41.89077994476439)
expected_output=-47.62913662138064)


def _prepare_pickled_model(tmp_path):
Expand Down Expand Up @@ -50,6 +51,8 @@ def test_piped(tmp_path):
execute_test(exec_args)


@pytest.mark.skip(reason="utils.verify_python_model_is_expected "
"doesn't support modules")
def test_dash_m(tmp_path):
pickled_model_path = _prepare_pickled_model(tmp_path)
exec_args = ["python", "-m", "m2cgen", "--language", "python",
Expand Down
14 changes: 7 additions & 7 deletions tests/interpreters/test_visual_basic.py
Expand Up @@ -13,7 +13,7 @@ def test_if_expr():
Module Model
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
If (1) == (input_vector(0)) Then
If (1) = (input_vector(0)) Then
var0 = 2
Else
var0 = 3
Expand Down Expand Up @@ -67,7 +67,7 @@ def test_dependable_condition():
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
Dim var1 As Double
If (1) == (1) Then
If (1) = (1) Then
var1 = 1
Else
var1 = 2
Expand Down Expand Up @@ -108,19 +108,19 @@ def test_nested_condition():
Function score(ByRef input_vector() As Double) As Double
Dim var0 As Double
Dim var1 As Double
If (1) == (1) Then
If (1) = (1) Then
var1 = 1
Else
var1 = 2
End If
If (1) == ((var1) + (2)) Then
If (1) = ((var1) + (2)) Then
Dim var2 As Double
If (1) == (1) Then
If (1) = (1) Then
var2 = 1
Else
var2 = 2
End If
If (1) == ((var2) + (2)) Then
If (1) = ((var2) + (2)) Then
var0 = input_vector(2)
Else
var0 = 2
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_multi_output():
Module Model
Function score(ByRef input_vector() As Double) As Double()
Dim var0() As Double
If (1) == (1) Then
If (1) = (1) Then
Dim var1(1) As Double
var1(0) = 1
var1(1) = 2
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Expand Up @@ -83,7 +83,7 @@ def test_generate_code():
utils.verify_python_model_is_expected(
generated_code,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
expected_output=-41.89077994476439)
expected_output=-47.62913662138064)


def test_class_name():
Expand Down Expand Up @@ -137,4 +137,4 @@ def test_unsupported_args_are_ignored():
utils.verify_python_model_is_expected(
generated_code,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],
expected_output=-41.89077994476439)
expected_output=-47.62913662138064)

0 comments on commit 5793414

Please sign in to comment.