Skip to content

Commit

Permalink
Merge 4d31bf4 into 3bcea02
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jun 5, 2020
2 parents 3bcea02 + 4d31bf4 commit 51e0bdc
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 42 deletions.
103 changes: 103 additions & 0 deletions m2cgen/ast.py
Expand Up @@ -23,6 +23,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "IdExpr(" + args + ")"

def __eq__(self, other):
return type(other) is IdExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class FeatureRef(Expr):
def __init__(self, index):
Expand All @@ -31,6 +37,12 @@ def __init__(self, index):
def __str__(self):
return "FeatureRef(" + str(self.index) + ")"

def __eq__(self, other):
return type(other) is FeatureRef and self.index == other.index

def __hash__(self):
return hash(self.index)


class BinExpr(Expr):
pass
Expand All @@ -51,6 +63,12 @@ def __init__(self, value, dtype=None):
def __str__(self):
return "NumVal(" + str(self.value) + ")"

def __eq__(self, other):
return type(other) is NumVal and self.value == other.value

def __hash__(self):
return hash(self.value)


class AbsExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -63,6 +81,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "AbsExpr(" + args + ")"

def __eq__(self, other):
return type(other) is AbsExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class ExpExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -75,6 +99,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "ExpExpr(" + args + ")"

def __eq__(self, other):
return type(other) is ExpExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class SqrtExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -87,6 +117,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SqrtExpr(" + args + ")"

def __eq__(self, other):
return type(other) is SqrtExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class TanhExpr(NumExpr):
def __init__(self, expr, to_reuse=False):
Expand All @@ -99,6 +135,12 @@ def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "TanhExpr(" + args + ")"

def __eq__(self, other):
return type(other) is TanhExpr and self.expr == other.expr

def __hash__(self):
return hash(self.expr)


class PowExpr(NumExpr):
def __init__(self, base_expr, exp_expr, to_reuse=False):
Expand All @@ -117,6 +159,14 @@ def __str__(self):
])
return "PowExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is PowExpr and
self.base_expr == other.base_expr and
self.exp_expr == other.exp_expr)

def __hash__(self):
return hash((self.base_expr, self.exp_expr))


class BinNumOpType(Enum):
ADD = '+'
Expand Down Expand Up @@ -144,6 +194,15 @@ def __str__(self):
])
return "BinNumExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is BinNumExpr and
self.left == other.left and
self.right == other.right and
self.op == other.op)

def __hash__(self):
return hash((self.left, self.right, self.op))


# Vector Expressions.

Expand All @@ -164,6 +223,14 @@ def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return "VectorVal([" + args + "])"

def __eq__(self, other):
return (type(other) is VectorVal and
self.output_size == other.output_size and
all(i == j for i, j in zip(self.exprs, other.exprs)))

def __hash__(self):
return hash(tuple(e for e in self.exprs))


class BinVectorExpr(VectorExpr, BinExpr):

Expand All @@ -181,6 +248,15 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is BinVectorExpr and
self.left == other.left and
self.right == other.right and
self.op == other.op)

def __hash__(self):
return hash((self.left, self.right, self.op))


class BinVectorNumExpr(VectorExpr, BinExpr):

Expand All @@ -197,6 +273,15 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorNumExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is BinVectorNumExpr and
self.left == other.left and
self.right == other.right and
self.op == other.op)

def __hash__(self):
return hash((self.left, self.right, self.op))


# Boolean Expressions.

Expand Down Expand Up @@ -233,6 +318,15 @@ def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "CompExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is CompExpr and
self.left == other.left and
self.right == other.right and
self.op == other.op)

def __hash__(self):
return hash((self.left, self.right, self.op))


# Control Expressions.

Expand All @@ -254,6 +348,15 @@ def __str__(self):
args = ",".join([str(self.test), str(self.body), str(self.orelse)])
return "IfExpr(" + args + ")"

def __eq__(self, other):
return (type(other) is IfExpr and
self.test == other.test and
self.body == other.body and
self.orelse == other.orelse)

def __hash__(self):
return hash((self.test, self.body, self.orelse))


TOTAL_NUMBER_OF_EXPRESSIONS = len(getmembers(modules[__name__], isclass))

Expand Down
6 changes: 3 additions & 3 deletions m2cgen/interpreters/mixins.py
Expand Up @@ -36,10 +36,10 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):

# Default implementation. Simply adds new variable.
def bin_depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
if expr in self._cached_expr_results:
return self._cached_expr_results[expr].var_name
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name
return self._cache_reused_expr(expr, result)


class LinearAlgebraMixin(BaseToCodeInterpreter):
Expand Down
20 changes: 10 additions & 10 deletions tests/interpreters/test_dart.py
Expand Up @@ -333,12 +333,12 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)
for j in range(4):
inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
inner, ast.NumVal(j), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

