Skip to content

Commit

Permalink
handle floating point values more accurate
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jul 26, 2020
1 parent 591ce95 commit cfb0679
Show file tree
Hide file tree
Showing 26 changed files with 285 additions and 274 deletions.
22 changes: 14 additions & 8 deletions .github/workflows/main.yml
Expand Up @@ -3,7 +3,7 @@ name: GitHub Actions
on:
push:
branches:
- master
- floats_improvement
pull_request:
branches:
- master
Expand All @@ -17,8 +17,6 @@ jobs:
matrix:
python:
- 3.6
- 3.7
- 3.8
steps:
- name: Checkout repository
uses: actions/checkout@v1
Expand Down Expand Up @@ -46,12 +44,20 @@ jobs:
matrix:
python:
- 3.6
- 3.7
- 3.8
lang:
- "c_lang or python or java or go_lang or javascript or php or haskell or ruby"
- "c_sharp or visual_basic or f_sharp"
- "r_lang or dart"
- "c_lang"
- "python"
- "java"
- "go_lang"
- "javascript"
- "php"
- "haskell"
- "ruby"
- "c_sharp"
- "visual_basic"
- "f_sharp"
- "r_lang"
- "dart"
- "powershell"
steps:
- name: Checkout repository
Expand Down
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -132,3 +132,7 @@ A: If this error occurs while generating code using an ensemble model, try to re
**Q: Generation fails with `ImportError: No module named <module_name_here>` error while transpiling model from a serialized model object.**

A: This error indicates that pickle protocol cannot deserialize model object. For unpickling serialized model objects, it is required that their classes must be defined in the top level of an importable module in the unpickling environment. So installation of package which provided model's class definition should solve the problem.

**Q: Generated by m2cgen code provides different results for some inputs compared to original Python model from which the code were obtained.**

A: Some models force input data to be particular type during prediction phase in their native Python libraries. Currently, m2cgen works only with ``float64`` (``double``) data type. You can try to cast your input data to another type manually and check results again. Also, some small differences can happen due to specific implementation of floating-point arithmetic in a target language.
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Expand Up @@ -151,7 +151,7 @@ def __init__(self, model):

def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])
return ast.NumVal(tree["leaf"], dtype=np.float32)

threshold = ast.NumVal(tree["split_condition"], dtype=np.float32)
split = tree["split"]
Expand Down
26 changes: 9 additions & 17 deletions m2cgen/assemblers/linear.py
@@ -1,7 +1,7 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers import fallback_expressions, utils
from m2cgen.assemblers.base import ModelAssembler


Expand Down Expand Up @@ -95,14 +95,7 @@ def _get_supported_inversed_funs(self):
raise NotImplementedError

def _logit_inversed(self, ast_to_transform):
return utils.div(
ast.NumVal(1.0),
utils.add(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform))))
return fallback_expressions.sigmoid(ast_to_transform)

def _power_inversed(self, ast_to_transform):
power = self._get_power()
Expand Down Expand Up @@ -146,16 +139,15 @@ def _cloglog_inversed(self, ast_to_transform):

def _negativebinomial_inversed(self, ast_to_transform):
alpha = self._get_alpha()
res = utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))
return utils.div(
ast.NumVal(-1.0),
utils.mul(
ast.NumVal(alpha),
utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))))
utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res)

def _get_power(self):
raise NotImplementedError
Expand Down
10 changes: 1 addition & 9 deletions m2cgen/assemblers/tree.py
@@ -1,5 +1,3 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler
Expand Down Expand Up @@ -49,11 +47,5 @@ def _assemble_leaf(self, node_id):

def _assemble_cond(self, node_id):
feature_idx = self._tree.feature[node_id]
threshold = self._tree.threshold[node_id]

# sklearn's trees internally work with float32 numbers, so in order
# to have consistent results across all supported languages, we convert
# all thresholds into float32.
threshold_num_val = ast.NumVal(threshold, dtype=np.float32)

threshold_num_val = ast.NumVal(self._tree.threshold[node_id])
return utils.lte(ast.FeatureRef(feature_idx), threshold_num_val)
16 changes: 11 additions & 5 deletions m2cgen/interpreters/code_generator.py
@@ -1,6 +1,10 @@
from io import StringIO
from weakref import finalize

import numpy as np

from m2cgen.interpreters.utils import format_float


class CodeTemplate:

Expand All @@ -11,12 +15,14 @@ def __str__(self):
return self.str_template

def __call__(self, *args, **kwargs):
# Force calling str() representation
# because without it numpy gives the same output
# for different float types

def _is_float(value):
return isinstance(value, (float, np.floating))

