Skip to content

Commit

Permalink
Merge 482a4c9 into 2090a0a
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Jan 28, 2020
2 parents 2090a0a + 482a4c9 commit 7652955
Show file tree
Hide file tree
Showing 33 changed files with 375 additions and 224 deletions.
5 changes: 3 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ python:

env:
- TEST=API
- TEST=E2E LANG="c or python or java or go or javascript or r_lang"
- TEST=E2E LANG="c or python or java or go or javascript"
- TEST=E2E LANG="c_sharp or visual_basic or powershell or php"
- TEST=E2E LANG="r_lang"

before_install:
- bash .travis/setup.sh
Expand All @@ -18,4 +19,4 @@ install:
- pip install -r requirements-test.txt

script:
- bash .travis/test.sh
- travis_wait 30 bash .travis/test.sh
70 changes: 1 addition & 69 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,51 +104,11 @@ def _assemble_estimators(self, trees, split_idx):
if self._tree_limit:
trees = trees[:self._tree_limit]

trees_ast = [ast.SubroutineExpr(self._assemble_tree(t)) for t in trees]

# In a large tree we need to generate multiple subroutines to avoid
# java limitations https://github.com/BayesWitnesses/m2cgen/issues/103.
trees_num_leaves = [self._count_leaves(t) for t in trees]
if sum(trees_num_leaves) > self._leaves_cutoff_threshold:
return self._split_into_subroutines(trees_ast, trees_num_leaves)
else:
return trees_ast

def _split_into_subroutines(self, trees_ast, trees_num_leaves):
result = []
subroutine_trees = []
subroutine_sum_leaves = 0
for tree, num_leaves in zip(trees_ast, trees_num_leaves):
next_sum = subroutine_sum_leaves + num_leaves
if subroutine_trees and next_sum > self._leaves_cutoff_threshold:
# Exceeded the max leaves in the current subroutine,
# finalize this one and start a new one.
partial_result = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*subroutine_trees)

result.append(ast.SubroutineExpr(partial_result))

subroutine_trees = []
subroutine_sum_leaves = 0

subroutine_sum_leaves += num_leaves
subroutine_trees.append(tree)

if subroutine_trees:
partial_result = utils.apply_op_to_expressions(
ast.BinNumOpType.ADD,
*subroutine_trees)
result.append(ast.SubroutineExpr(partial_result))
return result
return [ast.SubroutineExpr(self._assemble_tree(t)) for t in trees]

def _assemble_tree(self, tree):
raise NotImplementedError

@staticmethod
def _count_leaves(trees):
raise NotImplementedError


class XGBoostTreeModelAssembler(BaseTreeBoostingAssembler):

Expand Down Expand Up @@ -204,20 +164,6 @@ def _assemble_child_tree(self, tree, child_id):
return self._assemble_tree(child)
assert False, "Unexpected child ID {}".format(child_id)

@staticmethod
def _count_leaves(tree):
queue = [tree]
num_leaves = 0

while queue:
tree = queue.pop()
if "leaf" in tree:
num_leaves += 1
elif "children" in tree:
for child in tree["children"]:
queue.append(child)
return num_leaves


class XGBoostLinearModelAssembler(BaseBoostingAssembler):

Expand Down Expand Up @@ -299,20 +245,6 @@ def _assemble_tree(self, tree):
self._assemble_tree(true_child),
self._assemble_tree(false_child))

@staticmethod
def _count_leaves(tree):
queue = [tree]
num_leaves = 0

while queue:
tree = queue.pop()
if "leaf_value" in tree:
num_leaves += 1
else:
queue.append(tree["left_child"])
queue.append(tree["right_child"])
return num_leaves


def _split_estimator_params_by_classes(values, n_classes):
# Splits are computed based on a comment
Expand Down
13 changes: 6 additions & 7 deletions m2cgen/assemblers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
class RandomForestModelAssembler(ModelAssembler):

def assemble(self):
coef = 1.0 / self.model.n_estimators
trees = self.model.estimators_

def assemble_tree_expr(t):
assembler = TreeModelAssembler(t)

return utils.apply_bin_op(
ast.SubroutineExpr(assembler.assemble()),
ast.NumVal(coef),
ast.BinNumOpType.MUL)
return ast.SubroutineExpr(assembler.assemble())

