Skip to content

Commit

Permalink
Merge 496245b into 5f5d99e
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Feb 6, 2019
2 parents 5f5d99e + 496245b commit 498f17e
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 68 deletions.
10 changes: 7 additions & 3 deletions m2cgen/ast.py
Expand Up @@ -13,6 +13,10 @@ def __str__(self):
return "FeatureRef(" + str(self.index) + ")"


class BinExpr(Expr):
pass


# Numeric Expressions.

class NumExpr(Expr):
Expand All @@ -34,7 +38,7 @@ class BinNumOpType(Enum):
DIV = '/'


class BinNumExpr(NumExpr):
class BinNumExpr(NumExpr, BinExpr):
def __init__(self, left, right, op):
assert left.output_size == 1, "Only scalars are supported"
assert right.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -68,7 +72,7 @@ def __str__(self):
return "VectorVal([" + args + "])"


class BinVectorExpr(VectorExpr):
class BinVectorExpr(VectorExpr, BinExpr):

def __init__(self, left, right, op):
assert left.output_size > 1, "Only vectors are supported"
Expand All @@ -85,7 +89,7 @@ def __str__(self):
return "BinVectorExpr(" + args + ")"


class BinVectorNumExpr(VectorExpr):
class BinVectorNumExpr(VectorExpr, BinExpr):

def __init__(self, left, right, op):
assert left.output_size > 1, "Only vectors are supported"
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/__init__.py
@@ -1,8 +1,14 @@
from .interpreter import AstToCodeInterpreter
from .interpreter import (
AstToCodeInterpreterWithLinearAlgebra)
from .java.interpreter import JavaInterpreter
from .python.interpreter import PythonInterpreter
from .c.interpreter import CInterpreter


__all__ = [
AstToCodeInterpreter,
AstToCodeInterpreterWithLinearAlgebra,
JavaInterpreter,
PythonInterpreter,
CInterpreter,
Expand Down
10 changes: 2 additions & 8 deletions m2cgen/interpreters/c/interpreter.py
Expand Up @@ -2,11 +2,11 @@

from m2cgen import ast
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import InterpreterWithLinearAlgebra
from m2cgen import interpreters
from m2cgen.interpreters.c.code_generator import CCodeGenerator


class CInterpreter(InterpreterWithLinearAlgebra):
class CInterpreter(interpreters.AstToCodeInterpreterWithLinearAlgebra):

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "add_vectors",
Expand All @@ -16,8 +16,6 @@ class CInterpreter(InterpreterWithLinearAlgebra):
ast.BinNumOpType.MUL: "mul_vector_number",
}

with_vectors = False

def __init__(self, indent=4, *args, **kwargs):
cg = CCodeGenerator(indent=indent)
super(CInterpreter, self).__init__(cg, *args, **kwargs)
Expand Down Expand Up @@ -57,10 +55,6 @@ def interpret(self, expr):

return self._cg.code

def interpret_vector_val(self, expr, **kwargs):
self.with_vectors = True
return super().interpret_vector_val(expr, **kwargs)

# Both methods supporting linear algebra do several things:
#
# 1. Call super method with extra parameters. Super method will return a
Expand Down
111 changes: 70 additions & 41 deletions m2cgen/interpreters/interpreter.py
@@ -1,18 +1,67 @@
import re
import sys

from m2cgen import ast


class BaseInterpreter:
class BaseAstInterpreter:

def __init__(self, cg, feature_array_name="input"):
self._cg = cg
self._feature_array_name = feature_array_name
# disabled by default
bin_depth_threshold = sys.maxsize

def interpret(self, expr):
return self._do_interpret(expr)

# Default method implementations
# Private methods implementing visitor pattern

def _do_interpret(self, expr, bin_depth=None, **kwargs):

# We track depth of the binary expressions and call a hook if it
# exceeds specified limit.
if isinstance(expr, ast.BinExpr):
bin_depth = bin_depth+1 if bin_depth is not None else 1

if bin_depth > self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs)
else:
bin_depth = 0

try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
return self._do_interpret(expr.expr, bin_depth=bin_depth,
**kwargs)
raise
return handler(expr, bin_depth=bin_depth, **kwargs)

def _select_handler(self, expr):
handler_name = self._handler_name(type(expr))
if hasattr(self, handler_name):
return getattr(self, handler_name)
raise NotImplementedError(
"No handler found for {}".format(type(expr).__name__))

@staticmethod
def _handler_name(expr_tpe):
expr_name = BaseAstInterpreter._normalize_expr_name(expr_tpe.__name__)
return "interpret_" + expr_name

@staticmethod
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()

def bin_depth_threshold_hook(self, expr, **kwargs):
raise NotImplementedError


class AstToCodeInterpreter(BaseAstInterpreter):

with_vectors = False

def __init__(self, cg, feature_array_name="input"):
self._cg = cg
self._feature_array_name = feature_array_name

def interpret_if_expr(self, expr, if_var_name=None, **kwargs):
if if_var_name is not None:
Expand All @@ -24,7 +73,7 @@ def handle_nested_expr(nested):
if isinstance(nested, ast.IfExpr):
self._do_interpret(nested, if_var_name=var_name, **kwargs)
else:
nested_result = self._do_interpret(nested)
nested_result = self._do_interpret(nested, **kwargs)
self._cg.add_var_assignment(var_name, nested_result,
nested.output_size)

