diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 626497b4..f28b9aef 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -3,7 +3,7 @@ name: GitHub Actions on: push: branches: - - master + - floats_improvement pull_request: branches: - master @@ -17,8 +17,6 @@ jobs: matrix: python: - 3.6 - - 3.7 - - 3.8 steps: - name: Checkout repository uses: actions/checkout@v1 @@ -46,12 +44,20 @@ jobs: matrix: python: - 3.6 - - 3.7 - - 3.8 lang: - - "c_lang or python or java or go_lang or javascript or php or haskell or ruby" - - "c_sharp or visual_basic or f_sharp" - - "r_lang or dart" + - "c_lang" + - "python" + - "java" + - "go_lang" + - "javascript" + - "php" + - "haskell" + - "ruby" + - "c_sharp" + - "visual_basic" + - "f_sharp" + - "r_lang" + - "dart" - "powershell" steps: - name: Checkout repository diff --git a/README.md b/README.md index 5df268e9..8e11606d 100644 --- a/README.md +++ b/README.md @@ -132,3 +132,7 @@ A: If this error occurs while generating code using an ensemble model, try to re **Q: Generation fails with `ImportError: No module named ` error while transpiling model from a serialized model object.** A: This error indicates that pickle protocol cannot deserialize model object. For unpickling serialized model objects, it is required that their classes must be defined in the top level of an importable module in the unpickling environment. So installation of package which provided model's class definition should solve the problem. + +**Q: Generated by m2cgen code provides different results for some inputs compared to original Python model from which the code were obtained.** + +A: Some models force input data to be particular type during prediction phase in their native Python libraries. Currently, m2cgen works only with ``float64`` (``double``) data type. You can try to cast your input data to another type manually and check results again. Also, some small differences can happen due to specific implementation of floating-point arithmetic in a target language. diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index 43893ac7..e9fd6039 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -151,7 +151,7 @@ def __init__(self, model): def _assemble_tree(self, tree): if "leaf" in tree: - return ast.NumVal(tree["leaf"]) + return ast.NumVal(tree["leaf"], dtype=np.float32) threshold = ast.NumVal(tree["split_condition"], dtype=np.float32) split = tree["split"] diff --git a/m2cgen/assemblers/linear.py b/m2cgen/assemblers/linear.py index faa8c770..90b244a5 100644 --- a/m2cgen/assemblers/linear.py +++ b/m2cgen/assemblers/linear.py @@ -1,7 +1,7 @@ import numpy as np from m2cgen import ast -from m2cgen.assemblers import utils +from m2cgen.assemblers import fallback_expressions, utils from m2cgen.assemblers.base import ModelAssembler @@ -95,14 +95,7 @@ def _get_supported_inversed_funs(self): raise NotImplementedError def _logit_inversed(self, ast_to_transform): - return utils.div( - ast.NumVal(1.0), - utils.add( - ast.NumVal(1.0), - ast.ExpExpr( - utils.sub( - ast.NumVal(0.0), - ast_to_transform)))) + return fallback_expressions.sigmoid(ast_to_transform) def _power_inversed(self, ast_to_transform): power = self._get_power() @@ -146,16 +139,15 @@ def _cloglog_inversed(self, ast_to_transform): def _negativebinomial_inversed(self, ast_to_transform): alpha = self._get_alpha() + res = utils.sub( + ast.NumVal(1.0), + ast.ExpExpr( + utils.sub( + ast.NumVal(0.0), + ast_to_transform))) return utils.div( ast.NumVal(-1.0), - utils.mul( - ast.NumVal(alpha), - utils.sub( - ast.NumVal(1.0), - ast.ExpExpr( - utils.sub( - ast.NumVal(0.0), - ast_to_transform))))) + utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res) def _get_power(self): raise NotImplementedError diff --git a/m2cgen/assemblers/tree.py b/m2cgen/assemblers/tree.py index f25afd1e..9a40a94d 100644 --- a/m2cgen/assemblers/tree.py +++ b/m2cgen/assemblers/tree.py @@ -1,5 +1,3 @@ -import numpy as np - from m2cgen import ast from m2cgen.assemblers import utils from m2cgen.assemblers.base import ModelAssembler @@ -49,11 +47,5 @@ def _assemble_leaf(self, node_id): def _assemble_cond(self, node_id): feature_idx = self._tree.feature[node_id] - threshold = self._tree.threshold[node_id] - - # sklearn's trees internally work with float32 numbers, so in order - # to have consistent results across all supported languages, we convert - # all thresholds into float32. - threshold_num_val = ast.NumVal(threshold, dtype=np.float32) - + threshold_num_val = ast.NumVal(self._tree.threshold[node_id]) return utils.lte(ast.FeatureRef(feature_idx), threshold_num_val) diff --git a/m2cgen/interpreters/code_generator.py b/m2cgen/interpreters/code_generator.py index e222d546..de39c48d 100644 --- a/m2cgen/interpreters/code_generator.py +++ b/m2cgen/interpreters/code_generator.py @@ -1,6 +1,10 @@ from io import StringIO from weakref import finalize +import numpy as np + +from m2cgen.interpreters.utils import format_float + class CodeTemplate: @@ -11,12 +15,14 @@ def __str__(self): return self.str_template def __call__(self, *args, **kwargs): - # Force calling str() representation - # because without it numpy gives the same output - # for different float types + + def _is_float(value): + return isinstance(value, (float, np.floating)) + return self.str_template.format( - *[str(i) for i in args], - **{k: str(v) for k, v in kwargs.items()}) + *[format_float(i) if _is_float(i) else i for i in args], + **{k: format_float(v) if _is_float(v) else v + for k, v in kwargs.items()}) class BaseCodeGenerator: diff --git a/m2cgen/interpreters/utils.py b/m2cgen/interpreters/utils.py index f32cd68c..8fbb6adb 100644 --- a/m2cgen/interpreters/utils.py +++ b/m2cgen/interpreters/utils.py @@ -1,5 +1,7 @@ import re +import numpy as np + from collections import namedtuple from functools import lru_cache from math import ceil, log @@ -22,3 +24,7 @@ def _get_handler_name(expr_tpe): def _normalize_expr_name(name): return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower() + + +def format_float(value): + return np.format_float_positional(value, unique=True, trim="0") diff --git a/tests/assemblers/test_lightgbm.py b/tests/assemblers/test_lightgbm.py index 3b6ad20d..7a09779d 100644 --- a/tests/assemblers/test_lightgbm.py +++ b/tests/assemblers/test_lightgbm.py @@ -26,15 +26,15 @@ def test_binary_classification(): ast.FeatureRef(23), ast.NumVal(868.2000000000002), ast.CompOpType.GT), - ast.NumVal(0.25986931215073095), - ast.NumVal(0.6237178414050242)), + ast.NumVal(0.26400127816506497), + ast.NumVal(0.633133056485969)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(7), - ast.NumVal(0.05142), + ast.FeatureRef(22), + ast.NumVal(105.95000000000002), ast.CompOpType.GT), - ast.NumVal(-0.1909605544006228), - ast.NumVal(0.1293965108676673)), + ast.NumVal(-0.18744882409486507), + ast.NumVal(0.13458899352064668)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -85,17 +85,17 @@ def test_regression(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.918), + ast.NumVal(6.837500000000001), ast.CompOpType.GT), - ast.NumVal(24.011454621684155), - ast.NumVal(22.289277544391084)), + ast.NumVal(23.961356387224317), + ast.NumVal(22.32858336612959)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.63), + ast.NumVal(9.725000000000003), ast.CompOpType.GT), - ast.NumVal(-0.49461212269771115), - ast.NumVal(0.7174324413014594)), + ast.NumVal(-0.5031712645462916), + ast.NumVal(0.6885501354513913)), ast.BinNumOpType.ADD) assert utils.cmp_exprs(actual, expected) @@ -114,18 +114,18 @@ def test_regression_random_forest(): ast.BinNumExpr( ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.954000000000001), + ast.FeatureRef(12), + ast.NumVal(5.200000000000001), ast.CompOpType.GT), - ast.NumVal(37.24347877367631), - ast.NumVal(19.936999995530854)), + ast.NumVal(20.206688945020474), + ast.NumVal(38.30000037757679)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.971500000000001), + ast.NumVal(6.837500000000001), ast.CompOpType.GT), - ast.NumVal(38.48600037864964), - ast.NumVal(20.183783757300255)), + ast.NumVal(36.40634951405711), + ast.NumVal(19.57067132709245)), ast.BinNumOpType.ADD), ast.NumVal(0.5), ast.BinNumOpType.MUL) @@ -154,15 +154,15 @@ def test_simple_sigmoid_output_transform(): ast.FeatureRef(12), ast.NumVal(19.23), ast.CompOpType.GT), - ast.NumVal(4.0026305187), - ast.NumVal(4.0880438137)), + ast.NumVal(4.0050691250), + ast.NumVal(4.0914737728)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(14.895), + ast.NumVal(15.065), ast.CompOpType.GT), - ast.NumVal(-0.0412703078), - ast.NumVal(0.0208393767)), + ast.NumVal(-0.0420531079), + ast.NumVal(0.0202891577)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -188,15 +188,15 @@ def test_log1p_exp_output_transform(): ast.FeatureRef(12), ast.NumVal(19.23), ast.CompOpType.GT), - ast.NumVal(0.6623502468), - ast.NumVal(0.6683497987)), + ast.NumVal(0.6623996001), + ast.NumVal(0.6684477608)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(15.145), + ast.NumVal(15.065), ast.CompOpType.GT), - ast.NumVal(0.1405181490), - ast.NumVal(0.1453602134)), + ast.NumVal(0.1405782705), + ast.NumVal(0.1453764991)), ast.BinNumOpType.ADD))) assert utils.cmp_exprs(actual, expected) @@ -216,17 +216,17 @@ def test_maybe_sqr_output_transform(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.905), + ast.NumVal(11.655), ast.CompOpType.GT), - ast.NumVal(4.5658116817), - ast.NumVal(4.6620790482)), + ast.NumVal(4.5671830654), + ast.NumVal(4.6516575813)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.77), + ast.NumVal(9.725), ast.CompOpType.GT), - ast.NumVal(-0.0340889740), - ast.NumVal(0.0543687153)), + ast.NumVal(-0.0348178434), + ast.NumVal(0.0549301624)), ast.BinNumOpType.ADD), to_reuse=True) @@ -251,17 +251,17 @@ def test_exp_output_transform(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.918), + ast.NumVal(6.8375), ast.CompOpType.GT), - ast.NumVal(3.1480683932), - ast.NumVal(3.1101554907)), + ast.NumVal(3.1481886430), + ast.NumVal(3.1123367238)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.63), + ast.NumVal(9.725), ast.CompOpType.GT), - ast.NumVal(-0.0111969636), - ast.NumVal(0.0160298303)), + ast.NumVal(-0.0113689739), + ast.NumVal(0.0153551274)), ast.BinNumOpType.ADD)) assert utils.cmp_exprs(actual, expected) @@ -289,8 +289,8 @@ def test_bin_class_sigmoid_output_transform(): ast.FeatureRef(23), ast.NumVal(868.2), ast.CompOpType.GT), - ast.NumVal(0.5197386243), - ast.NumVal(1.2474356828)), + ast.NumVal(0.5280025563), + ast.NumVal(1.2662661130)), ast.BinNumOpType.MUL), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), diff --git a/tests/assemblers/test_linear.py b/tests/assemblers/test_linear.py index b7344e55..fbe3a290 100644 --- a/tests/assemblers/test_linear.py +++ b/tests/assemblers/test_linear.py @@ -141,55 +141,55 @@ def test_statsmodels_wo_const(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0926871267), + ast.NumVal(-0.0940752519), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0482139967), + ast.NumVal(0.0461122112), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0075524567), + ast.NumVal(-0.0034800646), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.9965313383), + ast.NumVal(2.9669908485), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-3.0877925575), + ast.NumVal(-2.1264724710), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(5.9546630146), + ast.NumVal(5.9738064897), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0073548271), + ast.NumVal(-0.0062638276), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.9828206079), + ast.NumVal(-0.9385894841), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.1727389546), + ast.NumVal(0.1568975632), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0094218658), + ast.NumVal(-0.0091548228), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.3931071261), + ast.NumVal(-0.3949784315), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0149656744), + ast.NumVal(0.0135685532), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.4133835832), + ast.NumVal(-0.4392385223), ast.BinNumOpType.MUL), ] @@ -213,61 +213,61 @@ def test_statsmodels_w_const(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1085910250), + ast.NumVal(-0.1082106941), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0441988987), + ast.NumVal(0.0444969007), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.0174669054), + ast.NumVal(0.0189847585), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.8323210870), + ast.NumVal(2.7998640040), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-18.4837486980), + ast.NumVal(-16.7498366967), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(3.8354955484), + ast.NumVal(3.9040863643), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.0001409165), + ast.NumVal(0.0014333844), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-1.5040340047), + ast.NumVal(-1.4436181595), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.3106174852), + ast.NumVal(0.2868165881), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0123066500), + ast.NumVal(-0.0118539736), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.9736183985), + ast.NumVal(-0.9449930750), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0094039648), + ast.NumVal(0.0083181952), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.5203427347), + ast.NumVal(-0.5415938640), ast.BinNumOpType.MUL), ] expected = assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, - ast.NumVal(37.1353468527), + ast.NumVal(35.5746356887), *feature_weight_mul) assert utils.cmp_exprs(actual, expected) @@ -308,55 +308,55 @@ def test_statsmodels_processmle(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0932673973), + ast.NumVal(-0.0915126856), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0480819091), + ast.NumVal(0.0455368812), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0063734439), + ast.NumVal(-0.0092227692), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(2.7510656855), + ast.NumVal(2.8566616798), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(-3.0836268637), + ast.NumVal(-2.1208777964), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(5.9605290000), + ast.NumVal(5.9725253309), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0077880716), + ast.NumVal(-0.0061566965), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.9685365627), + ast.NumVal(-0.9414114075), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.1688777882), + ast.NumVal(0.1522429507), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.0092446419), + ast.NumVal(-0.0092123938), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(-0.3924930042), + ast.NumVal(-0.3928508764), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.01506511708295605), + ast.NumVal(0.0134405151), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.4177000096), + ast.NumVal(-0.4364996490), ast.BinNumOpType.MUL), ] @@ -650,21 +650,18 @@ def test_statsmodels_glm_negativebinomial_link_func(): ast.NumVal(-1.0), ast.BinNumExpr( ast.NumVal(1.0), - ast.BinNumExpr( - ast.NumVal(1.0), - ast.ExpExpr( + ast.ExpExpr( + ast.BinNumExpr( + ast.NumVal(0.0), ast.BinNumExpr( ast.NumVal(0.0), ast.BinNumExpr( - ast.NumVal(0.0), - ast.BinNumExpr( - ast.FeatureRef(0), - ast.NumVal(-1.1079583217), - ast.BinNumOpType.MUL), - ast.BinNumOpType.ADD), - ast.BinNumOpType.SUB)), - ast.BinNumOpType.SUB), - ast.BinNumOpType.MUL), + ast.FeatureRef(0), + ast.NumVal(-1.1079583217), + ast.BinNumOpType.MUL), + ast.BinNumOpType.ADD), + ast.BinNumOpType.SUB)), + ast.BinNumOpType.SUB), ast.BinNumOpType.DIV) assert utils.cmp_exprs(actual, expected) @@ -746,55 +743,55 @@ def test_lightning_regression(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.0961163452), + ast.NumVal(-0.0610645819), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.1574398180), + ast.NumVal(0.0856563713), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.0251799219), + ast.NumVal(-0.0562044566), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.1975142192), + ast.NumVal(0.2804204925), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(0.1189621635), + ast.NumVal(0.1359261760), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(1.2977018274), + ast.NumVal(1.6307305501), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.1192977978), + ast.NumVal(0.0866147265), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(0.0331955333), + ast.NumVal(-0.0726894150), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.1433964513), + ast.NumVal(0.0435440193), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(0.0014943531), + ast.NumVal(-0.0077364839), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.3116036672), + ast.NumVal(0.2902775116), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0258421936), + ast.NumVal(0.0229879957), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.7386996349), + ast.NumVal(-0.7614706871), ast.BinNumOpType.MUL), ] @@ -816,123 +813,123 @@ def test_lightning_binary_class(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(0.1617602138), + ast.NumVal(0.1605265174), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.0931034793), + ast.NumVal(0.1045225083), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.6279180888), + ast.NumVal(0.6237391536), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.1856722189), + ast.NumVal(0.1680225811), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(0.0009999878), + ast.NumVal(0.0011013688), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(-0.0028974470), + ast.NumVal(-0.0027528486), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(-0.0059948515), + ast.NumVal(-0.0058878714), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(-0.0024173728), + ast.NumVal(-0.0023719811), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(0.0020429247), + ast.NumVal(0.0019944105), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(0.0009604400), + ast.NumVal(0.0009924456), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.0010933747), + ast.NumVal(0.0003994860), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.0078588761), + ast.NumVal(0.0124697033), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.0069150246), + ast.NumVal(-0.0123674096), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(13), - ast.NumVal(-0.2583249885), + ast.NumVal(-0.2844204905), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(14), - ast.NumVal(0.0000097479), + ast.NumVal(0.0000273704), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(15), - ast.NumVal(-0.0007210600), + ast.NumVal(-0.0007498013), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(16), - ast.NumVal(-0.0011295195), + ast.NumVal(-0.0010784399), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(17), - ast.NumVal(-0.0001966115), + ast.NumVal(-0.0001848694), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(18), - ast.NumVal(0.0001358314), + ast.NumVal(0.0000632254), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(19), - ast.NumVal(-0.0000378118), + ast.NumVal(-0.0000369618), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(20), - ast.NumVal(0.1555921773), + ast.NumVal(0.1520223021), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(21), - ast.NumVal(0.0621307817), + ast.NumVal(0.0925348635), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(22), - ast.NumVal(0.5138354949), + ast.NumVal(0.4861047372), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(23), - ast.NumVal(-0.2418579612), + ast.NumVal(-0.2798670185), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(24), - ast.NumVal(0.0007953821), + ast.NumVal(0.0009925506), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(25), - ast.NumVal(-0.0110760214), + ast.NumVal(-0.0103414976), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(26), - ast.NumVal(-0.0162178044), + ast.NumVal(-0.0155024577), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(27), - ast.NumVal(-0.0040277699), + ast.NumVal(-0.0038881538), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(28), - ast.NumVal(0.0015067033), + ast.NumVal(0.0010126166), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(29), - ast.NumVal(0.0001536614), + ast.NumVal(0.0002312558), ast.BinNumOpType.MUL), ] @@ -959,22 +956,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(0.0935146297), + ast.NumVal(0.0895848274), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.3213921354), + ast.NumVal(0.3258329434), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.4855914264), + ast.NumVal(-0.4900856238), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(-0.2214295302), + ast.NumVal(-0.2214482506), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( @@ -984,22 +981,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1103262586), + ast.NumVal(-0.1074247041), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(-0.1662457692), + ast.NumVal(-0.1693225196), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.0379823341), + ast.NumVal(0.0357417324), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(-0.0128634938), + ast.NumVal(-0.0161614171), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( @@ -1009,22 +1006,22 @@ def test_lightning_multi_class(): ast.NumVal(0.0), ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.1685751402), + ast.NumVal(-0.1825063678), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(-0.2045901693), + ast.NumVal(-0.2185655665), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(0.2932121798), + ast.NumVal(0.3053017646), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(0.2138148665), + ast.NumVal(0.2175198459), ast.BinNumOpType.MUL), ast.BinNumOpType.ADD)]) diff --git a/tests/assemblers/test_xgboost.py b/tests/assemblers/test_xgboost.py index ba666f63..6727da14 100644 --- a/tests/assemblers/test_xgboost.py +++ b/tests/assemblers/test_xgboost.py @@ -25,17 +25,17 @@ def test_binary_classification(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(20), - ast.NumVal(16.7950001), + ast.NumVal(16.7950000763), ast.CompOpType.GTE), - ast.NumVal(-0.519171), - ast.NumVal(0.49032259)), + ast.NumVal(-0.5253885984), + ast.NumVal(0.4967741966)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(27), - ast.NumVal(0.142349988), + ast.NumVal(0.1423499882), ast.CompOpType.GTE), - ast.NumVal(-0.443304211), - ast.NumVal(0.391988248)), + ast.NumVal(-0.4393392801), + ast.NumVal(0.3904181421)), ast.BinNumOpType.ADD), ast.BinNumOpType.SUB)), ast.BinNumOpType.ADD), @@ -92,17 +92,17 @@ def test_regression(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.725), + ast.NumVal(9.7250003815), ast.CompOpType.GTE), - ast.NumVal(4.98425627), - ast.NumVal(8.75091362)), + ast.NumVal(5.0069689751), + ast.NumVal(8.7252864838)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.941), + ast.NumVal(6.9409999847), ast.CompOpType.GTE), - ast.NumVal(8.34557438), - ast.NumVal(3.9141891)), + ast.NumVal(8.3520317078), + ast.NumVal(3.9274528027)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -127,17 +127,17 @@ def test_regression_best_ntree_limit(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.72500038), + ast.NumVal(9.7250003815), ast.CompOpType.GTE), - ast.NumVal(4.98425627), - ast.NumVal(8.75091362)), + ast.NumVal(5.0069689751), + ast.NumVal(8.7252864838)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.94099998), + ast.NumVal(6.9409999847), ast.CompOpType.GTE), - ast.NumVal(8.34557438), - ast.NumVal(3.9141891)), + ast.NumVal(8.3520317078), + ast.NumVal(3.9274528027)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -243,17 +243,17 @@ def test_regression_saved_without_feature_names(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(12), - ast.NumVal(9.72500038), + ast.NumVal(9.7250003815), ast.CompOpType.GTE), - ast.NumVal(4.98425627), - ast.NumVal(8.75091362)), + ast.NumVal(5.0069689751), + ast.NumVal(8.7252864838)), ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.94099998), + ast.NumVal(6.9409999847), ast.CompOpType.GTE), - ast.NumVal(8.34557438), - ast.NumVal(3.9141891)), + ast.NumVal(8.3520317078), + ast.NumVal(3.9274528027)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) @@ -274,55 +274,55 @@ def test_linear_model(): feature_weight_mul = [ ast.BinNumExpr( ast.FeatureRef(0), - ast.NumVal(-0.151436), + ast.NumVal(-0.152305), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(1), - ast.NumVal(0.084474), + ast.NumVal(0.0819002), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(2), - ast.NumVal(-0.10035), + ast.NumVal(-0.0993571), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(3), - ast.NumVal(4.71537), + ast.NumVal(4.76251), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(4), - ast.NumVal(1.39071), + ast.NumVal(1.4137), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(5), - ast.NumVal(0.330592), + ast.NumVal(0.329731), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(6), - ast.NumVal(0.0610453), + ast.NumVal(0.0616366), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(7), - ast.NumVal(0.476255), + ast.NumVal(0.462437), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(8), - ast.NumVal(-0.0677851), + ast.NumVal(-0.067064), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(9), - ast.NumVal(-0.000543615), + ast.NumVal(-0.000510475), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(10), - ast.NumVal(0.0717916), + ast.NumVal(0.0720296), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(11), - ast.NumVal(0.010832), + ast.NumVal(0.0108551), ast.BinNumOpType.MUL), ast.BinNumExpr( ast.FeatureRef(12), - ast.NumVal(-0.139375), + ast.NumVal(-0.140799), ast.BinNumOpType.MUL), ] @@ -330,7 +330,7 @@ def test_linear_model(): ast.NumVal(0.5), assemblers.utils.apply_op_to_expressions( ast.BinNumOpType.ADD, - ast.NumVal(11.1287), + ast.NumVal(11.1651), *feature_weight_mul), ast.BinNumOpType.ADD) @@ -352,17 +352,17 @@ def test_regression_random_forest(): ast.IfExpr( ast.CompExpr( ast.FeatureRef(5), - ast.NumVal(6.94099998), + ast.NumVal(6.8410000801), ast.CompOpType.GTE), - ast.NumVal(18.1008453), - ast.NumVal(9.60167599)), + ast.NumVal(17.4066162109), + ast.NumVal(9.6789960861)), ast.IfExpr( ast.CompExpr( - ast.FeatureRef(5), - ast.NumVal(6.79699993), + ast.FeatureRef(12), + ast.NumVal(7.5799999237), ast.CompOpType.GTE), - ast.NumVal(17.780262), - ast.NumVal(9.51712894)), + ast.NumVal(9.0286970139), + ast.NumVal(15.9452571869)), ast.BinNumOpType.ADD), ast.BinNumOpType.ADD) diff --git a/tests/e2e/executors/c.py b/tests/e2e/executors/c.py index 0bc656fd..de4611c0 100644 --- a/tests/e2e/executors/c.py +++ b/tests/e2e/executors/c.py @@ -54,7 +54,7 @@ def __init__(self, model): def predict(self, X): exec_args = [os.path.join(self._resource_tmp_dir, self.model_name)] - exec_args.extend(map(str, X)) + exec_args.extend(map(interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/c_sharp.py b/tests/e2e/executors/c_sharp.py index 0b5db8cd..80e8915f 100644 --- a/tests/e2e/executors/c_sharp.py +++ b/tests/e2e/executors/c_sharp.py @@ -49,7 +49,7 @@ def __init__(self, model): def predict(self, X): exec_args = [os.path.join(self.target_exec_dir, self.project_name)] - exec_args.extend(map(str, X)) + exec_args.extend(map(interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) @classmethod diff --git a/tests/e2e/executors/dart.py b/tests/e2e/executors/dart.py index 64c489da..d7baf9ba 100644 --- a/tests/e2e/executors/dart.py +++ b/tests/e2e/executors/dart.py @@ -42,7 +42,7 @@ def predict(self, X): f"{self.executor_name}.dart") exec_args = [self._dart, file_name, - *map(str, X)] + *map(interpreters.utils.format_float, X)] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/f_sharp.py b/tests/e2e/executors/f_sharp.py index 1089b33b..96f76c3a 100644 --- a/tests/e2e/executors/f_sharp.py +++ b/tests/e2e/executors/f_sharp.py @@ -35,7 +35,7 @@ def __init__(self, model): def predict(self, X): exec_args = [os.path.join(self.target_exec_dir, self.project_name)] - exec_args.extend(map(str, X)) + exec_args.extend(map(interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) @classmethod diff --git a/tests/e2e/executors/go.py b/tests/e2e/executors/go.py index 1d5b99c4..1eca5d1e 100644 --- a/tests/e2e/executors/go.py +++ b/tests/e2e/executors/go.py @@ -55,7 +55,7 @@ def __init__(self, model): def predict(self, X): exec_args = [os.path.join(self._resource_tmp_dir, self.model_name)] - exec_args.extend(map(str, X)) + exec_args.extend(map(interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/haskell.py b/tests/e2e/executors/haskell.py index 022901cf..ba28de0d 100644 --- a/tests/e2e/executors/haskell.py +++ b/tests/e2e/executors/haskell.py @@ -39,7 +39,8 @@ def __init__(self, model): def predict(self, X): app_name = os.path.join(self._resource_tmp_dir, self.executor_name) - exec_args = [app_name, *map(str, X)] + exec_args = [app_name, + *map(interpreters.utils.format_float, X)] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/java.py b/tests/e2e/executors/java.py index e1222599..2b735793 100644 --- a/tests/e2e/executors/java.py +++ b/tests/e2e/executors/java.py @@ -24,7 +24,7 @@ def predict(self, X): self._java_bin, "-cp", self._resource_tmp_dir, "Executor", "Model", "score" ] - exec_args.extend(map(str, X)) + exec_args.extend(map(m2c.interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/javascript.py b/tests/e2e/executors/javascript.py index 5d387fc0..78b8b7c5 100644 --- a/tests/e2e/executors/javascript.py +++ b/tests/e2e/executors/javascript.py @@ -1,4 +1,5 @@ import os + from py_mini_racer import py_mini_racer import m2cgen as m2c @@ -16,10 +17,11 @@ def predict(self, X): with open(file_name, 'r') as myfile: code = myfile.read() - caller = f"score([{','.join(map(str, X))}]);\n" + args = ",".join(map(m2c.interpreters.utils.format_float, X)) + caller = f"score([{args}]);\n" ctx = py_mini_racer.MiniRacer() - result = ctx.eval(caller + code) + result = ctx.eval(f"{caller}{code}") return result diff --git a/tests/e2e/executors/php.py b/tests/e2e/executors/php.py index 9b56026e..655793f5 100644 --- a/tests/e2e/executors/php.py +++ b/tests/e2e/executors/php.py @@ -47,7 +47,7 @@ def predict(self, X): exec_args = [self._php, "-f", file_name, - *map(str, X)] + *map(interpreters.utils.format_float, X)] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/powershell.py b/tests/e2e/executors/powershell.py index a3eafca7..b80f59ac 100644 --- a/tests/e2e/executors/powershell.py +++ b/tests/e2e/executors/powershell.py @@ -40,7 +40,7 @@ def predict(self, X): "-File", file_name, "-InputArray", - ",".join(map(str, X))] + ",".join(map(interpreters.utils.format_float, X))] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/r.py b/tests/e2e/executors/r.py index f1494209..51fa8dd9 100644 --- a/tests/e2e/executors/r.py +++ b/tests/e2e/executors/r.py @@ -34,7 +34,7 @@ def predict(self, X): exec_args = [self._r, "--vanilla", file_name, - *map(str, X)] + *map(interpreters.utils.format_float, X)] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/ruby.py b/tests/e2e/executors/ruby.py index 3615e2af..71b8329b 100644 --- a/tests/e2e/executors/ruby.py +++ b/tests/e2e/executors/ruby.py @@ -38,7 +38,9 @@ def __init__(self, model): def predict(self, X): file_name = os.path.join(self._resource_tmp_dir, f"{self.model_name}.rb") - exec_args = [self._ruby, file_name, *map(str, X)] + exec_args = [self._ruby, + file_name, + *map(interpreters.utils.format_float, X)] return utils.predict_from_commandline(exec_args) def prepare(self): diff --git a/tests/e2e/executors/visual_basic.py b/tests/e2e/executors/visual_basic.py index 3209a9a4..b4189d02 100644 --- a/tests/e2e/executors/visual_basic.py +++ b/tests/e2e/executors/visual_basic.py @@ -51,7 +51,7 @@ def __init__(self, model): def predict(self, X): exec_args = [os.path.join(self.target_exec_dir, self.project_name)] - exec_args.extend(map(str, X)) + exec_args.extend(map(interpreters.utils.format_float, X)) return utils.predict_from_commandline(exec_args) @classmethod diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index b2aacb5c..4dc5b595 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -568,6 +568,8 @@ def test_e2e(estimator, executor_cls, model_trainer, with executor.prepare_then_cleanup(): for idx in idxs_to_test: y_pred_executed = executor.predict(X_test[idx]) + y_pred_executed = np.array( + y_pred_executed, dtype=y_pred_true.dtype, copy=False) print(f"expected={y_pred_true[idx]}, actual={y_pred_executed}") res = np.isclose(y_pred_true[idx], y_pred_executed, atol=ATOL) assert res if isinstance(res, bool) else res.all() diff --git a/tests/test_cli.py b/tests/test_cli.py index e98ccb2a..81d01071 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -93,7 +93,7 @@ def test_generate_code(): utils.verify_python_model_is_expected( generated_code, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-47.62913662138064) + expected_output=-39.924763953119275) def test_function_name(): @@ -167,4 +167,4 @@ def test_unsupported_args_are_ignored(): utils.verify_python_model_is_expected( generated_code, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], - expected_output=-47.62913662138064) + expected_output=-39.924763953119275) diff --git a/tests/utils.py b/tests/utils.py index 46bae2d0..a4ffc0de 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,11 +13,13 @@ from lightning.impl.base import BaseClassifier as LightBaseClassifier from sklearn import datasets from sklearn.base import BaseEstimator, RegressorMixin, clone -from sklearn.ensemble._forest import ForestClassifier -from sklearn.utils import shuffle +from sklearn.ensemble._forest import ForestClassifier, BaseForest +from sklearn.model_selection import train_test_split from sklearn.linear_model._base import LinearClassifierMixin from sklearn.tree import DecisionTreeClassifier +from sklearn.tree._classes import BaseDecisionTree from sklearn.svm import SVC, NuSVC +from sklearn.svm._base import BaseLibSVM from xgboost import XGBClassifier from m2cgen import ast @@ -70,25 +72,17 @@ def __init__(self, dataset_name, test_fraction): np.random.seed(seed=7) if dataset_name == "boston": self.name = "train_model_regression" - dataset = datasets.load_boston() - self.X, self.y = shuffle( - dataset.data, dataset.target, random_state=13) + self.X, self.y = datasets.load_boston(True) elif dataset_name == "boston_y_bounded": self.name = "train_model_regression_bounded" - dataset = datasets.load_boston() - self.X, self.y = shuffle( - dataset.data, dataset.target, random_state=13) + self.X, self.y = datasets.load_boston(True) self.y = np.arctan(self.y) / np.pi + 0.5 # (0; 1) elif dataset_name == "iris": self.name = "train_model_classification" - dataset = datasets.load_iris() - self.X, self.y = shuffle( - dataset.data, dataset.target, random_state=13) + self.X, self.y = datasets.load_iris(True) elif dataset_name == "breast_cancer": self.name = "train_model_classification_binary" - dataset = datasets.load_breast_cancer() - self.X, self.y = shuffle( - dataset.data, dataset.target, random_state=13) + self.X, self.y = datasets.load_breast_cancer(True) elif dataset_name == "regression_rnd": self.name = "train_model_regression_random_data" N = 1000 @@ -107,9 +101,9 @@ def __init__(self, dataset_name, test_fraction): else: raise ValueError("Unknown dataset name: {}".format(dataset_name)) - offset = int(self.X.shape[0] * (1 - test_fraction)) - self.X_train, self.y_train = self.X[:offset], self.y[:offset] - self.X_test, self.y_test = self.X[offset:], self.y[offset:] + (self.X_train, self.X_test, + self.y_train, self.y_test) = train_test_split( + self.X, self.y, test_size=test_fraction, random_state=13) @classmethod def get_instance(cls, dataset_name, test_fraction=0.02): @@ -125,15 +119,22 @@ def __call__(self, estimator): if isinstance(estimator, (LinearClassifierMixin, SVC, NuSVC, LightBaseClassifier)): y_pred = estimator.decision_function(self.X_test) - elif isinstance(estimator, DecisionTreeClassifier): - y_pred = estimator.predict_proba(self.X_test.astype(np.float32)) elif isinstance( estimator, - (ForestClassifier, XGBClassifier, LGBMClassifier)): + (ForestClassifier, DecisionTreeClassifier, + XGBClassifier, LGBMClassifier)): y_pred = estimator.predict_proba(self.X_test) else: y_pred = estimator.predict(self.X_test) + # Some models force input data to be particular type + # during prediction phase in their native Python libraries. + # For correct comparison of testing results we mimic the same behavior + if isinstance(estimator, (BaseDecisionTree, BaseForest)): + self.X_test = self.X_test.astype(np.float32, copy=False) + elif isinstance(estimator, BaseLibSVM): + self.X_test = self.X_test.astype(np.float64, copy=False) + return self.X_test, y_pred, fitted_estimator @@ -238,9 +239,9 @@ def predict_from_commandline(exec_args): items = stdout.decode("utf-8").strip().split(" ") if len(items) == 1: - return float(items[0]) + return np.float64(items[0]) else: - return [float(i) for i in items] + return [np.float64(i) for i in items] def cartesian_e2e_params(executors_with_marks, models_with_trainers_with_marks, @@ -284,4 +285,4 @@ def inner(*args, **kwarg): def _is_float(value): - return isinstance(value, (float, np.float16, np.float32, np.float64)) + return isinstance(value, (float, np.floating))