Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle floating point values more accurate #277

Merged
merged 3 commits into from Aug 5, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -132,3 +132,7 @@ A: If this error occurs while generating code using an ensemble model, try to re
**Q: Generation fails with `ImportError: No module named <module_name_here>` error while transpiling model from a serialized model object.**

A: This error indicates that pickle protocol cannot deserialize model object. For unpickling serialized model objects, it is required that their classes must be defined in the top level of an importable module in the unpickling environment. So installation of package which provided model's class definition should solve the problem.

**Q: Generated by m2cgen code provides different results for some inputs compared to original Python model from which the code were obtained.**

A: Some models force input data to be particular type during prediction phase in their native Python libraries. Currently, m2cgen works only with ``float64`` (``double``) data type. You can try to cast your input data to another type manually and check results again. Also, some small differences can happen due to specific implementation of floating-point arithmetic in a target language.
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Expand Up @@ -151,7 +151,7 @@ def __init__(self, model):

def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])
return ast.NumVal(tree["leaf"], dtype=np.float32)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


threshold = ast.NumVal(tree["split_condition"], dtype=np.float32)
split = tree["split"]
Expand Down
26 changes: 9 additions & 17 deletions m2cgen/assemblers/linear.py
@@ -1,7 +1,7 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers import fallback_expressions, utils
from m2cgen.assemblers.base import ModelAssembler


Expand Down Expand Up @@ -95,14 +95,7 @@ def _get_supported_inversed_funs(self):
raise NotImplementedError

def _logit_inversed(self, ast_to_transform):
return utils.div(
ast.NumVal(1.0),
utils.add(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform))))
return fallback_expressions.sigmoid(ast_to_transform)

def _power_inversed(self, ast_to_transform):
power = self._get_power()
Expand Down Expand Up @@ -146,16 +139,15 @@ def _cloglog_inversed(self, ast_to_transform):

def _negativebinomial_inversed(self, ast_to_transform):
alpha = self._get_alpha()
res = utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))
return utils.div(
ast.NumVal(-1.0),
utils.mul(
ast.NumVal(alpha),
utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))))
utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res)

def _get_power(self):
raise NotImplementedError
Expand Down
10 changes: 1 addition & 9 deletions m2cgen/assemblers/tree.py
@@ -1,5 +1,3 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler
Expand Down Expand Up @@ -49,11 +47,5 @@ def _assemble_leaf(self, node_id):

def _assemble_cond(self, node_id):
feature_idx = self._tree.feature[node_id]
threshold = self._tree.threshold[node_id]

# sklearn's trees internally work with float32 numbers, so in order
# to have consistent results across all supported languages, we convert
# all thresholds into float32.
threshold_num_val = ast.NumVal(threshold, dtype=np.float32)

threshold_num_val = ast.NumVal(self._tree.threshold[node_id])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #190 (review).

Now threshold matches original type in scikit-learn (double).

return utils.lte(ast.FeatureRef(feature_idx), threshold_num_val)
16 changes: 11 additions & 5 deletions m2cgen/interpreters/code_generator.py
@@ -1,6 +1,10 @@
from io import StringIO
from weakref import finalize

import numpy as np

from m2cgen.interpreters.utils import format_float


class CodeTemplate:

Expand All @@ -11,12 +15,14 @@ def __str__(self):
return self.str_template

def __call__(self, *args, **kwargs):
# Force calling str() representation
# because without it numpy gives the same output
# for different float types

def _is_float(value):
return isinstance(value, (float, np.floating))

return self.str_template.format(
*[str(i) for i in args],
**{k: str(v) for k, v in kwargs.items()})
*[format_float(i) if _is_float(i) else i for i in args],
**{k: format_float(v) if _is_float(v) else v
for k, v in kwargs.items()})


class BaseCodeGenerator:
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/utils.py
@@ -1,5 +1,7 @@
import re

import numpy as np

from collections import namedtuple
from functools import lru_cache
from math import ceil, log
Expand All @@ -22,3 +24,7 @@ def _get_handler_name(expr_tpe):

def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()


