From 9e5af793e8b389d0c55469df938e8ac3b3ccf90d Mon Sep 17 00:00:00 2001 From: StrikerRUS Date: Fri, 8 May 2020 05:37:45 +0300 Subject: [PATCH] added IdExpr --- m2cgen/ast.py | 15 +++++++++++++-- m2cgen/interpreters/interpreter.py | 3 +++ tests/test_ast.py | 17 +++++++++-------- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/m2cgen/ast.py b/m2cgen/ast.py index a7eea37c..7cfc9d85 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -11,6 +11,17 @@ class Expr: to_reuse = False +class IdExpr(Expr): + def __init__(self, expr, to_reuse=False): + self.expr = expr + self.to_reuse = to_reuse + self.output_size = expr.output_size + + def __str__(self): + args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) + return "IdExpr(" + args + ")" + + class FeatureRef(Expr): def __init__(self, index): self.index = index @@ -129,7 +140,7 @@ class VectorExpr(Expr): class VectorVal(VectorExpr): def __init__(self, exprs): - assert all(map(lambda e: e.output_size == 1, exprs)), ( + assert all(e.output_size == 1 for e in exprs), ( "All expressions for VectorVal must be scalar") self.exprs = exprs @@ -235,7 +246,7 @@ def __str__(self): (PowExpr, lambda e: [e.base_expr, e.exp_expr]), (VectorVal, lambda e: e.exprs), (IfExpr, lambda e: [e.test, e.body, e.orelse]), - ((ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]), + ((IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]), ] diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index 4e08f74a..69b7bed7 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -101,6 +101,9 @@ def __init__(self, cg, feature_array_name="input"): self.with_vectors = False self.with_math_module = False + def interpret_id_expr(self, expr, **kwargs): + return self._do_interpret(expr.expr, **kwargs) + def interpret_comp_expr(self, expr, **kwargs): op = self._cg._comp_op_overwrite(expr.op) return self._cg.infix_expression( diff --git a/tests/test_ast.py b/tests/test_ast.py index 32c4b4f4..d8498e12 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -54,13 +54,14 @@ def test_count_all_exprs_types(): ast.FeatureRef(0), ast.BinNumOpType.ADD) ]), - ast.VectorVal([ - ast.NumVal(1), - ast.NumVal(2), - ast.NumVal(3), - ast.NumVal(4), - ast.FeatureRef(1) - ]), + ast.IdExpr( + ast.VectorVal([ + ast.NumVal(1), + ast.NumVal(2), + ast.NumVal(3), + ast.NumVal(4), + ast.FeatureRef(1) + ])), ast.BinNumOpType.SUB), ast.IfExpr( ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT), @@ -69,7 +70,7 @@ def test_count_all_exprs_types(): ), ast.BinNumOpType.MUL) - assert ast.count_exprs(expr) == 27 + assert ast.count_exprs(expr) == 28 def test_num_val():