diff --git a/m2cgen/assemblers/boosting.py b/m2cgen/assemblers/boosting.py index 72eabbca..2265335a 100644 --- a/m2cgen/assemblers/boosting.py +++ b/m2cgen/assemblers/boosting.py @@ -2,7 +2,7 @@ import numpy as np from m2cgen import ast -from m2cgen.assemblers import utils +from m2cgen.assemblers import fallback_expressions, utils from m2cgen.assemblers.base import ModelAssembler from m2cgen.assemblers.linear import _linear_to_ast @@ -62,7 +62,7 @@ def _assemble_multi_class_output(self, estimator_params): for i, e in enumerate(splits) ] - proba_exprs = utils.softmax_exprs(exprs) + proba_exprs = fallback_expressions.softmax(exprs) return ast.VectorVal(proba_exprs) def _assemble_bin_class_output(self, estimator_params): @@ -76,7 +76,7 @@ def _assemble_bin_class_output(self, estimator_params): expr = self._assemble_single_output( estimator_params, base_score=base_score) - proba_expr = utils.sigmoid_expr(expr, to_reuse=True) + proba_expr = fallback_expressions.sigmoid(expr, to_reuse=True) return ast.VectorVal([ ast.BinNumExpr(ast.NumVal(1), proba_expr, ast.BinNumOpType.SUB), diff --git a/m2cgen/assemblers/fallback_expressions.py b/m2cgen/assemblers/fallback_expressions.py new file mode 100644 index 00000000..8b8d1b25 --- /dev/null +++ b/m2cgen/assemblers/fallback_expressions.py @@ -0,0 +1,50 @@ +"""This module provides an implementation for a variety of functions +expressed in library's AST. + +These AST-based implementations are used as fallbacks in case +when the target language lacks native support for respective functions +provided in this module. +""" +from m2cgen import ast +from m2cgen.assemblers import utils + + +def tanh(expr): + expr = ast.IdExpr(expr, to_reuse=True) + tanh_expr = utils.sub( + ast.NumVal(1.0), + utils.div( + ast.NumVal(2.0), + utils.add( + ast.ExpExpr( + utils.mul( + ast.NumVal(2.0), + expr)), + ast.NumVal(1.0)))) + return ast.IfExpr( + utils.gt(expr, ast.NumVal(44.0)), # exp(2*x) <= 2^127 + ast.NumVal(1.0), + ast.IfExpr( + utils.lt(expr, ast.NumVal(-44.0)), + ast.NumVal(-1.0), + tanh_expr)) + + +def sigmoid(expr, to_reuse=False): + neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB) + exp_expr = ast.ExpExpr(neg_expr) + return ast.BinNumExpr( + ast.NumVal(1), + ast.BinNumExpr(ast.NumVal(1), exp_expr, ast.BinNumOpType.ADD), + ast.BinNumOpType.DIV, + to_reuse=to_reuse) + + +def softmax(exprs): + exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs] + exp_sum_expr = utils.apply_op_to_expressions( + ast.BinNumOpType.ADD, *exp_exprs, to_reuse=True) + return [ + ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV) + for e in exp_exprs + ] diff --git a/m2cgen/assemblers/utils.py b/m2cgen/assemblers/utils.py index 0be3b348..ebbbdaf2 100644 --- a/m2cgen/assemblers/utils.py +++ b/m2cgen/assemblers/utils.py @@ -19,10 +19,18 @@ def sub(l, r, to_reuse=False): return ast.BinNumExpr(l, r, ast.BinNumOpType.SUB, to_reuse=to_reuse) +def lt(l, r): + return ast.CompExpr(l, r, ast.CompOpType.LT) + + def lte(l, r): return ast.CompExpr(l, r, ast.CompOpType.LTE) +def gt(l, r): + return ast.CompExpr(l, r, ast.CompOpType.GT) + + def eq(l, r): return ast.CompExpr(l, r, ast.CompOpType.EQ) @@ -79,23 +87,3 @@ def to_2d_array(var): else: x, y = 1, np.size(var) return np.reshape(np.asarray(var), (x, y)) - - -def sigmoid_expr(expr, to_reuse=False): - neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB) - exp_expr = ast.ExpExpr(neg_expr) - return ast.BinNumExpr( - ast.NumVal(1), - ast.BinNumExpr(ast.NumVal(1), exp_expr, ast.BinNumOpType.ADD), - ast.BinNumOpType.DIV, - to_reuse=to_reuse) - - -def softmax_exprs(exprs): - exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs] - exp_sum_expr = apply_op_to_expressions(ast.BinNumOpType.ADD, *exp_exprs, - to_reuse=True) - return [ - ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV) - for e in exp_exprs - ] diff --git a/m2cgen/interpreters/dart/tanh.dart b/m2cgen/interpreters/dart/tanh.dart index f6955237..36d2ae0b 100644 --- a/m2cgen/interpreters/dart/tanh.dart +++ b/m2cgen/interpreters/dart/tanh.dart @@ -1,7 +1,28 @@ double tanh(double x) { - if (x > 22.0) + // Implementation is taken from + // https://github.com/golang/go/blob/master/src/math/tanh.go + double z; + z = x.abs(); + if (z > 44.0148459655565271479942397125) { + if (x < 0) { + return -1.0; + } return 1.0; - if (x < -22.0) - return -1.0; - return ((exp(2*x) - 1)/(exp(2*x) + 1)); + } + if (z >= 0.625) { + z = 1 - 2 / (exp(2 * z) + 1); + if (x < 0) { + z = -z; + } + return z; + } + if (x == 0) { + return 0.0; + } + double s; + s = x * x; + z = x + x * s + * ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952) + / (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048); + return z; } diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index 36f06d6d..4c75ef19 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -1,4 +1,5 @@ from m2cgen import ast +from m2cgen.assemblers import fallback_expressions from m2cgen.interpreters.utils import CachedResult, _get_handler_name @@ -120,28 +121,33 @@ def interpret_vector_val(self, expr, **kwargs): return self._cg.vector_init(nested) def interpret_exp_expr(self, expr, **kwargs): - assert self.exponent_function_name, "Exponent function is not provided" + if self.exponent_function_name is NotImplemented: + raise NotImplementedError("Exponent function is not provided") self.with_math_module = True nested_result = self._do_interpret(expr.expr, **kwargs) return self._cg.function_invocation( self.exponent_function_name, nested_result) def interpret_sqrt_expr(self, expr, **kwargs): - assert self.sqrt_function_name, "Sqrt function is not provided" + if self.sqrt_function_name is NotImplemented: + raise NotImplementedError("Sqrt function is not provided") self.with_math_module = True nested_result = self._do_interpret(expr.expr, **kwargs) return self._cg.function_invocation( self.sqrt_function_name, nested_result) def interpret_tanh_expr(self, expr, **kwargs): - assert self.tanh_function_name, "Tanh function is not provided" self.with_math_module = True + if self.tanh_function_name is NotImplemented: + return self._do_interpret( + fallback_expressions.tanh(expr.expr), **kwargs) nested_result = self._do_interpret(expr.expr, **kwargs) return self._cg.function_invocation( self.tanh_function_name, nested_result) def interpret_pow_expr(self, expr, **kwargs): - assert self.power_function_name, "Power function is not provided" + if self.power_function_name is NotImplemented: + raise NotImplementedError("Power function is not provided") self.with_math_module = True base_result = self._do_interpret(expr.base_expr, **kwargs) exp_result = self._do_interpret(expr.exp_expr, **kwargs) diff --git a/m2cgen/interpreters/visual_basic/tanh.bas b/m2cgen/interpreters/visual_basic/tanh.bas index f9bdbab1..5cd2450d 100644 --- a/m2cgen/interpreters/visual_basic/tanh.bas +++ b/m2cgen/interpreters/visual_basic/tanh.bas @@ -1,11 +1,32 @@ Function Tanh(ByVal number As Double) As Double - If number > 44.0 Then ' exp(2*x) <= 2^127 + ' Implementation is taken from + ' https://github.com/golang/go/blob/master/src/math/tanh.go + Dim z As Double + z = Math.Abs(number) + If z > 44.0148459655565271479942397125 Then + If number < 0 Then + Tanh = -1.0 + Exit Function + End If Tanh = 1.0 Exit Function End If - If number < -44.0 Then - Tanh = -1.0 + If z >= 0.625 Then + z = 1 - 2 / (Math.Exp(2 * z) + 1) + If number < 0 Then + z = -z + End If + Tanh = z Exit Function End If - Tanh = (Math.Exp(2 * number) - 1) / (Math.Exp(2 * number) + 1) + If number = 0 Then + Tanh = 0.0 + Exit Function + End If + Dim s As Double + s = number * number + z = number + number * s _ + * ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952) _ + / (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048) + Tanh = z End Function diff --git a/tests/test_fallback_expressions.py b/tests/test_fallback_expressions.py new file mode 100644 index 00000000..312f19c6 --- /dev/null +++ b/tests/test_fallback_expressions.py @@ -0,0 +1,27 @@ +from m2cgen import ast +from m2cgen.interpreters import PythonInterpreter + +from tests.utils import assert_code_equal + + +def test_tanh_fallback_expr(): + expr = ast.TanhExpr(ast.NumVal(2.0)) + + interpreter = PythonInterpreter() + interpreter.tanh_function_name = NotImplemented + + expected_code = """ +import math +def score(input): + var1 = 2.0 + if (var1) > (44.0): + var0 = 1.0 + else: + if (var1) < (-44.0): + var0 = -1.0 + else: + var0 = (1.0) - ((2.0) / ((math.exp((2.0) * (var1))) + (1.0))) + return var0 +""" + + assert_code_equal(interpreter.interpret(expr), expected_code)