Skip to content

Commit

Permalink
Drop the Subroutine expression from AST
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Mar 16, 2020
1 parent 8961fe8 commit baca974
Show file tree
Hide file tree
Showing 22 changed files with 428 additions and 515 deletions.
4 changes: 2 additions & 2 deletions m2cgen/assemblers/boosting.py
Expand Up @@ -48,7 +48,7 @@ def _assemble_single_output(self, estimator_params,

result_ast = self._final_transform(tmp_ast)

return ast.SubroutineExpr(result_ast)
return result_ast

def _assemble_multi_class_output(self, estimator_params):
# Multi-class output is calculated based on discussion in
Expand Down Expand Up @@ -100,7 +100,7 @@ def _assemble_estimators(self, trees, split_idx):
if self._tree_limit:
trees = trees[:self._tree_limit]

return [ast.SubroutineExpr(self._assemble_tree(t)) for t in trees]
return [self._assemble_tree(t) for t in trees]

def _assemble_tree(self, tree):
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/assemblers/ensemble.py
Expand Up @@ -12,7 +12,7 @@ def assemble(self):
def assemble_tree_expr(t):
assembler = TreeModelAssembler(t)

return ast.SubroutineExpr(assembler.assemble())
return assembler.assemble()

assembled_trees = [assemble_tree_expr(t) for t in trees]
return utils.apply_bin_op(
Expand Down
3 changes: 1 addition & 2 deletions m2cgen/assemblers/linear.py
Expand Up @@ -19,8 +19,7 @@ def _build_ast(self):

exprs = []
for idx in range(coef.shape[0]):
exprs.append(ast.SubroutineExpr(
_linear_to_ast(coef[idx], intercept[idx])))
exprs.append(_linear_to_ast(coef[idx], intercept[idx]))
return ast.VectorVal(exprs)

def _get_intercept(self):
Expand Down
3 changes: 2 additions & 1 deletion m2cgen/assemblers/svm.py
Expand Up @@ -50,7 +50,8 @@ def _apply_kernel(self, support_vectors, to_reuse=False):
kernel_exprs = []
for v in support_vectors:
kernel = self._kernel_fun(v)
kernel_exprs.append(ast.SubroutineExpr(kernel, to_reuse=to_reuse))
kernel.to_reuse = to_reuse
kernel_exprs.append(kernel)
return kernel_exprs

def _get_supported_kernels(self):
Expand Down
21 changes: 2 additions & 19 deletions m2cgen/ast.py
Expand Up @@ -228,28 +228,12 @@ def __str__(self):
return "IfExpr(" + args + ")"


class TransparentExpr(CtrlExpr):
def __init__(self, expr):
self.expr = expr
self.output_size = expr.output_size


class SubroutineExpr(TransparentExpr):
def __init__(self, expr, to_reuse=False):
super().__init__(expr)
self.to_reuse = to_reuse

def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SubroutineExpr(" + args + ")"


NESTED_EXPRS_MAPPINGS = [
((BinExpr, CompExpr), lambda e: [e.left, e.right]),
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((ExpExpr, SqrtExpr, TanhExpr, TransparentExpr), lambda e: [e.expr]),
((ExpExpr, SqrtExpr, TanhExpr), lambda e: [e.expr]),
]


Expand All @@ -258,8 +242,7 @@ def count_exprs(expr, exclude_list=None):
excluded = tuple(exclude_list) if exclude_list else ()

init = 1
if issubclass(expr_type, excluded) or \
issubclass(expr_type, TransparentExpr):
if issubclass(expr_type, excluded):
init = 0

if isinstance(expr, (NumVal, FeatureRef)):
Expand Down
8 changes: 1 addition & 7 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -29,13 +29,7 @@ def _do_interpret(self, expr, to_reuse=None, **kwargs):
if result is not None:
return result

try:
handler = self._select_handler(expr)
except NotImplementedError:
if isinstance(expr, ast.TransparentExpr):
reuse = True if expr.to_reuse else None
return self._do_interpret(expr.expr, to_reuse=reuse, **kwargs)
raise
handler = self._select_handler(expr)

# Note that the reuse flag passed in the arguments has a higher
# precedence than one specified in the expression. One use case for
Expand Down
94 changes: 44 additions & 50 deletions tests/assemblers/test_ensemble.py
Expand Up @@ -14,16 +14,14 @@ def test_single_condition():

expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.SubroutineExpr(
ast.NumVal(1.0)),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(1.0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand All @@ -41,22 +39,20 @@ def test_two_conditions():

expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.NumVal(2.0),
ast.NumVal(3.0))),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.NumVal(2.0),
ast.NumVal(3.0)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand All @@ -75,30 +71,28 @@ def test_multi_class():

expected = ast.BinVectorNumExpr(
ast.BinVectorExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]))),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]))),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)])),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)])),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand Down
129 changes: 59 additions & 70 deletions tests/assemblers/test_lightgbm.py
Expand Up @@ -19,28 +19,25 @@ def test_binary_classification():
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673))),
ast.BinNumOpType.ADD)),
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
Expand All @@ -62,12 +59,10 @@ def test_multi_class():
actual = assembler.assemble()

exponent = ast.ExpExpr(
ast.SubroutineExpr(
ast.BinNumExpr(
ast.NumVal(0.0),
ast.SubroutineExpr(
ast.NumVal(-1.0986122886681098)),
ast.BinNumOpType.ADD)),
ast.BinNumExpr(
ast.NumVal(0.0),
ast.NumVal(-1.0986122886681098),
ast.BinNumOpType.ADD),
to_reuse=True)

exponent_sum = ast.BinNumExpr(
Expand All @@ -91,28 +86,25 @@ def test_regression():
assembler = assemblers.LightGBMModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.SubroutineExpr(
expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.CompOpType.GT),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594))),
ast.BinNumOpType.ADD))
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.918),
ast.CompOpType.GT),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084)),
ast.BinNumOpType.ADD),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)

Expand All @@ -126,30 +118,27 @@ def test_regression_random_forest():
assembler = assemblers.LightGBMModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.SubroutineExpr(
expected = ast.BinNumExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.954000000000001),
ast.CompOpType.GT),
ast.NumVal(37.24347877367631),
ast.NumVal(19.936999995530854))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.971500000000001),
ast.CompOpType.GT),
ast.NumVal(38.48600037864964),
ast.NumVal(20.183783757300255))),
ast.NumVal(0),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.954000000000001),
ast.CompOpType.GT),
ast.NumVal(37.24347877367631),
ast.NumVal(19.936999995530854)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL))
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.971500000000001),
ast.CompOpType.GT),
ast.NumVal(38.48600037864964),
ast.NumVal(20.183783757300255)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)

assert utils.cmp_exprs(actual, expected)

0 comments on commit baca974

Please sign in to comment.