return self.str_template.format(
*[str(i) for i in args],
**{k: str(v) for k, v in kwargs.items()})
*[format_float(i) if _is_float(i) else i for i in args],
**{k: format_float(v) if _is_float(v) else v
for k, v in kwargs.items()})


class BaseCodeGenerator:
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/utils.py
@@ -1,5 +1,7 @@
import re

import numpy as np

from collections import namedtuple
from functools import lru_cache
from math import ceil, log
Expand All @@ -22,3 +24,7 @@ def _get_handler_name(expr_tpe):

def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()


def format_float(value):
return np.format_float_positional(value, unique=True, trim="0")
86 changes: 43 additions & 43 deletions tests/assemblers/test_lightgbm.py
Expand Up @@ -26,15 +26,15 @@ def test_binary_classification():
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)),
ast.NumVal(0.26400127816506497),
ast.NumVal(0.633133056485969)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.FeatureRef(22),
ast.NumVal(105.95000000000002),
ast.CompOpType.GT),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)),
ast.NumVal(-0.18744882409486507),
ast.NumVal(0.13458899352064668)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -85,17 +85,17 @@ def test_regression():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.NumVal(6.837500000000001),
ast.CompOpType.GT),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084)),
ast.NumVal(23.961356387224317),
ast.NumVal(22.32858336612959)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.NumVal(9.725000000000003),
ast.CompOpType.GT),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594)),
ast.NumVal(-0.5031712645462916),
ast.NumVal(0.6885501354513913)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -114,18 +114,18 @@ def test_regression_random_forest():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.954000000000001),
ast.FeatureRef(12),
ast.NumVal(5.200000000000001),
ast.CompOpType.GT),
ast.NumVal(37.24347877367631),
ast.NumVal(19.936999995530854)),
ast.NumVal(20.206688945020474),
ast.NumVal(38.30000037757679)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.971500000000001),
ast.NumVal(6.837500000000001),
ast.CompOpType.GT),
ast.NumVal(38.48600037864964),
ast.NumVal(20.183783757300255)),
ast.NumVal(36.40634951405711),
ast.NumVal(19.57067132709245)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand Down Expand Up @@ -154,15 +154,15 @@ def test_simple_sigmoid_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(4.0026305187),
ast.NumVal(4.0880438137)),
ast.NumVal(4.0050691250),
ast.NumVal(4.0914737728)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(14.895),
ast.NumVal(15.065),
ast.CompOpType.GT),
ast.NumVal(-0.0412703078),
ast.NumVal(0.0208393767)),
ast.NumVal(-0.0420531079),
ast.NumVal(0.0202891577)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand All @@ -188,15 +188,15 @@ def test_log1p_exp_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(0.6623502468),
ast.NumVal(0.6683497987)),
ast.NumVal(0.6623996001),
ast.NumVal(0.6684477608)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(15.145),
ast.NumVal(15.065),
ast.CompOpType.GT),
ast.NumVal(0.1405181490),
ast.NumVal(0.1453602134)),
ast.NumVal(0.1405782705),
ast.NumVal(0.1453764991)),
ast.BinNumOpType.ADD)))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -216,17 +216,17 @@ def test_maybe_sqr_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.905),
ast.NumVal(11.655),
ast.CompOpType.GT),
ast.NumVal(4.5658116817),
ast.NumVal(4.6620790482)),
ast.NumVal(4.5671830654),
ast.NumVal(4.6516575813)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.77),
ast.NumVal(9.725),
ast.CompOpType.GT),
ast.NumVal(-0.0340889740),
ast.NumVal(0.0543687153)),
ast.NumVal(-0.0348178434),
ast.NumVal(0.0549301624)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -251,17 +251,17 @@ def test_exp_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.NumVal(6.8375),
ast.CompOpType.GT),
ast.NumVal(3.1480683932),
ast.NumVal(3.1101554907)),
ast.NumVal(3.1481886430),
ast.NumVal(3.1123367238)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.NumVal(9.725),
ast.CompOpType.GT),
ast.NumVal(-0.0111969636),
ast.NumVal(0.0160298303)),
ast.NumVal(-0.0113689739),
ast.NumVal(0.0153551274)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down Expand Up @@ -289,8 +289,8 @@ def test_bin_class_sigmoid_output_transform():
ast.FeatureRef(23),
ast.NumVal(868.2),
ast.CompOpType.GT),
ast.NumVal(0.5197386243),
ast.NumVal(1.2474356828)),
ast.NumVal(0.5280025563),
ast.NumVal(1.2662661130)),
ast.BinNumOpType.MUL),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down

0 comments on commit cfb0679

Please sign in to comment.