Skip to content

Commit

Permalink
Do not set with_math_module flag for fallback expressions (#261)
Browse files Browse the repository at this point in the history
* don't set with_math_module flag for fallback exprs

* hotfix test
  • Loading branch information
StrikerRUS committed Jul 7, 2020
1 parent a8957d8 commit a1532d7
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion m2cgen/interpreters/c/code_generator.py
Expand Up @@ -52,7 +52,7 @@ def add_assign_array_statement(self, source_var, target_var, size):
f"{size} * sizeof(double));")

def add_dependency(self, dep):
super().prepend_code_line(f"#include {dep}")
self.prepend_code_line(f"#include {dep}")

def vector_init(self, values):
return f"(double[]){{{', '.join(values)}}}"
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/go/code_generator.py
Expand Up @@ -38,7 +38,7 @@ def _get_var_declare_type(self, is_vector):
return self.vector_type if is_vector else self.scalar_type

def add_dependency(self, dep):
super().prepend_code_line(f'import "{dep}"')
self.prepend_code_line(f'import "{dep}"')

def vector_init(self, values):
return f"[]float64{{{', '.join(values)}}}"
10 changes: 5 additions & 5 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -124,20 +124,20 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

def interpret_abs_expr(self, expr, **kwargs):
self.with_math_module = True
if self.abs_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.abs(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.abs_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
self.with_math_module = True
if self.exponent_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.exp(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.exponent_function_name, nested_result)
Expand All @@ -151,29 +151,29 @@ def interpret_log_expr(self, expr, **kwargs):
self.logarithm_function_name, nested_result)

def interpret_log1p_expr(self, expr, **kwargs):
self.with_math_module = True
if self.log1p_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.log1p(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.log1p_function_name, nested_result)

def interpret_sqrt_expr(self, expr, **kwargs):
self.with_math_module = True
if self.sqrt_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.sqrt(expr.expr, to_reuse=expr.to_reuse),
**kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.sqrt_function_name, nested_result)

def interpret_tanh_expr(self, expr, **kwargs):
self.with_math_module = True
if self.tanh_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.tanh(expr.expr), **kwargs)
self.with_math_module = True
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.tanh_function_name, nested_result)
Expand Down
1 change: 0 additions & 1 deletion tests/test_fallback_expressions.py
Expand Up @@ -11,7 +11,6 @@ def test_abs_fallback_expr():
interpreter.abs_function_name = NotImplemented

expected_code = """
#include <math.h>
double score(double * input) {
double var0;
double var1;
Expand Down

0 comments on commit a1532d7

Please sign in to comment.