Skip to content

Commit

Permalink
Add the fallback expression for Tanh function (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 18, 2020
1 parent b6f4d54 commit ecf9444
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 35 deletions.
6 changes: 3 additions & 3 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers import fallback_expressions, utils
from m2cgen.assemblers.base import ModelAssembler
from m2cgen.assemblers.linear import _linear_to_ast

Expand Down Expand Up @@ -62,7 +62,7 @@ def _assemble_multi_class_output(self, estimator_params):
for i, e in enumerate(splits)
]

proba_exprs = utils.softmax_exprs(exprs)
proba_exprs = fallback_expressions.softmax(exprs)
return ast.VectorVal(proba_exprs)

def _assemble_bin_class_output(self, estimator_params):
Expand All @@ -76,7 +76,7 @@ def _assemble_bin_class_output(self, estimator_params):
expr = self._assemble_single_output(
estimator_params, base_score=base_score)

proba_expr = utils.sigmoid_expr(expr, to_reuse=True)
proba_expr = fallback_expressions.sigmoid(expr, to_reuse=True)

return ast.VectorVal([
ast.BinNumExpr(ast.NumVal(1), proba_expr, ast.BinNumOpType.SUB),
Expand Down
50 changes: 50 additions & 0 deletions m2cgen/assemblers/fallback_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""This module provides an implementation for a variety of functions
expressed in library's AST.
These AST-based implementations are used as fallbacks in case
when the target language lacks native support for respective functions
provided in this module.
"""
from m2cgen import ast
from m2cgen.assemblers import utils


def tanh(expr):
expr = ast.IdExpr(expr, to_reuse=True)
tanh_expr = utils.sub(
ast.NumVal(1.0),
utils.div(
ast.NumVal(2.0),
utils.add(
ast.ExpExpr(
utils.mul(
ast.NumVal(2.0),
expr)),
ast.NumVal(1.0))))
return ast.IfExpr(
utils.gt(expr, ast.NumVal(44.0)), # exp(2*x) <= 2^127
ast.NumVal(1.0),
ast.IfExpr(
utils.lt(expr, ast.NumVal(-44.0)),
ast.NumVal(-1.0),
tanh_expr))


def sigmoid(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
return ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(ast.NumVal(1), exp_expr, ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
to_reuse=to_reuse)


def softmax(exprs):
exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs]
exp_sum_expr = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD, *exp_exprs, to_reuse=True)
return [
ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV)
for e in exp_exprs
]
28 changes: 8 additions & 20 deletions m2cgen/assemblers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,18 @@ def sub(l, r, to_reuse=False):
return ast.BinNumExpr(l, r, ast.BinNumOpType.SUB, to_reuse=to_reuse)


def lt(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LT)


def lte(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LTE)


def gt(l, r):
return ast.CompExpr(l, r, ast.CompOpType.GT)


def eq(l, r):
return ast.CompExpr(l, r, ast.CompOpType.EQ)

Expand Down Expand Up @@ -79,23 +87,3 @@ def to_2d_array(var):
else:
x, y = 1, np.size(var)
return np.reshape(np.asarray(var), (x, y))


def sigmoid_expr(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
return ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(ast.NumVal(1), exp_expr, ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
to_reuse=to_reuse)


def softmax_exprs(exprs):
exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs]
exp_sum_expr = apply_op_to_expressions(ast.BinNumOpType.ADD, *exp_exprs,
to_reuse=True)
return [
ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV)
for e in exp_exprs
]
29 changes: 25 additions & 4 deletions m2cgen/interpreters/dart/tanh.dart
Original file line number Diff line number Diff line change
@@ -1,7 +1,28 @@
double tanh(double x) {
if (x > 22.0)
// Implementation is taken from
// https://github.com/golang/go/blob/master/src/math/tanh.go
double z;
z = x.abs();
if (z > 44.0148459655565271479942397125) {
if (x < 0) {
return -1.0;
}
return 1.0;
if (x < -22.0)
return -1.0;
return ((exp(2*x) - 1)/(exp(2*x) + 1));
}
if (z >= 0.625) {
z = 1 - 2 / (exp(2 * z) + 1);
if (x < 0) {
z = -z;
}
return z;
}
if (x == 0) {
return 0.0;
}
double s;
s = x * x;
z = x + x * s
* ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952)
/ (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048);
return z;
}
14 changes: 10 additions & 4 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from m2cgen import ast
from m2cgen.assemblers import fallback_expressions
from m2cgen.interpreters.utils import CachedResult, _get_handler_name


Expand Down Expand Up @@ -120,28 +121,33 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

def interpret_exp_expr(self, expr, **kwargs):
assert self.exponent_function_name, "Exponent function is not provided"
if self.exponent_function_name is NotImplemented:
raise NotImplementedError("Exponent function is not provided")
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.exponent_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
assert self.sqrt_function_name, "Sqrt function is not provided"
if self.sqrt_function_name is NotImplemented:
raise NotImplementedError("Sqrt function is not provided")
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.sqrt_function_name, nested_result)

def interpret_tanh_expr(self, expr, **kwargs):
assert self.tanh_function_name, "Tanh function is not provided"
self.with_math_module = True
if self.tanh_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.tanh(expr.expr), **kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.tanh_function_name, nested_result)

def interpret_pow_expr(self, expr, **kwargs):
assert self.power_function_name, "Power function is not provided"
if self.power_function_name is NotImplemented:
raise NotImplementedError("Power function is not provided")
self.with_math_module = True
base_result = self._do_interpret(expr.base_expr, **kwargs)
exp_result = self._do_interpret(expr.exp_expr, **kwargs)
Expand Down
29 changes: 25 additions & 4 deletions m2cgen/interpreters/visual_basic/tanh.bas
Original file line number Diff line number Diff line change
@@ -1,11 +1,32 @@
Function Tanh(ByVal number As Double) As Double
If number > 44.0 Then ' exp(2*x) <= 2^127
' Implementation is taken from
' https://github.com/golang/go/blob/master/src/math/tanh.go
Dim z As Double
z = Math.Abs(number)
If z > 44.0148459655565271479942397125 Then
If number < 0 Then
Tanh = -1.0
Exit Function
End If
Tanh = 1.0
Exit Function
End If
If number < -44.0 Then
Tanh = -1.0
If z >= 0.625 Then
z = 1 - 2 / (Math.Exp(2 * z) + 1)
If number < 0 Then
z = -z
End If
Tanh = z
Exit Function
End If
Tanh = (Math.Exp(2 * number) - 1) / (Math.Exp(2 * number) + 1)
If number = 0 Then
Tanh = 0.0
Exit Function
End If
Dim s As Double
s = number * number
z = number + number * s _
* ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952) _
/ (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048)
Tanh = z
End Function
27 changes: 27 additions & 0 deletions tests/test_fallback_expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from m2cgen import ast
from m2cgen.interpreters import PythonInterpreter

from tests.utils import assert_code_equal


def test_tanh_fallback_expr():
expr = ast.TanhExpr(ast.NumVal(2.0))

interpreter = PythonInterpreter()
interpreter.tanh_function_name = NotImplemented

expected_code = """
import math
def score(input):
var1 = 2.0
if (var1) > (44.0):
var0 = 1.0
else:
if (var1) < (-44.0):
var0 = -1.0
else:
var0 = (1.0) - ((2.0) / ((math.exp((2.0) * (var1))) + (1.0)))
return var0
"""

assert_code_equal(interpreter.interpret(expr), expected_code)

0 comments on commit ecf9444

Please sign in to comment.