Expand All @@ -348,23 +348,23 @@ def test_deep_mixed_exprs_exceeding_threshold():
double score(List<double> input) {
double var0;
double var1;
var1 = (1) + ((1) + (1));
if (((1) + ((1) + (var1))) == (1)) {
var1 = (3) + ((3) + (1));
if (((3) + ((3) + (var1))) == (3)) {
var0 = 1;
} else {
double var2;
var2 = (1) + ((1) + (1));
if (((1) + ((1) + (var2))) == (1)) {
var2 = (2) + ((2) + (1));
if (((2) + ((2) + (var2))) == (3)) {
var0 = 1;
} else {
double var3;
var3 = (1) + ((1) + (1));
if (((1) + ((1) + (var3))) == (1)) {
if (((1) + ((1) + (var3))) == (3)) {
var0 = 1;
} else {
double var4;
var4 = (1) + ((1) + (1));
if (((1) + ((1) + (var4))) == (1)) {
var4 = (0) + ((0) + (1));
if (((0) + ((0) + (var4))) == (3)) {
var0 = 1;
} else {
var0 = 1;
Expand Down
20 changes: 10 additions & 10 deletions tests/interpreters/test_python.py
Expand Up @@ -281,33 +281,33 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)
for j in range(4):
inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
inner, ast.NumVal(j), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

interpreter = CustomPythonInterpreter()

expected_code = """
def score(input):
var1 = (1) + ((1) + (1))
if ((1) + ((1) + (var1))) == (1):
var1 = (3) + ((3) + (1))
if ((3) + ((3) + (var1))) == (3):
var0 = 1
else:
var2 = (1) + ((1) + (1))
if ((1) + ((1) + (var2))) == (1):
var2 = (2) + ((2) + (1))
if ((2) + ((2) + (var2))) == (3):
var0 = 1
else:
var3 = (1) + ((1) + (1))
if ((1) + ((1) + (var3))) == (1):
if ((1) + ((1) + (var3))) == (3):
var0 = 1
else:
var4 = (1) + ((1) + (1))
if ((1) + ((1) + (var4))) == (1):
var4 = (0) + ((0) + (1))
if ((0) + ((0) + (var4))) == (3):
var0 = 1
else:
var0 = 1
Expand Down
38 changes: 19 additions & 19 deletions tests/interpreters/test_r.py
Expand Up @@ -295,12 +295,12 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr = ast.NumVal(1)
for i in range(4):
inner = ast.NumVal(1)
for i in range(4):
inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD)
for j in range(4):
inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD)

expr = ast.IfExpr(
ast.CompExpr(
inner, ast.NumVal(1), ast.CompOpType.EQ),
inner, ast.NumVal(j), ast.CompOpType.EQ),
ast.NumVal(1),
expr)

Expand All @@ -312,19 +312,19 @@ def test_deep_mixed_exprs_exceeding_threshold():
expected_code = """
score <- function(input) {
var1 <- subroutine0(input)
if (((1) + (var1)) == (1)) {
if (((3) + (var1)) == (3)) {
var0 <- 1
} else {
var2 <- subroutine1(input)
if (((1) + (var2)) == (1)) {
if (((2) + (var2)) == (3)) {
var0 <- 1
} else {
var3 <- subroutine2(input)
if (((1) + (var3)) == (1)) {
if (((1) + (var3)) == (3)) {
var0 <- 1
} else {
var4 <- subroutine3(input)
if (((1) + (var4)) == (1)) {
if (((0) + (var4)) == (3)) {
var0 <- 1
} else {
var0 <- 1
Expand All @@ -335,24 +335,24 @@ def test_deep_mixed_exprs_exceeding_threshold():
return(var0)
}
subroutine0 <- function(input) {
var1 <- (1) + (1)
var0 <- (1) + (var1)
return((1) + (var0))
var0 <- (3) + (1)
var1 <- (3) + (var0)
return((3) + (var1))
}
subroutine1 <- function(input) {
var1 <- (1) + (1)
var0 <- (1) + (var1)
return((1) + (var0))
var0 <- (2) + (1)
var1 <- (2) + (var0)
return((2) + (var1))
}
subroutine2 <- function(input) {
var1 <- (1) + (1)
var0 <- (1) + (var1)
return((1) + (var0))
var0 <- (1) + (1)
var1 <- (1) + (var0)
return((1) + (var1))
}
subroutine3 <- function(input) {
var1 <- (1) + (1)
var0 <- (1) + (var1)
return((1) + (var0))
var0 <- (0) + (1)
var1 <- (0) + (var0)
return((0) + (var1))
}
"""

Expand Down

0 comments on commit 51e0bdc

Please sign in to comment.