def format_float(value):
return np.format_float_positional(value, unique=True, trim="0")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe format_float_scientific will be better: https://numpy.org/doc/stable/reference/generated/numpy.format_float_scientific.html. But I'm not sure how many languages support scientific notation.

86 changes: 43 additions & 43 deletions tests/assemblers/test_lightgbm.py
Expand Up @@ -26,15 +26,15 @@ def test_binary_classification():
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)),
ast.NumVal(0.26400127816506497),
ast.NumVal(0.633133056485969)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.FeatureRef(22),
ast.NumVal(105.95000000000002),
ast.CompOpType.GT),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)),
ast.NumVal(-0.18744882409486507),
ast.NumVal(0.13458899352064668)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -85,17 +85,17 @@ def test_regression():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.NumVal(6.837500000000001),
ast.CompOpType.GT),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084)),
ast.NumVal(23.961356387224317),
ast.NumVal(22.32858336612959)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.NumVal(9.725000000000003),
ast.CompOpType.GT),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594)),
ast.NumVal(-0.5031712645462916),
ast.NumVal(0.6885501354513913)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -114,18 +114,18 @@ def test_regression_random_forest():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.954000000000001),
ast.FeatureRef(12),
ast.NumVal(5.200000000000001),
ast.CompOpType.GT),
ast.NumVal(37.24347877367631),
ast.NumVal(19.936999995530854)),
ast.NumVal(20.206688945020474),
ast.NumVal(38.30000037757679)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.971500000000001),
ast.NumVal(6.837500000000001),
ast.CompOpType.GT),
ast.NumVal(38.48600037864964),
ast.NumVal(20.183783757300255)),
ast.NumVal(36.40634951405711),
ast.NumVal(19.57067132709245)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand Down Expand Up @@ -154,15 +154,15 @@ def test_simple_sigmoid_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(4.0026305187),
ast.NumVal(4.0880438137)),
ast.NumVal(4.0050691250),
ast.NumVal(4.0914737728)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(14.895),
ast.NumVal(15.065),
ast.CompOpType.GT),
ast.NumVal(-0.0412703078),
ast.NumVal(0.0208393767)),
ast.NumVal(-0.0420531079),
ast.NumVal(0.0202891577)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand All @@ -188,15 +188,15 @@ def test_log1p_exp_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(0.6623502468),
ast.NumVal(0.6683497987)),
ast.NumVal(0.6623996001),
ast.NumVal(0.6684477608)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(15.145),
ast.NumVal(15.065),
ast.CompOpType.GT),
ast.NumVal(0.1405181490),
ast.NumVal(0.1453602134)),
ast.NumVal(0.1405782705),
ast.NumVal(0.1453764991)),
ast.BinNumOpType.ADD)))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -216,17 +216,17 @@ def test_maybe_sqr_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.905),
ast.NumVal(11.655),
ast.CompOpType.GT),
ast.NumVal(4.5658116817),
ast.NumVal(4.6620790482)),
ast.NumVal(4.5671830654),
ast.NumVal(4.6516575813)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.77),
ast.NumVal(9.725),
ast.CompOpType.GT),
ast.NumVal(-0.0340889740),
ast.NumVal(0.0543687153)),
ast.NumVal(-0.0348178434),
ast.NumVal(0.0549301624)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -251,17 +251,17 @@ def test_exp_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.NumVal(6.8375),
ast.CompOpType.GT),
ast.NumVal(3.1480683932),
ast.NumVal(3.1101554907)),
ast.NumVal(3.1481886430),
ast.NumVal(3.1123367238)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.NumVal(9.725),
ast.CompOpType.GT),
ast.NumVal(-0.0111969636),
ast.NumVal(0.0160298303)),
ast.NumVal(-0.0113689739),
ast.NumVal(0.0153551274)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down Expand Up @@ -289,8 +289,8 @@ def test_bin_class_sigmoid_output_transform():
ast.FeatureRef(23),
ast.NumVal(868.2),
ast.CompOpType.GT),
ast.NumVal(0.5197386243),
ast.NumVal(1.2474356828)),
ast.NumVal(0.5280025563),
ast.NumVal(1.2662661130)),
ast.BinNumOpType.MUL),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down