From 4d31bf4ba95b6651d927f2b009689e5d735a26c5 Mon Sep 17 00:00:00 2001 From: StrikerRUS Date: Thu, 4 Jun 2020 04:17:44 +0300 Subject: [PATCH] improve caching mechanism --- m2cgen/ast.py | 103 ++++++++++++++++++++++++++++++ m2cgen/interpreters/mixins.py | 6 +- tests/interpreters/test_dart.py | 20 +++--- tests/interpreters/test_python.py | 20 +++--- tests/interpreters/test_r.py | 38 +++++------ 5 files changed, 145 insertions(+), 42 deletions(-) diff --git a/m2cgen/ast.py b/m2cgen/ast.py index 452706ab..ed10bee6 100644 --- a/m2cgen/ast.py +++ b/m2cgen/ast.py @@ -23,6 +23,12 @@ def __str__(self): args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) return "IdExpr(" + args + ")" + def __eq__(self, other): + return type(other) is IdExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + class FeatureRef(Expr): def __init__(self, index): @@ -31,6 +37,12 @@ def __init__(self, index): def __str__(self): return "FeatureRef(" + str(self.index) + ")" + def __eq__(self, other): + return type(other) is FeatureRef and self.index == other.index + + def __hash__(self): + return hash(self.index) + class BinExpr(Expr): pass @@ -51,6 +63,12 @@ def __init__(self, value, dtype=None): def __str__(self): return "NumVal(" + str(self.value) + ")" + def __eq__(self, other): + return type(other) is NumVal and self.value == other.value + + def __hash__(self): + return hash(self.value) + class AbsExpr(NumExpr): def __init__(self, expr, to_reuse=False): @@ -63,6 +81,12 @@ def __str__(self): args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) return "AbsExpr(" + args + ")" + def __eq__(self, other): + return type(other) is AbsExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + class ExpExpr(NumExpr): def __init__(self, expr, to_reuse=False): @@ -75,6 +99,12 @@ def __str__(self): args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) return "ExpExpr(" + args + ")" + def __eq__(self, other): + return type(other) is ExpExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + class SqrtExpr(NumExpr): def __init__(self, expr, to_reuse=False): @@ -87,6 +117,12 @@ def __str__(self): args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) return "SqrtExpr(" + args + ")" + def __eq__(self, other): + return type(other) is SqrtExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + class TanhExpr(NumExpr): def __init__(self, expr, to_reuse=False): @@ -99,6 +135,12 @@ def __str__(self): args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)]) return "TanhExpr(" + args + ")" + def __eq__(self, other): + return type(other) is TanhExpr and self.expr == other.expr + + def __hash__(self): + return hash(self.expr) + class PowExpr(NumExpr): def __init__(self, base_expr, exp_expr, to_reuse=False): @@ -117,6 +159,14 @@ def __str__(self): ]) return "PowExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is PowExpr and + self.base_expr == other.base_expr and + self.exp_expr == other.exp_expr) + + def __hash__(self): + return hash((self.base_expr, self.exp_expr)) + class BinNumOpType(Enum): ADD = '+' @@ -144,6 +194,15 @@ def __str__(self): ]) return "BinNumExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is BinNumExpr and + self.left == other.left and + self.right == other.right and + self.op == other.op) + + def __hash__(self): + return hash((self.left, self.right, self.op)) + # Vector Expressions. @@ -164,6 +223,14 @@ def __str__(self): args = ",".join([str(e) for e in self.exprs]) return "VectorVal([" + args + "])" + def __eq__(self, other): + return (type(other) is VectorVal and + self.output_size == other.output_size and + all(i == j for i, j in zip(self.exprs, other.exprs))) + + def __hash__(self): + return hash(tuple(e for e in self.exprs)) + class BinVectorExpr(VectorExpr, BinExpr): @@ -181,6 +248,15 @@ def __str__(self): args = ",".join([str(self.left), str(self.right), self.op.name]) return "BinVectorExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is BinVectorExpr and + self.left == other.left and + self.right == other.right and + self.op == other.op) + + def __hash__(self): + return hash((self.left, self.right, self.op)) + class BinVectorNumExpr(VectorExpr, BinExpr): @@ -197,6 +273,15 @@ def __str__(self): args = ",".join([str(self.left), str(self.right), self.op.name]) return "BinVectorNumExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is BinVectorNumExpr and + self.left == other.left and + self.right == other.right and + self.op == other.op) + + def __hash__(self): + return hash((self.left, self.right, self.op)) + # Boolean Expressions. @@ -233,6 +318,15 @@ def __str__(self): args = ",".join([str(self.left), str(self.right), self.op.name]) return "CompExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is CompExpr and + self.left == other.left and + self.right == other.right and + self.op == other.op) + + def __hash__(self): + return hash((self.left, self.right, self.op)) + # Control Expressions. @@ -254,6 +348,15 @@ def __str__(self): args = ",".join([str(self.test), str(self.body), str(self.orelse)]) return "IfExpr(" + args + ")" + def __eq__(self, other): + return (type(other) is IfExpr and + self.test == other.test and + self.body == other.body and + self.orelse == other.orelse) + + def __hash__(self): + return hash((self.test, self.body, self.orelse)) + TOTAL_NUMBER_OF_EXPRESSIONS = len(getmembers(modules[__name__], isclass)) diff --git a/m2cgen/interpreters/mixins.py b/m2cgen/interpreters/mixins.py index fd4bc454..5cb7fb99 100644 --- a/m2cgen/interpreters/mixins.py +++ b/m2cgen/interpreters/mixins.py @@ -36,10 +36,10 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): # Default implementation. Simply adds new variable. def bin_depth_threshold_hook(self, expr, **kwargs): - var_name = self._cg.add_var_declaration(expr.output_size) + if expr in self._cached_expr_results: + return self._cached_expr_results[expr].var_name result = self._do_interpret(expr, **kwargs) - self._cg.add_var_assignment(var_name, result, expr.output_size) - return var_name + return self._cache_reused_expr(expr, result) class LinearAlgebraMixin(BaseToCodeInterpreter): diff --git a/tests/interpreters/test_dart.py b/tests/interpreters/test_dart.py index 14ca9e6e..493c059a 100644 --- a/tests/interpreters/test_dart.py +++ b/tests/interpreters/test_dart.py @@ -333,12 +333,12 @@ def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) - for i in range(4): - inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) + for j in range(4): + inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr( - inner, ast.NumVal(1), ast.CompOpType.EQ), + inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) @@ -348,23 +348,23 @@ def test_deep_mixed_exprs_exceeding_threshold(): double score(List input) { double var0; double var1; - var1 = (1) + ((1) + (1)); - if (((1) + ((1) + (var1))) == (1)) { + var1 = (3) + ((3) + (1)); + if (((3) + ((3) + (var1))) == (3)) { var0 = 1; } else { double var2; - var2 = (1) + ((1) + (1)); - if (((1) + ((1) + (var2))) == (1)) { + var2 = (2) + ((2) + (1)); + if (((2) + ((2) + (var2))) == (3)) { var0 = 1; } else { double var3; var3 = (1) + ((1) + (1)); - if (((1) + ((1) + (var3))) == (1)) { + if (((1) + ((1) + (var3))) == (3)) { var0 = 1; } else { double var4; - var4 = (1) + ((1) + (1)); - if (((1) + ((1) + (var4))) == (1)) { + var4 = (0) + ((0) + (1)); + if (((0) + ((0) + (var4))) == (3)) { var0 = 1; } else { var0 = 1; diff --git a/tests/interpreters/test_python.py b/tests/interpreters/test_python.py index 3eea24e4..fac1418a 100644 --- a/tests/interpreters/test_python.py +++ b/tests/interpreters/test_python.py @@ -281,12 +281,12 @@ def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) - for i in range(4): - inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) + for j in range(4): + inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr( - inner, ast.NumVal(1), ast.CompOpType.EQ), + inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) @@ -294,20 +294,20 @@ def test_deep_mixed_exprs_exceeding_threshold(): expected_code = """ def score(input): - var1 = (1) + ((1) + (1)) - if ((1) + ((1) + (var1))) == (1): + var1 = (3) + ((3) + (1)) + if ((3) + ((3) + (var1))) == (3): var0 = 1 else: - var2 = (1) + ((1) + (1)) - if ((1) + ((1) + (var2))) == (1): + var2 = (2) + ((2) + (1)) + if ((2) + ((2) + (var2))) == (3): var0 = 1 else: var3 = (1) + ((1) + (1)) - if ((1) + ((1) + (var3))) == (1): + if ((1) + ((1) + (var3))) == (3): var0 = 1 else: - var4 = (1) + ((1) + (1)) - if ((1) + ((1) + (var4))) == (1): + var4 = (0) + ((0) + (1)) + if ((0) + ((0) + (var4))) == (3): var0 = 1 else: var0 = 1 diff --git a/tests/interpreters/test_r.py b/tests/interpreters/test_r.py index 714a3beb..006f2f75 100644 --- a/tests/interpreters/test_r.py +++ b/tests/interpreters/test_r.py @@ -295,12 +295,12 @@ def test_deep_mixed_exprs_exceeding_threshold(): expr = ast.NumVal(1) for i in range(4): inner = ast.NumVal(1) - for i in range(4): - inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) + for j in range(4): + inner = ast.BinNumExpr(ast.NumVal(i), inner, ast.BinNumOpType.ADD) expr = ast.IfExpr( ast.CompExpr( - inner, ast.NumVal(1), ast.CompOpType.EQ), + inner, ast.NumVal(j), ast.CompOpType.EQ), ast.NumVal(1), expr) @@ -312,19 +312,19 @@ def test_deep_mixed_exprs_exceeding_threshold(): expected_code = """ score <- function(input) { var1 <- subroutine0(input) - if (((1) + (var1)) == (1)) { + if (((3) + (var1)) == (3)) { var0 <- 1 } else { var2 <- subroutine1(input) - if (((1) + (var2)) == (1)) { + if (((2) + (var2)) == (3)) { var0 <- 1 } else { var3 <- subroutine2(input) - if (((1) + (var3)) == (1)) { + if (((1) + (var3)) == (3)) { var0 <- 1 } else { var4 <- subroutine3(input) - if (((1) + (var4)) == (1)) { + if (((0) + (var4)) == (3)) { var0 <- 1 } else { var0 <- 1 @@ -335,24 +335,24 @@ def test_deep_mixed_exprs_exceeding_threshold(): return(var0) } subroutine0 <- function(input) { - var1 <- (1) + (1) - var0 <- (1) + (var1) - return((1) + (var0)) + var0 <- (3) + (1) + var1 <- (3) + (var0) + return((3) + (var1)) } subroutine1 <- function(input) { - var1 <- (1) + (1) - var0 <- (1) + (var1) - return((1) + (var0)) + var0 <- (2) + (1) + var1 <- (2) + (var0) + return((2) + (var1)) } subroutine2 <- function(input) { - var1 <- (1) + (1) - var0 <- (1) + (var1) - return((1) + (var0)) + var0 <- (1) + (1) + var1 <- (1) + (var0) + return((1) + (var1)) } subroutine3 <- function(input) { - var1 <- (1) + (1) - var0 <- (1) + (var1) - return((1) + (var0)) + var0 <- (0) + (1) + var1 <- (0) + (var0) + return((0) + (var1)) } """