Skip to content

Commit

Permalink
RandomForestClassifier for python (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart authored and izeigerman committed Jan 30, 2019
1 parent 34aca6a commit 426c4b3
Show file tree
Hide file tree
Showing 19 changed files with 212 additions and 27 deletions.
5 changes: 2 additions & 3 deletions m2cgen/assemblers/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@


class RandomForestModelAssembler(ModelAssembler):
def __init__(self, model):
super().__init__(model)

def assemble(self):
coef = 1.0 / self.model.n_estimators
trees = self.model.estimators_

def assemble_tree_expr(t):
assembler = TreeModelAssembler(t)
return ast.BinNumExpr(

return utils.apply_bin_op(
ast.SubroutineExpr(assembler.assemble()),
ast.NumVal(coef),
ast.BinNumOpType.MUL)
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/assemblers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def _build_ast(self):
for idx in range(coef.shape[0]):
exprs.append(ast.SubroutineExpr(
_linear_to_ast(coef[idx], intercept[idx])))
return ast.VectorExpr(exprs)
return ast.VectorVal(exprs)


def _linear_to_ast(coef, intercept):
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/assemblers/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _assemble_leaf(self, node_id):
score_sum = scores.sum() or 1.0
for s in scores:
outputs.append(ast.NumVal(s / score_sum))
return ast.VectorExpr(outputs)
return ast.VectorVal(outputs)
else:
assert len(scores) == 1, "Unexpected number of outputs"
return ast.NumVal(scores[0])
Expand Down
29 changes: 26 additions & 3 deletions m2cgen/assemblers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,29 @@ def lte(l, r):
return ast.CompExpr(l, r, ast.CompOpType.LTE)


BIN_EXPR_CLASSES = {
(False, False): ast.BinNumExpr,
(True, True): ast.BinVectorExpr,
(True, False): ast.BinVectorNumExpr,
}


def apply_bin_op(left, right, op):
"""
Finds binary expression class suitable for combination of left and right
expressions depending on whether their output is scalar or vector and
creates instance of this expression with specified operation.
"""
exr_class = BIN_EXPR_CLASSES.get(
(left.is_vector_output, right.is_vector_output))
if exr_class is None:
# change the positions of left and right
left, right = right, left
exr_class = ast.BinVectorNumExpr

return exr_class(left, right, op)


def apply_op_to_expressions(op, *exprs):
if len(exprs) < 2:
raise ValueError("At least two expressions are required")
Expand All @@ -18,10 +41,10 @@ def _inner(current_expr, *rest_exprs):
if not rest_exprs:
return current_expr

return _inner(ast.BinNumExpr(current_expr, rest_exprs[0], op),
*rest_exprs[1:])
return _inner(
apply_bin_op(current_expr, rest_exprs[0], op), *rest_exprs[1:])

return _inner(ast.BinNumExpr(exprs[0], exprs[1], op), *exprs[2:])
return _inner(apply_bin_op(exprs[0], exprs[1], op), *exprs[2:])


def to_1d_array(var):
Expand Down
41 changes: 38 additions & 3 deletions m2cgen/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,53 @@ def __str__(self):
return "BinNumExpr(" + args + ")"


class VectorExpr(NumExpr):
# Vector Expressions.

class VectorExpr(Expr):
is_vector_output = True


class VectorVal(VectorExpr):

def __init__(self, exprs):
assert all(map(lambda e: not e.is_vector_output, exprs)), (
"All expressions for VectorExpr must be scalar")
"All expressions for VectorVal must be scalar")

self.exprs = exprs

def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return "VectorExpr([" + args + "])"
return "VectorVal([" + args + "])"


class BinVectorExpr(VectorExpr):

def __init__(self, left, right, op):
assert left.is_vector_output, "Only vectors are supported"
assert right.is_vector_output, "Only vectors are supported"

self.left = left
self.right = right
self.op = op

def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorExpr(" + args + ")"


class BinVectorNumExpr(VectorExpr):

def __init__(self, left, right, op):
assert left.is_vector_output, "Only vectors are supported"
assert not right.is_vector_output, "Only scalars are supported"

self.left = left
self.right = right
self.op = op

def __str__(self):
args = ",".join([str(self.left), str(self.right), self.op.name])
return "BinVectorNumExpr(" + args + ")"


# Boolean Expressions.
Expand Down
1 change: 1 addition & 0 deletions m2cgen/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class BaseExporter:
"DecisionTreeRegressor": assemblers.TreeModelAssembler,
"DecisionTreeClassifier": assemblers.TreeModelAssembler,
"RandomForestRegressor": assemblers.RandomForestModelAssembler,
"RandomForestClassifier": assemblers.RandomForestModelAssembler,
}

def __init__(self, model):
Expand Down
3 changes: 3 additions & 0 deletions m2cgen/interpreters/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def add_code_lines(self, lines):
for l in lines:
self.add_code_line(l)

def prepend_code_line(self, line):
self.code = line + "\n" + self.code

# Following statements compute expressions using templates AND add
# it to the result.

Expand Down
4 changes: 2 additions & 2 deletions m2cgen/interpreters/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def interpret_feature_ref(self, expr, **kwargs):
array_name=self._feature_array_name,
index=expr.index)

def interpret_vector_expr(self, expr, **kwargs):
def interpret_vector_val(self, expr, **kwargs):
nested = [self._do_interpret(expr, **kwargs) for expr in expr.exprs]
return self._cg.array_init(nested)
return self._cg.vector_init(nested)

