Skip to content

Commit

Permalink
Merge 9f4bf22 into ca8bf0e
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 23, 2020
2 parents ca8bf0e + 9f4bf22 commit 1ed7538
Show file tree
Hide file tree
Showing 33 changed files with 282 additions and 4 deletions.
8 changes: 8 additions & 0 deletions m2cgen/assemblers/fallback_expressions.py
Expand Up @@ -9,6 +9,14 @@
from m2cgen.assemblers import utils


def abs(expr):
expr = ast.IdExpr(expr, to_reuse=True)
return ast.IfExpr(
utils.lt(expr, ast.NumVal(0)),
utils.sub(ast.NumVal(0.0), expr),
expr)


def tanh(expr):
expr = ast.IdExpr(expr, to_reuse=True)
tanh_expr = utils.sub(
Expand Down
14 changes: 13 additions & 1 deletion m2cgen/ast.py
Expand Up @@ -52,6 +52,18 @@ def __str__(self):
return "NumVal(" + str(self.value) + ")"


class AbsExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"

self.expr = expr
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "AbsExpr(" + args + ")"


class ExpExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
assert expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -251,7 +263,7 @@ def __str__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
((AbsExpr, IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
]


Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -17,6 +17,7 @@ class CInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mul_vector_number",
}

abs_function_name = "fabs"
exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/c/linear_algebra.c
Expand Up @@ -5,4 +5,4 @@ void add_vectors(double *v1, double *v2, int size, double *result) {
void mul_vector_number(double *v1, double num, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] * num;
}
}
1 change: 1 addition & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Expand Up @@ -18,6 +18,7 @@ class CSharpInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "MulVectorNumber",
}

abs_function_name = "Abs"
exponent_function_name = "Exp"
power_function_name = "Pow"
sqrt_function_name = "Sqrt"
Expand Down
4 changes: 4 additions & 0 deletions m2cgen/interpreters/dart/code_generator.py
Expand Up @@ -27,6 +27,10 @@ def function_definition(self, name, args, is_vector_output):
yield
self.add_block_termination()

def method_invocation(self, method_name, obj, args):
return ("(" + str(obj) + ")." + method_name +
"(" + ", ".join(map(str, args)) + ")")

def vector_init(self, values):
return "[" + ", ".join(values) + "]"

