Skip to content

Commit

Permalink
Merge e8b673d into 2090a0a
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Jan 29, 2020
2 parents 2090a0a + e8b673d commit 666ca5d
Show file tree
Hide file tree
Showing 11 changed files with 177 additions and 221 deletions.
5 changes: 3 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ python:

env:
- TEST=API
- TEST=E2E LANG="c or python or java or go or javascript or r_lang"
- TEST=E2E LANG="c_sharp or visual_basic or powershell or php"
- TEST=E2E LANG="c or python or java or go or javascript or php"
- TEST=E2E LANG="c_sharp or visual_basic or powershell"
- TEST=E2E LANG="r_lang"

before_install:
- bash .travis/setup.sh
Expand Down
70 changes: 1 addition & 69 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,51 +104,11 @@ def _assemble_estimators(self, trees, split_idx):
if self._tree_limit:
trees = trees[:self._tree_limit]

trees_ast = [ast.SubroutineExpr(self._assemble_tree(t)) for t in trees]

# In a large tree we need to generate multiple subroutines to avoid
# java limitations https://github.com/BayesWitnesses/m2cgen/issues/103.
trees_num_leaves = [self._count_leaves(t) for t in trees]
if sum(trees_num_leaves) > self._leaves_cutoff_threshold:
return self._split_into_subroutines(trees_ast, trees_num_leaves)
else:
return trees_ast

def _split_into_subroutines(self, trees_ast, trees_num_leaves):
result = []
subroutine_trees = []
subroutine_sum_leaves = 0
for tree, num_leaves in zip(trees_ast, trees_num_leaves):
next_sum = subroutine_sum_leaves + num_leaves
if subroutine_trees and next_sum > self._leaves_cutoff_threshold:
# Exceeded the max leaves in the current subroutine,
# finalize this one and start a new one.
partial_result = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*subroutine_trees)

result.append(ast.SubroutineExpr(partial_result))

subroutine_trees = []
subroutine_sum_leaves = 0

subroutine_sum_leaves += num_leaves
subroutine_trees.append(tree)

if subroutine_trees:
partial_result = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*subroutine_trees)
result.append(ast.SubroutineExpr(partial_result))
return result
return [ast.SubroutineExpr(self._assemble_tree(t)) for t in trees]

def _assemble_tree(self, tree):
raise NotImplementedError

@staticmethod
def _count_leaves(trees):
raise NotImplementedError


class XGBoostTreeModelAssembler(BaseTreeBoostingAssembler):

Expand Down Expand Up @@ -204,20 +164,6 @@ def _assemble_child_tree(self, tree, child_id):
return self._assemble_tree(child)
assert False, "Unexpected child ID {}".format(child_id)

@staticmethod
def _count_leaves(tree):
queue = [tree]
num_leaves = 0

while queue:
tree = queue.pop()
if "leaf" in tree:
num_leaves += 1
elif "children" in tree:
for child in tree["children"]:
queue.append(child)
return num_leaves


class XGBoostLinearModelAssembler(BaseBoostingAssembler):

Expand Down Expand Up @@ -299,20 +245,6 @@ def _assemble_tree(self, tree):
self._assemble_tree(true_child),
self._assemble_tree(false_child))

@staticmethod
def _count_leaves(tree):
queue = [tree]
num_leaves = 0

while queue:
tree = queue.pop()
if "leaf_value" in tree:
num_leaves += 1
else:
queue.append(tree["left_child"])
queue.append(tree["right_child"])
return num_leaves


def _split_estimator_params_by_classes(values, n_classes):
# Splits are computed based on a comment
Expand Down
13 changes: 6 additions & 7 deletions m2cgen/assemblers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
class RandomForestModelAssembler(ModelAssembler):

def assemble(self):
coef = 1.0 / self.model.n_estimators
trees = self.model.estimators_

def assemble_tree_expr(t):
assembler = TreeModelAssembler(t)

return utils.apply_bin_op(
ast.SubroutineExpr(assembler.assemble()),
ast.NumVal(coef),
ast.BinNumOpType.MUL)
return ast.SubroutineExpr(assembler.assemble())

assembled_trees = [assemble_tree_expr(t) for t in trees]
return utils.apply_op_to_expressions(
ast.BinNumOpType.ADD, *assembled_trees)
return utils.apply_bin_op(
utils.apply_op_to_expressions(ast.BinNumOpType.ADD,
*assembled_trees),
ast.NumVal(1 / self.model.n_estimators),
ast.BinNumOpType.MUL)
32 changes: 31 additions & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __str__(self):
return "TanhExpr(" + args + ")"


class PowExpr(NumExpr, BinExpr):
class PowExpr(NumExpr):
def __init__(self, base_expr, exp_expr, to_reuse=False):
assert base_expr.output_size == 1, "Only scalars are supported"
assert exp_expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -230,3 +230,33 @@ def __init__(self, expr, to_reuse=False):
def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SubroutineExpr(" + args + ")"


NESTED_EXPRS_MAPPINGS = [
((BinExpr, CompExpr), lambda e: [e.left, e.right]),
(PowExpr, lambda e: [e.base_expr, e.exp_expr]),
(VectorVal, lambda e: e.exprs),
(IfExpr, lambda e: [e.test, e.body, e.orelse]),
((ExpExpr, TanhExpr, TransparentExpr), lambda e: [e.expr]),
]


