Skip to content

Commit

Permalink
refactor interpreters
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart committed Feb 6, 2019
1 parent 145cd42 commit 24770bb
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 35 deletions.
2 changes: 1 addition & 1 deletion m2cgen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .exporters import export_to_java, export_to_python, export_to_c
from m2cgen.exporters import export_to_java, export_to_python, export_to_c


__all__ = [
Expand Down
6 changes: 3 additions & 3 deletions m2cgen/assemblers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .linear import LinearModelAssembler
from .tree import TreeModelAssembler
from .ensemble import RandomForestModelAssembler
from m2cgen.assemblers.linear import LinearModelAssembler
from m2cgen.assemblers.tree import TreeModelAssembler
from m2cgen.assemblers.ensemble import RandomForestModelAssembler

__all__ = [
LinearModelAssembler,
Expand Down
12 changes: 9 additions & 3 deletions m2cgen/interpreters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from .java.interpreter import JavaInterpreter
from .python.interpreter import PythonInterpreter
from .c.interpreter import CInterpreter
from m2cgen.interpreters.interpreter import AstToCodeInterpreter
from m2cgen.interpreters.interpreter import (
AstToCodeInterpreterWithLinearAlgebra)
from m2cgen.interpreters.java.interpreter import JavaInterpreter
from m2cgen.interpreters.python.interpreter import PythonInterpreter
from m2cgen.interpreters.c.interpreter import CInterpreter


__all__ = [
AstToCodeInterpreter,
AstToCodeInterpreterWithLinearAlgebra,
JavaInterpreter,
PythonInterpreter,
CInterpreter,
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/c/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from m2cgen import ast
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import InterpreterWithLinearAlgebra
from m2cgen import interpreters
from m2cgen.interpreters.c.code_generator import CCodeGenerator


class CInterpreter(InterpreterWithLinearAlgebra):
class CInterpreter(interpreters.AstToCodeInterpreterWithLinearAlgebra):

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "add_vectors",
Expand Down
38 changes: 16 additions & 22 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
from m2cgen import ast


class BaseInterpreter:
class BaseAstInterpreter:

# disabled by default
depth_threshold = sys.maxsize

def __init__(self, cg):
self._cg = cg

def interpret(self, expr):
return self._do_interpret(expr)

Expand All @@ -20,11 +17,9 @@ def interpret(self, expr):
def _do_interpret(self, expr, depth=1, **kwargs):

# We track depth of the expression and if it exceeds specified limit,
# we will call hook. By default it will create a variable and store
# result of the expression in this variable. Sub-interpreters may
# override this behaviour.
# we will call hook.
if depth > self.depth_threshold and isinstance(expr, ast.BinExpr):
return self._depth_threshold_hook(expr, **kwargs)
return self.depth_threshold_hook(expr, **kwargs)

try:
handler = self._select_handler(expr)
Expand All @@ -43,28 +38,24 @@ def _select_handler(self, expr):

@staticmethod
def _handler_name(expr_tpe):
expr_name = BaseInterpreter._normalize_expr_name(expr_tpe.__name__)
expr_name = BaseAstInterpreter._normalize_expr_name(expr_tpe.__name__)
return "interpret_" + expr_name

@staticmethod
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()

# Default implementation. Simply adds new variable.
def _depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name
def depth_threshold_hook(self, expr, **kwargs):
raise NotImplementedError


class Interpreter(BaseInterpreter):
class AstToCodeInterpreter(BaseAstInterpreter):

with_vectors = False

def __init__(self, cg, feature_array_name="input"):
self._cg = cg
self._feature_array_name = feature_array_name
super().__init__(cg)

def interpret_if_expr(self, expr, if_var_name=None, **kwargs):
if if_var_name is not None:
Expand Down Expand Up @@ -113,8 +104,15 @@ def interpret_vector_val(self, expr, **kwargs):
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.vector_init(nested)

# Default implementation. Simply adds new variable.
def depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name


class LinearAlgebraMixin(BaseInterpreter):
class AstToCodeInterpreterWithLinearAlgebra(AstToCodeInterpreter):

with_linear_algebra = False

Expand Down Expand Up @@ -151,7 +149,3 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
self._do_interpret(expr.left, **kwargs),
self._do_interpret(expr.right, **kwargs),
*extra_func_args)


class InterpreterWithLinearAlgebra(LinearAlgebraMixin, Interpreter):
pass
4 changes: 2 additions & 2 deletions m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os

from m2cgen import ast
from m2cgen import interpreters
from m2cgen.interpreters import utils
from m2cgen.interpreters.interpreter import InterpreterWithLinearAlgebra
from m2cgen.interpreters.java.code_generator import JavaCodeGenerator

from collections import namedtuple
Expand All @@ -11,7 +11,7 @@
Subroutine = namedtuple('Subroutine', ['name', 'expr'])


class JavaInterpreter(InterpreterWithLinearAlgebra):
class JavaInterpreter(interpreters.AstToCodeInterpreterWithLinearAlgebra):

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from m2cgen.interpreters.interpreter import Interpreter
from m2cgen import interpreters
from m2cgen.interpreters.python.code_generator import PythonCodeGenerator


class PythonInterpreter(Interpreter):
class PythonInterpreter(interpreters.AstToCodeInterpreter):

# 93 may raise MemoryError, so use something close enough to it not to
# create unnecessary overhead.
Expand Down

0 comments on commit 24770bb

Please sign in to comment.