Skip to content

Commit

Permalink
Refactor LinearAlgebraMixin; Revert not needed changes
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart committed Feb 6, 2019
1 parent 6bf2ade commit 145cd42
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 74 deletions.
104 changes: 56 additions & 48 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,65 @@

class BaseInterpreter:

with_vectors = False

# disabled by default
depth_threshold = sys.maxsize

def __init__(self, cg, feature_array_name="input"):
def __init__(self, cg):
self._cg = cg
self._feature_array_name = feature_array_name

def interpret(self, expr):
return self._do_interpret(expr)

# Default method implementations
# Private methods implementing visitor pattern

def _do_interpret(self, expr, depth=1, **kwargs):

# We track depth of the expression and if it exceeds specified limit,
# we will call hook. By default it will create a variable and store
# result of the expression in this variable. Sub-interpreters may
# override this behaviour.
if depth > self.depth_threshold and isinstance(expr, ast.BinExpr):
return self._depth_threshold_hook(expr, **kwargs)

try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
return self._do_interpret(expr.expr, depth=depth+1, **kwargs)
raise
return handler(expr, depth=depth+1, **kwargs)

def _select_handler(self, expr):
handler_name = self._handler_name(type(expr))
if hasattr(self, handler_name):
return getattr(self, handler_name)
raise NotImplementedError(
"No handler found for {}".format(type(expr).__name__))

@staticmethod
def _handler_name(expr_tpe):
expr_name = BaseInterpreter._normalize_expr_name(expr_tpe.__name__)
return "interpret_" + expr_name

@staticmethod
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()

# Default implementation. Simply adds new variable.
def _depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name


class Interpreter(BaseInterpreter):

with_vectors = False

def __init__(self, cg, feature_array_name="input"):
self._feature_array_name = feature_array_name
super().__init__(cg)

def interpret_if_expr(self, expr, if_var_name=None, **kwargs):
if if_var_name is not None:
Expand Down Expand Up @@ -67,50 +113,8 @@ def interpret_vector_val(self, expr, **kwargs):
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.vector_init(nested)

# Private methods implementing visitor pattern

def _do_interpret(self, expr, depth=1, **kwargs):

# We track depth of the expression and if it exceeds specified limit,
# we will call hook. By default it will create a variable and store
# result of the expression in this variable. Sub-interpreters may
# override this behaviour.
if depth > self.depth_threshold and isinstance(expr, ast.BinExpr):
return self._depth_threshold_hook(expr, **kwargs)

try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
return self._do_interpret(expr.expr, depth=depth+1, **kwargs)
raise
return handler(expr, depth=depth+1, **kwargs)

def _select_handler(self, expr):
handler_name = self._handler_name(type(expr))
if hasattr(self, handler_name):
return getattr(self, handler_name)
raise NotImplementedError(
"No handler found for {}".format(type(expr).__name__))

@staticmethod
def _handler_name(expr_tpe):
expr_name = BaseInterpreter._normalize_expr_name(expr_tpe.__name__)
return "interpret_" + expr_name

@staticmethod
def _normalize_expr_name(name):
return re.sub("(?!^)([A-Z]+)", r"_\1", name).lower()

# Default implementation. Simply adds new variable.
def _depth_threshold_hook(self, expr, **kwargs):
var_name = self._cg.add_var_declaration(expr.output_size)
result = self._do_interpret(expr, **kwargs)
self._cg.add_var_assignment(var_name, result, expr.output_size)
return var_name


class InterpreterWithLinearAlgebra(BaseInterpreter):
class LinearAlgebraMixin(BaseInterpreter):

with_linear_algebra = False

Expand Down Expand Up @@ -147,3 +151,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
self._do_interpret(expr.left, **kwargs),
self._do_interpret(expr.right, **kwargs),
*extra_func_args)


class InterpreterWithLinearAlgebra(LinearAlgebraMixin, Interpreter):
pass
4 changes: 2 additions & 2 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from m2cgen.interpreters.interpreter import BaseInterpreter
from m2cgen.interpreters.interpreter import Interpreter
from m2cgen.interpreters.python.code_generator import PythonCodeGenerator


class PythonInterpreter(BaseInterpreter):
class PythonInterpreter(Interpreter):

# 93 may raise MemoryError, so use something close enough to it not to
# create unnecessary overhead.
Expand Down
21 changes: 4 additions & 17 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,27 @@


# Set of helper functions to make parametrization less verbose.
def regression(model, is_fast=False):
def regression(model):
return (
model,
utils.train_model_regression,
REGRESSION,
is_fast,
)


def classification(model, is_fast=False):
def classification(model):
return (
model,
utils.train_model_classification,
CLASSIFICATION,
is_fast,
)


def classification_binary(model, is_fast=False):
def classification_binary(model):
return (
model,
utils.train_model_classification_binary,
CLASSIFICATION,
is_fast,
)


Expand Down Expand Up @@ -133,27 +130,17 @@ def classification_binary(model, is_fast=False):
classification_binary(
ensemble.RandomForestClassifier(**FOREST_PARAMS)),
classification_binary(ensemble.ExtraTreesClassifier(**FOREST_PARAMS)),
# This a special case of the HUGE model. We want to verify that
# even in such cases we generate code which works.
# regression(ensemble.RandomForestRegressor(n_estimators=100),
# is_fast=True),
],
# Following is the list of extra tests for languages/models which are
# not fully supported yet.
# <empty>
)
def test_e2e(estimator, executor_cls, model_trainer, is_fast_model, is_fast):
def test_e2e(estimator, executor_cls, model_trainer, is_fast):
X_test, y_pred_true = model_trainer(estimator)
executor = executor_cls(estimator)

# is_fast means that user used --flag.
# is_fast_model means that this model is explicitly specified to run fast.
is_fast = is_fast_model or is_fast

idxs_to_test = [0] if is_fast else range(len(X_test))

with executor.prepare_then_cleanup():
Expand Down
10 changes: 3 additions & 7 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@ def cartesian_e2e_params(executors_with_marks, models_with_trainers_with_marks,
prod = itertools.product(
executors_with_marks, models_with_trainers_with_marks)

for a1, a2 in prod:
executor, executor_mark = a1
model, trainer, trainer_mark, is_fast = a2

for (executor, executor_mark), (model, trainer, trainer_mark) in prod:
# Since we reuse the same model across multiple tests we want it
# to be clean.
model = clone(model)
Expand All @@ -145,11 +142,10 @@ def cartesian_e2e_params(executors_with_marks, models_with_trainers_with_marks,
type(model).__name__, executor_mark.name, trainer.__name__))

result_params.append(pytest.param(
model, executor, trainer, is_fast,
marks=[executor_mark, trainer_mark],
model, executor, trainer, marks=[executor_mark, trainer_mark],
))

param_names = "estimator,executor_cls,model_trainer,is_fast_model"
param_names = "estimator,executor_cls,model_trainer"

def wrap(func):

Expand Down

0 comments on commit 145cd42

Please sign in to comment.