Skip to content

Commit

Permalink
Add IdExpr (#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 13, 2020
1 parent e4eccc5 commit 3db1cc2
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
15 changes: 13 additions & 2 deletions m2cgen/ast.py
Expand Up @@ -13,6 +13,17 @@ class Expr:
to_reuse = False


class IdExpr(Expr):
def __init__(self, expr, to_reuse=False):
self.expr = expr
self.to_reuse = to_reuse
self.output_size = expr.output_size

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "IdExpr(" + args + ")"


class FeatureRef(Expr):
def __init__(self, index):
self.index = index
Expand Down Expand Up @@ -131,7 +142,7 @@ class VectorExpr(Expr):
class VectorVal(VectorExpr):

def __init__(self, exprs):
assert all(map(lambda e: e.output_size == 1, exprs)), (
assert all(e.output_size == 1 for e in exprs), (
"All expressions for VectorVal must be scalar")

self.exprs = exprs
Expand Down Expand Up @@ -240,7 +251,7 @@ def __str__(self):
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
((IdExpr, ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
]


Expand Down
3 changes: 3 additions & 0 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -90,6 +90,9 @@ def __init__(self, cg, feature_array_name="input"):
self.with_vectors = False
self.with_math_module = False

def interpret_id_expr(self, expr, **kwargs):
return self._do_interpret(expr.expr, **kwargs)

def interpret_comp_expr(self, expr, **kwargs):
op = self._cg._comp_op_overwrite(expr.op)
return self._cg.infix_expression(
Expand Down
17 changes: 9 additions & 8 deletions tests/test_ast.py
Expand Up @@ -54,13 +54,14 @@ def test_count_all_exprs_types():
ast.FeatureRef(0),
ast.BinNumOpType.ADD)
]),
ast.VectorVal([
ast.NumVal(1),
ast.NumVal(2),
ast.NumVal(3),
ast.NumVal(4),
ast.FeatureRef(1)
]),
ast.IdExpr(
ast.VectorVal([
ast.NumVal(1),
ast.NumVal(2),
ast.NumVal(3),
ast.NumVal(4),
ast.FeatureRef(1)
])),
ast.BinNumOpType.SUB),
ast.IfExpr(
ast.CompExpr(ast.NumVal(2), ast.NumVal(0), ast.CompOpType.GT),
Expand All @@ -69,7 +70,7 @@ def test_count_all_exprs_types():
),
ast.BinNumOpType.MUL)

assert ast.count_exprs(expr) == 27
assert ast.count_exprs(expr) == 28


def test_num_val():
Expand Down

0 comments on commit 3db1cc2

Please sign in to comment.