# Private methods implementing visitor pattern

Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/java/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,6 @@ def method_definition(self, name, args, is_vector_output,
def method_invocation(self, method_name, *args):
return method_name + "(" + ", ".join(args) + ")"

def array_init(self, values):
def vector_init(self, values):
return "new " + self.vector_output_type + (
" {" + ", ".join(values) + "}")
10 changes: 8 additions & 2 deletions m2cgen/interpreters/python/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,11 @@ def function_definition(self, name, args):
self.add_function_def(name, args)
yield

def array_init(self, values):
return "[" + ", ".join(values) + "]"
def vector_init(self, values):
return "np.asarray([" + ", ".join(values) + "])"

def add_dependency(self, dep, alias=None):
dep_str = "import " + dep
if alias:
dep_str += " as " + alias
super().prepend_code_line(dep_str)
23 changes: 23 additions & 0 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

class PythonInterpreter(BaseInterpreter):

with_vectors = False

def __init__(self, indent=4, *args, **kwargs):
cg = PythonCodeGenerator(indent=indent)
super(PythonInterpreter, self).__init__(cg, *args, **kwargs)
Expand All @@ -17,6 +19,27 @@ def interpret(self, expr):
last_result = self._do_interpret(expr)
self._cg.add_return_statement(last_result)

if self.with_vectors:
self._cg.add_dependency("numpy", alias="np")

return [
("", self._cg.code),
]

def interpret_vector_val(self, expr, **kwargs):
self.with_vectors = True
return super().interpret_vector_val(expr, **kwargs)

def interpret_bin_vector_expr(self, expr):
self.with_vectors = True
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
op=expr.op.value,
right=self._do_interpret(expr.right))

def interpret_bin_vector_num_expr(self, expr):
self.with_vectors = True
return self._cg.infix_expression(
left=self._do_interpret(expr.left),
op=expr.op.value,
right=self._do_interpret(expr.right))
45 changes: 45 additions & 0 deletions tests/assemblers/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,48 @@ def test_two_conditions():
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)


def test_multi_class():
estimator = ensemble.RandomForestClassifier(
n_estimators=2, random_state=13)

estimator.fit([[1], [2], [3]], [1, -1, 1])

assembler = assemblers.RandomForestModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.BinVectorExpr(
ast.BinVectorNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinVectorNumExpr(
ast.SubroutineExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]))),
ast.NumVal(0.5),
ast.BinNumOpType.MUL),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
2 changes: 1 addition & 1 deletion tests/assemblers/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_multi_class():
assembler = assemblers.LinearModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.VectorExpr([
expected = ast.VectorVal([
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
Expand Down
6 changes: 3 additions & 3 deletions tests/assemblers/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,18 @@ def test_multi_class():
ast.FeatureRef(0),
ast.NumVal(1.5),
ast.CompOpType.LTE),
ast.VectorExpr([
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)]),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(0),
ast.NumVal(2.5),
ast.CompOpType.LTE),
ast.VectorExpr([
ast.VectorVal([
ast.NumVal(1.0),
ast.NumVal(0.0)]),
ast.VectorExpr([
ast.VectorVal([
ast.NumVal(0.0),
ast.NumVal(1.0)])))

Expand Down
File renamed without changes.
14 changes: 14 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ def exec_e2e_test(estimator, executor_cls, model_trainer, is_fast):
utils.train_model_classification_binary,
marks=[PYTHON, CLASSIFICATION],
),
pytest.param(
ensemble.RandomForestClassifier(n_estimators=10,
random_state=RANDOM_SEED),
executors.PythonExecutor,
utils.train_model_classification_binary,
marks=[PYTHON, CLASSIFICATION],
),
pytest.param(
ensemble.RandomForestClassifier(n_estimators=10,
random_state=RANDOM_SEED),
executors.PythonExecutor,
utils.train_model_classification,
marks=[PYTHON, CLASSIFICATION],
),
])
def test_e2e(estimator, executor_cls, model_trainer, is_fast):
exec_e2e_test(estimator, executor_cls, model_trainer, is_fast)
4 changes: 2 additions & 2 deletions tests/interpreters/test_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def test_multi_output():
ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.VectorExpr([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorExpr([ast.NumVal(3), ast.NumVal(4)])))
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))

expected_code = """
public class Model {
Expand Down
41 changes: 37 additions & 4 deletions tests/interpreters/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,50 @@ def test_multi_output():
ast.NumVal(1),
ast.NumVal(1),
ast.CompOpType.EQ),
ast.VectorExpr([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorExpr([ast.NumVal(3), ast.NumVal(4)])))
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])))

expected_code = """
import numpy as np
def score(input):
if (1) == (1):
var0 = [1, 2]
var0 = np.asarray([1, 2])
else:
var0 = [3, 4]
var0 = np.asarray([3, 4])
return var0
"""

interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)


def test_bin_vector_expr():
expr = ast.BinVectorExpr(
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]),
ast.BinNumOpType.MUL)

interpreter = interpreters.PythonInterpreter()

expected_code = """
import numpy as np
def score(input):
return (np.asarray([1, 2])) * (np.asarray([3, 4]))
"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)


def test_bin_vector_num_expr():
expr = ast.BinVectorNumExpr(
ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]),
ast.NumVal(1),
ast.BinNumOpType.MUL)

interpreter = interpreters.PythonInterpreter()

expected_code = """
import numpy as np
def score(input):
return (np.asarray([1, 2])) * (1)
"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)

0 comments on commit 426c4b3

Please sign in to comment.