assembled_trees = [assemble_tree_expr(t) for t in trees]
return utils.apply_op_to_expressions(
ast.BinNumOpType.ADD, *assembled_trees)
return utils.apply_bin_op(
utils.apply_op_to_expressions(ast.BinNumOpType.ADD,
*assembled_trees),
ast.NumVal(self.model.n_estimators),
ast.BinNumOpType.DIV)
42 changes: 41 additions & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __str__(self):
return "TanhExpr(" + args + ")"


class PowExpr(NumExpr, BinExpr):
class PowExpr(NumExpr):
def __init__(self, base_expr, exp_expr, to_reuse=False):
assert base_expr.output_size == 1, "Only scalars are supported"
assert exp_expr.output_size == 1, "Only scalars are supported"
Expand Down Expand Up @@ -230,3 +230,43 @@ def __init__(self, expr, to_reuse=False):
def __str__(self):
args = ",".join([str(self.expr), "to_reuse=" + str(self.to_reuse)])
return "SubroutineExpr(" + args + ")"


def count_exprs(expr, exclude_list=None):
init = 1
excluded = exclude_list if exclude_list else {}
if next(filter(lambda t: issubclass(type(expr), t), excluded), None):
init = 0

if isinstance(expr, (NumVal, FeatureRef)):
return init

if isinstance(expr, (ExpExpr, TanhExpr)):
return count_exprs(expr.expr, exclude_list) + init

if isinstance(expr, PowExpr):
nested = count_exprs(expr.base_expr, exclude_list) + \
count_exprs(expr.exp_expr, exclude_list)
return nested + init

bin_exprs = (BinNumExpr, BinVectorExpr, BinVectorNumExpr, CompExpr)
if isinstance(expr, bin_exprs):
nested = count_exprs(expr.left, exclude_list) + \
count_exprs(expr.right, exclude_list)
return nested + init

if isinstance(expr, VectorVal):
return sum([count_exprs(e, exclude_list) for e in expr.exprs]) + init

if isinstance(expr, IfExpr):
nested = sum([
count_exprs(expr.test, exclude_list),
count_exprs(expr.body, exclude_list),
count_exprs(expr.orelse, exclude_list)])
return nested + init

if isinstance(expr, TransparentExpr):
return count_exprs(expr.expr, exclude_list)

expr_tpe_name = type(expr).__name__
raise ValueError("Unexpected expression type {}".format(expr_tpe_name))
1 change: 1 addition & 0 deletions m2cgen/interpreters/c/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class CInterpreter(ToCodeInterpreter,

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mul_vector_number",
ast.BinNumOpType.DIV: "div_vector_number",
}

exponent_function_name = "exp"
Expand Down
4 changes: 4 additions & 0 deletions m2cgen/interpreters/c/linear_algebra.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,8 @@ void add_vectors(double *v1, double *v2, int size, double *result) {
void mul_vector_number(double *v1, double num, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] * num;
}
void div_vector_number(double *v1, double num, int size, double *result) {
for(int i = 0; i < size; ++i)
result[i] = v1[i] / num;
}
1 change: 1 addition & 0 deletions m2cgen/interpreters/c_sharp/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class CSharpInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin):

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "MulVectorNumber",
ast.BinNumOpType.DIV: "DivVectorNumber",
}

exponent_function_name = "Exp"
Expand Down
7 changes: 7 additions & 0 deletions m2cgen/interpreters/c_sharp/linear_algebra.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,10 @@
}
return result;
}
private static double[] DivVectorNumber(double[] v1, double num) {
double[] result = new double[v1.Length];
for (int i = 0; i < v1.Length; ++i) {
result[i] = v1[i] / num;
}
return result;
}
1 change: 1 addition & 0 deletions m2cgen/interpreters/go/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class GoInterpreter(ToCodeInterpreter,

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mulVectorNumber",
ast.BinNumOpType.DIV: "divVectorNumber",
}

exponent_function_name = "math.Exp"
Expand Down
9 changes: 8 additions & 1 deletion m2cgen/interpreters/go/linear_algebra.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,11 @@ func mulVectorNumber(v1 []float64, num float64) []float64 {
result[i] = v1[i] * num
}
return result
}
}
func divVectorNumber(v1 []float64, num float64) []float64 {
result := make([]float64, len(v1))
for i := 0; i < len(v1); i++ {
result[i] = v1[i] / num
}
return result
}
38 changes: 37 additions & 1 deletion m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import math

from m2cgen import ast
from m2cgen.interpreters import mixins
Expand All @@ -9,14 +10,19 @@

class JavaInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin,
mixins.SubroutinesAsFunctionsMixin):
mixins.SubroutinesAsFunctionsMixin,
mixins.BinExpressionDepthTrackingMixin):

bin_depth_threshold = 100
ast_size_per_subroutine_threshold = 4600

supported_bin_vector_ops = {
ast.BinNumOpType.ADD: "addVectors",
}

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mulVectorNumber",
ast.BinNumOpType.DIV: "divVectorNumber",
}

exponent_function_name = "Math.exp"
Expand Down Expand Up @@ -55,7 +61,37 @@ def interpret(self, expr):

return top_cg.code

def interpret_subroutine_expr(self, expr, **kwargs):
return self._do_interpret(expr.expr, **kwargs)

# Required by SubroutinesAsFunctionsMixin to create new code generator for
# each subroutine.
def create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)

def bin_depth_threshold_hook(self, expr, **kwargs):
# The condition below is a sanity check to ensure that the expression
# is actually worth moving into a separate subroutine.
if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)
return self._cg.function_invocation(
function_name, self._feature_array_name)
else:
return self._do_interpret(expr, **kwargs)

def _pre_interpret_hook(self, expr, **kwargs):
if isinstance(expr, ast.BinExpr):
threshold = self._calc_bin_depth_threshold(expr)
self.bin_depth_threshold = min(threshold, self.bin_depth_threshold)
return super()._pre_interpret_hook(expr, **kwargs)

def _calc_bin_depth_threshold(self, expr):
cnt = None
if not isinstance(expr.left, ast.BinExpr):
cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
elif not isinstance(expr.right, ast.BinExpr):
cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
if cnt and cnt < self.ast_size_per_subroutine_threshold:
return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
return self.bin_depth_threshold
10 changes: 9 additions & 1 deletion m2cgen/interpreters/java/linear_algebra.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ public static double[] addVectors(double[] v1, double[] v2) {

return result;
}

public static double[] mulVectorNumber(double[] v1, double num) {
double[] result = new double[v1.length];

Expand All @@ -17,3 +16,12 @@ public static double[] mulVectorNumber(double[] v1, double num) {

return result;
}
public static double[] divVectorNumber(double[] v1, double num) {
double[] result = new double[v1.length];

for (int i = 0; i < v1.length; i++) {
result[i] = v1[i] / num;
}

return result;
}
1 change: 1 addition & 0 deletions m2cgen/interpreters/javascript/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class JavascriptInterpreter(ToCodeInterpreter,

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mulVectorNumber",
ast.BinNumOpType.DIV: "divVectorNumber",
}

exponent_function_name = "Math.exp"
Expand Down
10 changes: 10 additions & 0 deletions m2cgen/interpreters/javascript/linear_algebra.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ function mulVectorNumber(v1, num) {

return result;
}

function divVectorNumber(v1, num) {
let result = new Array(v1.length);

for (let i = 0; i < v1.length; i++) {
result[i] = v1[i] / num;
}

return result;
}
2 changes: 1 addition & 1 deletion m2cgen/interpreters/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):

# We track depth of the binary expressions and call a hook if it
# reaches specified threshold .
if bin_depth == self.bin_depth_threshold:
if bin_depth >= self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs), kwargs

kwargs["bin_depth"] = bin_depth + 1
Expand Down
1 change: 1 addition & 0 deletions m2cgen/interpreters/php/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ class PhpInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin):

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "mul_vector_number",
ast.BinNumOpType.DIV: "div_vector_number",
}

exponent_function_name = "exp"
Expand Down
7 changes: 7 additions & 0 deletions m2cgen/interpreters/php/linear_algebra.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,10 @@ function mul_vector_number(array $v1, $num) {
}
return $result;
}
function div_vector_number(array $v1, $num) {
$result = array();
for ($i = 0; $i < count($v1); ++$i) {
$result[] = $v1[$i] / $num;
}
return $result;
}
1 change: 1 addition & 0 deletions m2cgen/interpreters/powershell/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PowershellInterpreter(ToCodeInterpreter,

supported_bin_vector_num_ops = {
ast.BinNumOpType.MUL: "Mul-Vector-Number",
ast.BinNumOpType.DIV: "Div-Vector-Number",
}

exponent_function_name = "[math]::Exp"
Expand Down

0 comments on commit 7652955

Please sign in to comment.