Skip to content

Commit

Permalink
Add fallback expressions for Sqrt and Exp functions (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 22, 2020
1 parent 4ecd8b9 commit ca8bf0e
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
14 changes: 14 additions & 0 deletions m2cgen/assemblers/fallback_expressions.py
Expand Up @@ -30,6 +30,20 @@ def tanh(expr):
tanh_expr))


def sqrt(expr, to_reuse=False):
return ast.PowExpr(
base_expr=expr,
exp_expr=ast.NumVal(0.5),
to_reuse=to_reuse)


def exp(expr, to_reuse=False):
return ast.PowExpr(
base_expr=ast.NumVal(2.71828182845904523536028747135),
exp_expr=expr,
to_reuse=to_reuse)


def sigmoid(expr, to_reuse=False):
neg_expr = ast.BinNumExpr(ast.NumVal(0), expr, ast.BinNumOpType.SUB)
exp_expr = ast.ExpExpr(neg_expr)
Expand Down
12 changes: 8 additions & 4 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -121,17 +121,21 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

def interpret_exp_expr(self, expr, **kwargs):
if self.exponent_function_name is NotImplemented:
raise NotImplementedError("Exponent function is not provided")
self.with_math_module = True
if self.exponent_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.exp(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.exponent_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
if self.sqrt_function_name is NotImplemented:
raise NotImplementedError("Sqrt function is not provided")
self.with_math_module = True
if self.sqrt_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.sqrt(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.sqrt_function_name, nested_result)
Expand Down
7 changes: 0 additions & 7 deletions m2cgen/interpreters/visual_basic/interpreter.py
Expand Up @@ -68,13 +68,6 @@ def interpret_pow_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=base_result, right=exp_result, op="^")

def interpret_sqrt_expr(self, expr, **kwargs):
return self.interpret_pow_expr(
ast.PowExpr(base_expr=expr.expr,
exp_expr=ast.NumVal(0.5),
to_reuse=expr.to_reuse),
**kwargs)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_fallback_expressions.py
Expand Up @@ -25,3 +25,33 @@ def score(input):
"""

assert_code_equal(interpreter.interpret(expr), expected_code)


def test_sqrt_fallback_expr():
expr = ast.SqrtExpr(ast.NumVal(2.0))

interpreter = PythonInterpreter()
interpreter.sqrt_function_name = NotImplemented

expected_code = """
import math
def score(input):
return math.pow(2.0, 0.5)
"""

assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_fallback_expr():
expr = ast.ExpExpr(ast.NumVal(2.0))

interpreter = PythonInterpreter()
interpreter.exponent_function_name = NotImplemented

expected_code = """
import math
def score(input):
return math.pow(2.718281828459045, 2.0)
"""

assert_code_equal(interpreter.interpret(expr), expected_code)

0 comments on commit ca8bf0e

Please sign in to comment.