Expand All @@ -38,15 +87,15 @@ def handle_nested_expr(nested):

def interpret_comp_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
left=self._do_interpret(expr.left, **kwargs),
op=expr.op.value,
right=self._do_interpret(expr.right))
right=self._do_interpret(expr.right, **kwargs))

def interpret_bin_num_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
left=self._do_interpret(expr.left, **kwargs),
op=expr.op.value,
right=self._do_interpret(expr.right))
right=self._do_interpret(expr.right, **kwargs))

def interpret_num_val(self, expr, **kwargs):
return self._cg.num_value(value=expr.value)
Expand All @@ -57,45 +106,26 @@ def interpret_feature_ref(self, expr, **kwargs):
index=expr.index)

def interpret_vector_val(self, expr, **kwargs):
self.with_vectors = True
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.vector_init(nested)

# Private methods implementing visitor pattern

def _do_interpret(self, expr, **kwargs):
try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
return self._do_interpret(expr.expr, **kwargs)
raise
return handler(expr, **kwargs)

def _select_handler(self, expr):
handler_name = self._handler_name(type(expr))
if hasattr(self, handler_name):
return getattr(self, handler_name)
raise NotImplementedError(
"No handler found for {}".format(type(expr).__name__))

@staticmethod
def _handler_name(expr_tpe):
expr_name = BaseInterpreter._normalize_expr_name(expr_tpe.__name__)
return "interpret_" + expr_name

@staticmethod
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()
# Default implementation. Simply adds new variable.
def bin_depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name


class InterpreterWithLinearAlgebra(BaseInterpreter):
class AstToCodeInterpreterWithLinearAlgebra(AstToCodeInterpreter):

with_linear_algebra = False

supported_bin_vector_ops = {}
supported_bin_vector_num_ops = {}

def interpret_bin_vector_expr(self, expr, **kwargs):
def interpret_bin_vector_expr(self, expr, extra_func_args=(), **kwargs):
if expr.op not in self.supported_bin_vector_ops:
raise NotImplementedError(
"Op {} is unsupported".format(expr.op.name))
Expand All @@ -104,14 +134,14 @@ def interpret_bin_vector_expr(self, expr, **kwargs):

function_name = self.supported_bin_vector_ops[expr.op]

extra_func_args = kwargs.pop("extra_func_args", [])
return self._cg.function_invocation(
function_name,
self._do_interpret(expr.left, **kwargs),
self._do_interpret(expr.right, **kwargs),
*extra_func_args)

def interpret_bin_vector_num_expr(self, expr, **kwargs):
def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
**kwargs):
if expr.op not in self.supported_bin_vector_num_ops:
raise NotImplementedError(
"Op {} is unsupported".format(expr.op.name))
Expand All @@ -120,7 +150,6 @@ def interpret_bin_vector_num_expr(self, expr, **kwargs):

function_name = self.supported_bin_vector_num_ops[expr.op]

extra_func_args = kwargs.pop("extra_func_args", [])
return self._cg.function_invocation(
function_name,
self._do_interpret(expr.left, **kwargs),
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/java/interpreter.py
@@ -1,8 +1,8 @@
import os

from m2cgen import ast
from m2cgen import interpreters
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import InterpreterWithLinearAlgebra
from m2cgen.interpreters.java.code_generator import JavaCodeGenerator

from collections import namedtuple
Expand All @@ -11,7 +11,7 @@
Subroutine = namedtuple('Subroutine', ['name', 'expr'])


class JavaInterpreter(InterpreterWithLinearAlgebra):
class JavaInterpreter(interpreters.AstToCodeInterpreterWithLinearAlgebra):

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
Expand Down
26 changes: 12 additions & 14 deletions m2cgen/interpreters/python/interpreter.py
@@ -1,10 +1,12 @@
from m2cgen.interpreters.interpreter import BaseInterpreter
from m2cgen import interpreters
from m2cgen.interpreters.python.code_generator import PythonCodeGenerator


class PythonInterpreter(BaseInterpreter):
class PythonInterpreter(interpreters.AstToCodeInterpreter):

with_numpy = False
# 93 may raise MemoryError, so use something close enough to it not to
# create unnecessary overhead.
bin_depth_threshold = 80

def __init__(self, indent=4, *args, **kwargs):
cg = PythonCodeGenerator(indent=indent)
Expand All @@ -19,23 +21,19 @@ def interpret(self, expr):
last_result = self._do_interpret(expr)
self._cg.add_return_statement(last_result)

if self.with_numpy:
if self.with_vectors:
self._cg.add_dependency("numpy", alias="np")

return self._cg.code

def interpret_vector_val(self, expr, **kwargs):
self.with_numpy = True
return super().interpret_vector_val(expr, **kwargs)

def interpret_bin_vector_expr(self, expr):
def interpret_bin_vector_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
left=self._do_interpret(expr.left, **kwargs),
op=expr.op.value,
right=self._do_interpret(expr.right))
right=self._do_interpret(expr.right, **kwargs))

def interpret_bin_vector_num_expr(self, expr):
def interpret_bin_vector_num_expr(self, expr, **kwargs):
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
left=self._do_interpret(expr.left, **kwargs),
op=expr.op.value,
right=self._do_interpret(expr.right))
right=self._do_interpret(expr.right, **kwargs))

0 comments on commit 498f17e

Please sign in to comment.