def count_exprs(expr, exclude_list=None):
expr_tpe = type(expr)
excluded = tuple(exclude_list) if exclude_list else ()

init = 1
if issubclass(expr_tpe, excluded) or issubclass(expr_tpe, TransparentExpr):
init = 0

if isinstance(expr, (NumVal, FeatureRef)):
return init

for tpes, nested_f in NESTED_EXPRS_MAPPINGS:
if issubclass(expr_tpe, tpes):
return init + sum(map(
lambda e: count_exprs(e, exclude_list),
nested_f(expr)))

expr_tpe_name = expr_tpe.__name__
raise ValueError("Unexpected expression type {}".format(expr_tpe_name))
43 changes: 42 additions & 1 deletion m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import math

from m2cgen import ast
from m2cgen.interpreters import mixins
Expand All @@ -9,7 +10,13 @@

class JavaInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin,
mixins.SubroutinesAsFunctionsMixin):
mixins.SubroutinesAsFunctionsMixin,
mixins.BinExpressionDepthTrackingMixin):

# The below numbers have been determined experimentally and are subject
# to adjustments in future.
bin_depth_threshold = 100
ast_size_per_subroutine_threshold = 4600

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
Expand Down Expand Up @@ -55,7 +62,41 @@ def interpret(self, expr):

return top_cg.code

def interpret_subroutine_expr(self, expr, **kwargs):
return self._do_interpret(expr.expr, **kwargs)

# Required by SubroutinesAsFunctionsMixin to create new code generator for
# each subroutine.
def create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)

def bin_depth_threshold_hook(self, expr, **kwargs):
# The condition below is a sanity check to ensure that the expression
# is actually worth moving into a separate subroutine.
if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)
return self._cg.function_invocation(
function_name, self._feature_array_name)
else:
return self._do_interpret(expr, **kwargs)

def _pre_interpret_hook(self, expr, **kwargs):
if isinstance(expr, ast.BinExpr):
threshold = self._calc_bin_depth_threshold(expr)
self.bin_depth_threshold = min(threshold, self.bin_depth_threshold)
return super()._pre_interpret_hook(expr, **kwargs)

def _calc_bin_depth_threshold(self, expr):
# The logic below counts the number of non-binary expressions
# in a non-recursive branch of a binary expression to account
# for large tree-like models and adjust the bin depth threshold
# if necessary.
cnt = None
if not isinstance(expr.left, ast.BinExpr):
cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
elif not isinstance(expr.right, ast.BinExpr):
cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
if cnt and cnt < self.ast_size_per_subroutine_threshold:
return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
return self.bin_depth_threshold
2 changes: 1 addition & 1 deletion m2cgen/interpreters/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):

# We track depth of the binary expressions and call a hook if it
# reaches specified threshold .
if bin_depth == self.bin_depth_threshold:
if bin_depth >= self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs), kwargs

kwargs["bin_depth"] = bin_depth + 1
Expand Down
31 changes: 11 additions & 20 deletions tests/assemblers/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ def test_single_condition():
ast.BinNumExpr(
ast.SubroutineExpr(
ast.NumVal(1.0)),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
Expand All @@ -27,9 +24,9 @@ def test_single_condition():
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)

assert utils.cmp_exprs(actual, expected)

Expand All @@ -52,9 +49,6 @@ def test_two_conditions():
ast.CompOpType.LTE),
ast.NumVal(1.0),
ast.NumVal(2.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
Expand All @@ -63,9 +57,9 @@ def test_two_conditions():
ast.CompOpType.LTE),
ast.NumVal(2.0),
ast.NumVal(3.0))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)

assert utils.cmp_exprs(actual, expected)

Expand All @@ -79,8 +73,8 @@ def test_multi_class():
assembler = assemblers.RandomForestModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.BinVectorExpr(
ast.BinVectorNumExpr(
expected = ast.BinVectorNumExpr(
ast.BinVectorExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
Expand All @@ -93,9 +87,6 @@ def test_multi_class():
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinVectorNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
Expand All @@ -108,8 +99,8 @@ def test_multi_class():
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)

assert utils.cmp_exprs(actual, expected)
52 changes: 0 additions & 52 deletions tests/assemblers/test_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,58 +117,6 @@ def test_regression():
assert utils.cmp_exprs(actual, expected)


def test_leaves_cutoff_threshold():
estimator = lightgbm.LGBMClassifier(n_estimators=2, random_state=1,
max_depth=1)
utils.train_model_classification_binary(estimator)

assembler = assemblers.LightGBMModelAssembler(estimator,
leaves_cutoff_threshold=1)
actual = assembler.assemble()

sigmoid = ast.BinNumExpr(
ast.NumVal(1),
ast.BinNumExpr(
ast.NumVal(1),
ast.ExpExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
ast.NumVal(0),
ast.SubroutineExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)))),
ast.BinNumOpType.ADD),
ast.SubroutineExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)))),
ast.BinNumOpType.ADD)),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.DIV,
to_reuse=True)

expected = ast.VectorVal([
ast.BinNumExpr(ast.NumVal(1), sigmoid, ast.BinNumOpType.SUB),
sigmoid])

assert utils.cmp_exprs(actual, expected)


def test_regression_random_forest():
estimator = lightgbm.LGBMRegressor(boosting_type="rf", n_estimators=2,
random_state=1, max_depth=1,
Expand Down

0 comments on commit 666ca5d

Please sign in to comment.