Expand Down
7 changes: 7 additions & 0 deletions m2cgen/interpreters/dart/interpreter.py
Expand Up @@ -21,6 +21,7 @@ class DartInterpreter(ImperativeToCodeInterpreter,

bin_depth_threshold = 465

abs_function_name = "abs"
exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
Expand Down Expand Up @@ -64,6 +65,12 @@ def interpret(self, expr):

return self._cg.finalize_and_get_generated_code()

def interpret_abs_expr(self, expr, **kwargs):
return self._cg.method_invocation(
method_name=self.abs_function_name,
obj=self._do_interpret(expr.expr, **kwargs),
args=[])

def interpret_tanh_expr(self, expr, **kwargs):
self.with_tanh_expr = True
return super(
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/go/interpreter.py
Expand Up @@ -16,6 +16,7 @@ class GoInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mulVectorNumber",
}

abs_function_name = "math.Abs"
exponent_function_name = "math.Exp"
power_function_name = "math.Pow"
sqrt_function_name = "math.Sqrt"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/haskell/interpreter.py
Expand Up @@ -16,6 +16,7 @@ class HaskellInterpreter(ToCodeInterpreter,
ast.BinNumOpType.MUL: "mulVectorNumber",
}

abs_function_name = "abs"
exponent_function_name = "exp"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"
Expand Down
10 changes: 10 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -81,6 +81,7 @@ class ToCodeInterpreter(BaseToCodeInterpreter):
about AST.
"""

abs_function_name = NotImplemented
exponent_function_name = NotImplemented
power_function_name = NotImplemented
sqrt_function_name = NotImplemented
Expand Down Expand Up @@ -120,6 +121,15 @@ def interpret_vector_val(self, expr, **kwargs):
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.vector_init(nested)

def interpret_abs_expr(self, expr, **kwargs):
self.with_math_module = True
if self.abs_function_name is NotImplemented:
return self._do_interpret(
fallback_expressions.abs(expr.expr), **kwargs)
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.abs_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
self.with_math_module = True
if self.exponent_function_name is NotImplemented:
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/java/interpreter.py
Expand Up @@ -24,6 +24,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mulVectorNumber",
}

abs_function_name = "Math.abs"
exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
sqrt_function_name = "Math.sqrt"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/javascript/interpreter.py
Expand Up @@ -19,6 +19,7 @@ class JavascriptInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mulVectorNumber",
}

abs_function_name = "Math.abs"
exponent_function_name = "Math.exp"
power_function_name = "Math.pow"
sqrt_function_name = "Math.sqrt"
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/php/interpreter.py
Expand Up @@ -17,6 +17,7 @@ class PhpInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mulVectorNumber",
}

abs_function_name = "abs"
exponent_function_name = "exp"
power_function_name = "pow"
sqrt_function_name = "sqrt"
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/powershell/interpreter.py
Expand Up @@ -18,6 +18,7 @@ class PowershellInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "Mul-Vector-Number",
}

abs_function_name = "[math]::Abs"
exponent_function_name = "[math]::Exp"
power_function_name = "[math]::Pow"
sqrt_function_name = "[math]::Sqrt"
Expand Down Expand Up @@ -48,6 +49,11 @@ def interpret(self, expr):

return self._cg.finalize_and_get_generated_code()

def interpret_abs_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
self.abs_function_name, nested_result)

def interpret_exp_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.math_function_invocation(
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/python/interpreter.py
Expand Up @@ -21,6 +21,7 @@ class PythonInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mul_vector_number",
}

abs_function_name = "abs"
exponent_function_name = "math.exp"
power_function_name = "math.pow"
sqrt_function_name = "math.sqrt"
Expand Down Expand Up @@ -51,3 +52,8 @@ def interpret(self, expr):
self._cg.add_dependency("math")

return self._cg.finalize_and_get_generated_code()

def interpret_abs_expr(self, expr, **kwargs):
nested_result = self._do_interpret(expr.expr, **kwargs)
return self._cg.function_invocation(
self.abs_function_name, nested_result)
1 change: 1 addition & 0 deletions m2cgen/interpreters/r/interpreter.py
Expand Up @@ -22,6 +22,7 @@ class RInterpreter(ImperativeToCodeInterpreter,
ast_size_check_frequency = 2
ast_size_per_subroutine_threshold = 200

abs_function_name = "abs"
exponent_function_name = "exp"
sqrt_function_name = "sqrt"
tanh_function_name = "tanh"
Expand Down
7 changes: 7 additions & 0 deletions m2cgen/interpreters/ruby/interpreter.py
Expand Up @@ -17,6 +17,7 @@ class RubyInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "mul_vector_number",
}

abs_function_name = "abs"
exponent_function_name = "Math.exp"
sqrt_function_name = "Math.sqrt"
tanh_function_name = "Math.tanh"
Expand Down Expand Up @@ -54,6 +55,12 @@ def interpret_bin_num_expr(self, expr, **kwargs):
else:
return super().interpret_bin_num_expr(expr, **kwargs)

def interpret_abs_expr(self, expr, **kwargs):
return self._cg.method_invocation(
method_name=self.abs_function_name,
obj=self._do_interpret(expr.expr, **kwargs),
args=[])

def interpret_pow_expr(self, expr, **kwargs):
base_result = self._do_interpret(expr.base_expr, **kwargs)
exp_result = self._do_interpret(expr.exp_expr, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/visual_basic/interpreter.py
Expand Up @@ -17,6 +17,7 @@ class VisualBasicInterpreter(ImperativeToCodeInterpreter,
ast.BinNumOpType.MUL: "MulVectorNumber",
}

abs_function_name = "Math.Abs"
exponent_function_name = "Math.Exp"
tanh_function_name = "Tanh"

Expand Down
14 changes: 14 additions & 0 deletions tests/interpreters/test_c.py
Expand Up @@ -212,6 +212,20 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

interpreter = interpreters.CInterpreter()

expected_code = """
#include <math.h>
double score(double * input) {
return fabs(-1.0);
}"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
18 changes: 18 additions & 0 deletions tests/interpreters/test_c_sharp.py
Expand Up @@ -293,6 +293,24 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

expected_code = """
using static System.Math;
namespace ML {
public static class Model {
public static double Score(double[] input) {
return Abs(-1.0);
}
}
}
"""

interpreter = CSharpInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
13 changes: 13 additions & 0 deletions tests/interpreters/test_dart.py
Expand Up @@ -379,6 +379,19 @@ def test_deep_mixed_exprs_exceeding_threshold():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

expected_code = """
double score(List<double> input) {
return (-1.0).abs();
}
"""

interpreter = DartInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
14 changes: 14 additions & 0 deletions tests/interpreters/test_go.py
Expand Up @@ -215,6 +215,20 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

interpreter = interpreters.GoInterpreter()

expected_code = """
import "math"
func score(input []float64) float64 {
return math.Abs(-1.0)
}"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
14 changes: 14 additions & 0 deletions tests/interpreters/test_haskell.py
Expand Up @@ -261,6 +261,20 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

expected_code = """
module Model where
score :: [Double] -> Double
score input =
abs (-1.0)
"""

interpreter = HaskellInterpreter()
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
15 changes: 15 additions & 0 deletions tests/interpreters/test_java.py
Expand Up @@ -283,6 +283,21 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

interpreter = interpreters.JavaInterpreter()

expected_code = """
public class Model {
public static double score(double[] input) {
return Math.abs(-1.0);
}
}"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down
14 changes: 14 additions & 0 deletions tests/interpreters/test_javascript.py
Expand Up @@ -230,6 +230,20 @@ def test_bin_vector_num_expr():
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_abs_expr():
expr = ast.AbsExpr(ast.NumVal(-1.0))

interpreter = interpreters.JavascriptInterpreter()

expected_code = """
function score(input) {
return Math.abs(-1.0);
}
"""

utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_exp_expr():
expr = ast.ExpExpr(ast.NumVal(1.0))

Expand Down

0 comments on commit 1ed7538

Please sign in to comment.