Skip to content

Commit

Permalink
refactor depth tracking of binary expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart committed Feb 6, 2019
1 parent 1a621ac commit 7931b00
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 19 deletions.
2 changes: 1 addition & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class CompOpType(Enum):
NOT_EQ = '!='


class CompExpr(BoolExpr, BinExpr):
class CompExpr(BoolExpr):
def __init__(self, left, right, op):
assert left.output_size == 1, "Only scalars are supported"
assert right.output_size == 1, "Only scalars are supported"
Expand Down
25 changes: 15 additions & 10 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,32 @@
class BaseAstInterpreter:

# disabled by default
depth_threshold = sys.maxsize
bin_depth_threshold = sys.maxsize

def interpret(self, expr):
return self._do_interpret(expr)

# Private methods implementing visitor pattern

def _do_interpret(self, expr, depth=1, **kwargs):
def _do_interpret(self, expr, bin_depth=None, **kwargs):

# We track depth of the expression and if it exceeds specified limit,
# we will call hook.
if depth > self.depth_threshold and isinstance(expr, ast.BinExpr):
return self.depth_threshold_hook(expr, **kwargs)
# We track depth of the binary expressions and call a hook if it
# exceeds specified limit.
if isinstance(expr, ast.BinExpr):
bin_depth = bin_depth+1 if bin_depth is not None else 1

if bin_depth > self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs)
else:
bin_depth = 0

try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
return self._do_interpret(expr.expr, depth=depth+1, **kwargs)
return self._do_interpret(expr.expr, bin_depth=bin_depth, **kwargs)
raise
return handler(expr, depth=depth+1, **kwargs)
return handler(expr, bin_depth=bin_depth, **kwargs)

def _select_handler(self, expr):
handler_name = self._handler_name(type(expr))
Expand All @@ -45,7 +50,7 @@ def _handler_name(expr_tpe):
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()

def depth_threshold_hook(self, expr, **kwargs):
def bin_depth_threshold_hook(self, expr, **kwargs):
raise NotImplementedError


Expand Down Expand Up @@ -105,7 +110,7 @@ def interpret_vector_val(self, expr, **kwargs):
return self._cg.vector_init(nested)

# Default implementation. Simply adds new variable.
def depth_threshold_hook(self, expr, **kwargs):
def bin_depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class PythonInterpreter(interpreters.AstToCodeInterpreter):

# 93 may raise MemoryError, so use something close enough to it not to
# create unnecessary overhead.
depth_threshold = 80
bin_depth_threshold = 80

def __init__(self, indent=4, *args, **kwargs):
cg = PythonCodeGenerator(indent=indent)
Expand Down
11 changes: 4 additions & 7 deletions tests/interpreters/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def score(input):


class CustomPythonInterpreter(interpreters.PythonInterpreter):
depth_threshold = 2
bin_depth_threshold = 2


def test_depth_threshold_with_bin_expr():
Expand Down Expand Up @@ -205,16 +205,13 @@ def score(input):
if (1) == (1):
var0 = 1
else:
var1 = (1) == (1)
if var1:
if (1) == (1):
var0 = 1
else:
var2 = (1) == (1)
if var2:
if (1) == (1):
var0 = 1
else:
var3 = (1) == (1)
if var3:
if (1) == (1):
var0 = 1
else:
var0 = 1
Expand Down

0 comments on commit 7931b00

Please sign in to comment.