-
Notifications
You must be signed in to change notification settings - Fork 236
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the fallback expression for Tanh function (#199)
- Loading branch information
1 parent
b6f4d54
commit ecf9444
Showing
7 changed files
with
148 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""This module provides an implementation for a variety of functions | ||
expressed in library's AST. | ||
These AST-based implementations are used as fallbacks in case | ||
when the target language lacks native support for respective functions | ||
provided in this module. | ||
""" | ||
from m2cgen import ast | ||
from m2cgen.assemblers import utils | ||
|
||
|
||
def tanh(expr): | ||
expr = ast.IdExpr(expr, to_reuse=True) | ||
tanh_expr = utils.sub( | ||
ast.NumVal(1.0), | ||
utils.div( | ||
ast.NumVal(2.0), | ||
utils.add( | ||
ast.ExpExpr( | ||
utils.mul( | ||
ast.NumVal(2.0), | ||
expr)), | ||
ast.NumVal(1.0)))) | ||
return ast.IfExpr( | ||
utils.gt(expr, ast.NumVal(44.0)), # exp(2*x) <= 2^127 | ||
ast.NumVal(1.0), | ||
ast.IfExpr( | ||
utils.lt(expr, ast.NumVal(-44.0)), | ||
ast.NumVal(-1.0), | ||
tanh_expr)) | ||
|
||
|
||
def sigmoid(expr, to_reuse=False): | ||
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB) | ||
exp_expr = ast.ExpExpr(neg_expr) | ||
return ast.BinNumExpr( | ||
ast.NumVal(1), | ||
ast.BinNumExpr(ast.NumVal(1), exp_expr, ast.BinNumOpType.ADD), | ||
ast.BinNumOpType.DIV, | ||
to_reuse=to_reuse) | ||
|
||
|
||
def softmax(exprs): | ||
exp_exprs = [ast.ExpExpr(e, to_reuse=True) for e in exprs] | ||
exp_sum_expr = utils.apply_op_to_expressions( | ||
ast.BinNumOpType.ADD, *exp_exprs, to_reuse=True) | ||
return [ | ||
ast.BinNumExpr(e, exp_sum_expr, ast.BinNumOpType.DIV) | ||
for e in exp_exprs | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,28 @@ | ||
double tanh(double x) { | ||
if (x > 22.0) | ||
// Implementation is taken from | ||
// https://github.com/golang/go/blob/master/src/math/tanh.go | ||
double z; | ||
z = x.abs(); | ||
if (z > 44.0148459655565271479942397125) { | ||
if (x < 0) { | ||
return -1.0; | ||
} | ||
return 1.0; | ||
if (x < -22.0) | ||
return -1.0; | ||
return ((exp(2*x) - 1)/(exp(2*x) + 1)); | ||
} | ||
if (z >= 0.625) { | ||
z = 1 - 2 / (exp(2 * z) + 1); | ||
if (x < 0) { | ||
z = -z; | ||
} | ||
return z; | ||
} | ||
if (x == 0) { | ||
return 0.0; | ||
} | ||
double s; | ||
s = x * x; | ||
z = x + x * s | ||
* ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952) | ||
/ (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048); | ||
return z; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,32 @@ | ||
Function Tanh(ByVal number As Double) As Double | ||
If number > 44.0 Then ' exp(2*x) <= 2^127 | ||
' Implementation is taken from | ||
' https://github.com/golang/go/blob/master/src/math/tanh.go | ||
Dim z As Double | ||
z = Math.Abs(number) | ||
If z > 44.0148459655565271479942397125 Then | ||
If number < 0 Then | ||
Tanh = -1.0 | ||
Exit Function | ||
End If | ||
Tanh = 1.0 | ||
Exit Function | ||
End If | ||
If number < -44.0 Then | ||
Tanh = -1.0 | ||
If z >= 0.625 Then | ||
z = 1 - 2 / (Math.Exp(2 * z) + 1) | ||
If number < 0 Then | ||
z = -z | ||
End If | ||
Tanh = z | ||
Exit Function | ||
End If | ||
Tanh = (Math.Exp(2 * number) - 1) / (Math.Exp(2 * number) + 1) | ||
If number = 0 Then | ||
Tanh = 0.0 | ||
Exit Function | ||
End If | ||
Dim s As Double | ||
s = number * number | ||
z = number + number * s _ | ||
* ((-0.964399179425052238628 * s + -99.2877231001918586564) * s + -1614.68768441708447952) _ | ||
/ (((s + 112.811678491632931402) * s + 2235.48839060100448583) * s + 4844.06305325125486048) | ||
Tanh = z | ||
End Function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
from m2cgen import ast | ||
from m2cgen.interpreters import PythonInterpreter | ||
|
||
from tests.utils import assert_code_equal | ||
|
||
|
||
def test_tanh_fallback_expr(): | ||
expr = ast.TanhExpr(ast.NumVal(2.0)) | ||
|
||
interpreter = PythonInterpreter() | ||
interpreter.tanh_function_name = NotImplemented | ||
|
||
expected_code = """ | ||
import math | ||
def score(input): | ||
var1 = 2.0 | ||
if (var1) > (44.0): | ||
var0 = 1.0 | ||
else: | ||
if (var1) < (-44.0): | ||
var0 = -1.0 | ||
else: | ||
var0 = (1.0) - ((2.0) / ((math.exp((2.0) * (var1))) + (1.0))) | ||
return var0 | ||
""" | ||
|
||
assert_code_equal(interpreter.interpret(expr), expected_code) |