Skip to content

Commit

Permalink
Merge 51a7713 into 5d9c3a2
Browse files Browse the repository at this point in the history
  • Loading branch information
krinart committed Jan 29, 2019
2 parents 5d9c3a2 + 51a7713 commit 19fee6f
Show file tree
Hide file tree
Showing 14 changed files with 162 additions and 123 deletions.
27 changes: 12 additions & 15 deletions m2cgen/assemblers/linear.py
@@ -1,4 +1,3 @@
import numpy as np
from m2cgen import ast
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler
Expand All @@ -10,25 +9,23 @@ def assemble(self):
return self._build_ast()

def _build_ast(self):
coef = self.model.coef_
intercept = self.model.intercept_
if isinstance(coef, np.ndarray) and len(coef.shape) == 2:
if coef.shape[0] == 1:
return _linear_to_ast(coef[0], intercept[0])
else:
exprs = []
for idx in range(coef.shape[0]):
exprs.append(ast.SubroutineExpr(
_linear_to_ast(coef[idx], intercept[idx])))
return ast.ArrayExpr(exprs)
else:
return _linear_to_ast(coef, intercept)
coef = utils.to_2d_array(self.model.coef_)
intercept = utils.to_1d_array(self.model.intercept_)

if coef.shape[0] == 1:
return _linear_to_ast(coef[0], intercept[0])

exprs = []
for idx in range(coef.shape[0]):
exprs.append(ast.SubroutineExpr(
_linear_to_ast(coef[idx], intercept[idx])))
return ast.VectorExpr(exprs)


def _linear_to_ast(coef, intercept):
feature_weight_mul_ops = []

for (index, value) in enumerate(coef):
for index, value in enumerate(coef):
feature_weight_mul_ops.append(
utils.mul(ast.FeatureRef(index), ast.NumVal(value)))

Expand Down
8 changes: 4 additions & 4 deletions m2cgen/assemblers/tree.py
Expand Up @@ -12,9 +12,9 @@ class TreeModelAssembler(ModelAssembler):
def __init__(self, model):
super().__init__(model)
self._tree = model.tree_
self._is_multi_output = False
self._is_vector_output = False
if isinstance(self.model, tree.DecisionTreeClassifier):
self._is_multi_output = self.model.n_classes_ > 1
self._is_vector_output = self.model.n_classes_ > 1

def assemble(self):
return self._assemble_node(0)
Expand All @@ -35,11 +35,11 @@ def _assemble_branch(self, node_id):

def _assemble_leaf(self, node_id):
scores = self._tree.value[node_id][0]
if self._is_multi_output:
if self._is_vector_output:
outputs = []
for s in scores:
outputs.append(ast.NumVal(s))
return ast.ArrayExpr(outputs)
return ast.VectorExpr(outputs)
else:
assert len(scores) == 1, "Unexpected number of outputs"
return ast.NumVal(scores[0])
Expand Down
13 changes: 13 additions & 0 deletions m2cgen/assemblers/utils.py
@@ -1,3 +1,4 @@
import numpy as np
from m2cgen import ast


Expand All @@ -21,3 +22,15 @@ def _inner(current_expr, *rest_exprs):
*rest_exprs[1:])

return _inner(ast.BinNumExpr(exprs[0], exprs[1], op), *exprs[2:])


def to_1d_array(var):
return np.reshape(np.asarray(var), (np.size(var)))


def to_2d_array(var):
if len(np.shape(var)) == 2:
x, y = var.shape
else:
x, y = 1, np.size(var)
return np.reshape(np.asarray(var), (x, y))
33 changes: 14 additions & 19 deletions m2cgen/ast.py
Expand Up @@ -2,7 +2,7 @@


class Expr:
is_multi_output = False
is_vector_output = False


class FeatureRef(Expr):
Expand Down Expand Up @@ -36,8 +36,8 @@ class BinNumOpType(Enum):

class BinNumExpr(NumExpr):
def __init__(self, left, right, op):
assert not left.is_multi_output, "Only scalars are supported"
assert not right.is_multi_output, "Only scalars are supported"
assert not left.is_vector_output, "Only scalars are supported"
assert not right.is_vector_output, "Only scalars are supported"

self.left = left
self.right = right
Expand All @@ -48,18 +48,18 @@ def __str__(self):
return "BinNumExpr(" + args + ")"


class ArrayExpr(NumExpr):
is_multi_output = True
class VectorExpr(NumExpr):
is_vector_output = True

def __init__(self, exprs):
assert all(map(lambda e: not e.is_multi_output, exprs)), (
"All expressions for ArrayExpr must be scalar")
assert all(map(lambda e: not e.is_vector_output, exprs)), (
"All expressions for VectorExpr must be scalar")

self.exprs = exprs

def __str__(self):
args = ",".join([str(e) for e in self.exprs])
return "ArrayExpr([" + args + "])"
return "VectorExpr([" + args + "])"


# Boolean Expressions.
Expand All @@ -79,8 +79,8 @@ class CompOpType(Enum):

class CompExpr(BoolExpr):
def __init__(self, left, right, op):
assert not left.is_multi_output, "Only scalars are supported"
assert not right.is_multi_output, "Only scalars are supported"
assert not left.is_vector_output, "Only scalars are supported"
assert not right.is_vector_output, "Only scalars are supported"

self.left = left
self.right = right
Expand All @@ -99,14 +99,14 @@ class CtrlExpr(Expr):

class IfExpr(CtrlExpr):
def __init__(self, test, body, orelse):
assert not (body.is_multi_output ^ orelse.is_multi_output), (
"body and orelse expressions should have same is_multi_output")
assert not (body.is_vector_output ^ orelse.is_vector_output), (
"body and orelse expressions should have same is_vector_output")

self.test = test
self.body = body
self.orelse = orelse

self.is_multi_output = body.is_multi_output
self.is_vector_output = body.is_vector_output

def __str__(self):
args = ",".join([str(self.test), str(self.body), str(self.orelse)])
Expand All @@ -116,15 +116,10 @@ def __str__(self):
class TransparentExpr(CtrlExpr):
def __init__(self, expr):
self.expr = expr
self.is_multi_output = expr.is_multi_output
self.is_vector_output = expr.is_vector_output


class SubroutineExpr(TransparentExpr):

def __str__(self):
return "SubroutineExpr(" + str(self.expr) + ")"


class MainExpr(SubroutineExpr):
def __str__(self):
return "MainExpr(" + str(self.expr) + ")"
16 changes: 15 additions & 1 deletion m2cgen/interpreters/code_generator.py
Expand Up @@ -26,6 +26,9 @@ class BaseCodeGenerator:
tpl_block_termination = NotImplemented
tpl_var_assignment = NotImplemented

scalar_output_type = NotImplemented
vector_output_type = NotImplemented

def __init__(self, indent=4):
self._indent = indent
self.reset_state()
Expand Down Expand Up @@ -66,8 +69,9 @@ def add_code_lines(self, lines):
def add_return_statement(self, value):
self.add_code_line(self.tpl_return_statement(value=value))

def add_var_declaration(self, var_type="double"):
def add_var_declaration(self, is_vector_type=False):
var_name = self.get_var_name()
var_type = self._get_var_type(is_vector_type)
self.add_code_line(
self.tpl_var_declaration(
var_type=var_type, var_name=var_name))
Expand Down Expand Up @@ -103,6 +107,13 @@ def array_index_access(self, array_name, index):
return self.tpl_array_index_access(
array_name=array_name, index=index)

# Helpers

def _get_var_type(self, is_vector):
return (
self.vector_output_type if is_vector
else self.scalar_output_type)


class CLikeCodeGenerator(BaseCodeGenerator):
"""
Expand All @@ -119,3 +130,6 @@ class CLikeCodeGenerator(BaseCodeGenerator):
tpl_else_statement = CodeTemplate("} else {")
tpl_block_termination = CodeTemplate("}")
tpl_var_assignment = CodeTemplate("${var_name} = ${value};")

scalar_output_type = "double"
vector_output_type = "double[]"
18 changes: 6 additions & 12 deletions m2cgen/interpreters/interpreter.py
Expand Up @@ -14,27 +14,21 @@ def interpret(self, expr):

# Default method implementations

def interpret_if_expr(self, expr, if_var_name=None,
is_multi_output=False, **kwargs):
def interpret_if_expr(self, expr, if_var_name=None, **kwargs):
if if_var_name is not None:
var_name = if_var_name
else:
var_type = "double[]" if is_multi_output else "double"
var_name = self._cg.add_var_declaration(var_type=var_type)

if_def = self._do_interpret(expr.test, **kwargs)
self._cg.add_if_statement(if_def)
var_name = self._cg.add_var_declaration(
is_vector_type=expr.is_vector_output)

def handle_nested_expr(nested):
if isinstance(nested, ast.IfExpr):
self._do_interpret(nested, if_var_name=var_name,
is_multi_output=is_multi_output,
**kwargs)
self._do_interpret(nested, if_var_name=var_name, **kwargs)
else:
nested_result = self._do_interpret(
nested, is_multi_output=is_multi_output)
nested_result = self._do_interpret(nested)
self._cg.add_var_assignment(var_name, nested_result)

self._cg.add_if_statement(self._do_interpret(expr.test, **kwargs))
handle_nested_expr(expr.body)
self._cg.add_else_statement()
handle_nested_expr(expr.orelse)
Expand Down
18 changes: 12 additions & 6 deletions m2cgen/interpreters/java/code_generator.py
Expand Up @@ -13,9 +13,13 @@ def add_class_def(self, class_name, modifier="public"):
self.add_code_line(class_def)
self.increase_indent()

def add_method_def(self, name, args, return_type, modifier="public"):
def add_method_def(self, name, args, is_vector_output, modifier="public"):
return_type = self._get_var_type(is_vector_output)

method_def = modifier + " static " + return_type + " " + name + "("
method_def += ",".join([t + " " + n for t, n in args])
method_def += ",".join([
self._get_var_type(is_vector) + " " + n
for is_vector, n in args])
method_def += ") {"
self.add_code_line(method_def)
self.increase_indent()
Expand All @@ -31,13 +35,15 @@ def class_definition(self, model_name):
self.add_block_termination()

@contextlib.contextmanager
def method_definition(self, name, args, return_type, modifier="public"):
self.add_method_def(name, args, return_type, modifier=modifier)
def method_definition(self, name, args, is_vector_output,
modifier="public"):
self.add_method_def(name, args, is_vector_output, modifier=modifier)
yield
self.add_block_termination()

def method_invocation(self, method_name, *args):
return method_name + "(" + ", ".join(args) + ")"

def array_init(self, values, arr_type="double"):
return "new " + arr_type + "[] " + "{" + ", ".join(values) + "}"
def array_init(self, values):
return "new " + self.vector_output_type + (
" {" + ", ".join(values) + "}")
52 changes: 24 additions & 28 deletions m2cgen/interpreters/java/interpreter.py
@@ -1,6 +1,6 @@
from m2cgen.interpreters.interpreter import BaseInterpreter
from m2cgen.interpreters.java.code_generator import JavaCodeGenerator
from m2cgen import ast

from collections import namedtuple


Expand All @@ -20,6 +20,8 @@ def __init__(self, package_name=None, model_name="Model", indent=4,
super(JavaInterpreter, self).__init__(cg, *args, **kwargs)

def interpret(self, expr):
self._subroutine_expr_queue = [Subroutine("score", expr)]

self._subroutine_idx = 0

top_cg = self._create_code_generator()
Expand All @@ -28,58 +30,52 @@ def interpret(self, expr):
top_cg.add_package_name(self.package_name)

with top_cg.class_definition(self.model_name):
if isinstance(expr, ast.SubroutineExpr):
self._subroutine_expr_queue = []
self._do_interpret(expr)
else:
self._subroutine_expr_queue = [
Subroutine("score", expr)
]

while len(self._subroutine_expr_queue) > 0:
self._process_next_subroutine()
top_cg.add_code_lines(self._cg.code)
while len(self._subroutine_expr_queue):
subroutine_code = self._process_next_subroutine()
top_cg.add_code_lines(subroutine_code)

return [
(self.model_name, top_cg.code),
]

def interpret_subroutine_expr(self, expr, **kwargs):
method_name = self._get_subroutine_name()
return self._enqueue_subroutine(method_name, expr)

def interpret_main_expr(self, expr, **kwargs):
return self._enqueue_subroutine("score", expr)

def interpret_array_expr(self, expr, **kwargs):
def interpret_vector_expr(self, expr, **kwargs):
nested = []
for e in expr.exprs:
nested.append(self._do_interpret(e, **kwargs))
return self._cg.array_init(nested)

def _create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)

# Methods to support ast.SubroutineExpr

def interpret_subroutine_expr(self, expr, **kwargs):
method_name = self._get_subroutine_name()
return self._enqueue_subroutine(method_name, expr)

def _enqueue_subroutine(self, name, expr):
self._subroutine_expr_queue.append(Subroutine(name, expr.expr))
return self._cg.method_invocation(name, self._feature_array_name)

def _process_next_subroutine(self):
subroutine = self._subroutine_expr_queue.pop(0)
is_multi_output = subroutine.expr.is_multi_output
return_type = "double[]" if is_multi_output else "double"
is_vector_output = subroutine.expr.is_vector_output

self._cg = self._create_code_generator()

with self._cg.method_definition(
name=subroutine.name,
args=[("double[]", self._feature_array_name)],
return_type=return_type):
args=[
(True, self._feature_array_name)],
is_vector_output=is_vector_output):
last_result = self._do_interpret(
subroutine.expr,
is_multi_output=is_multi_output)
is_vector_output=is_vector_output)
self._cg.add_return_statement(last_result)

return self._cg.code

def _get_subroutine_name(self):
subroutine_name = "subroutine" + str(self._subroutine_idx)
self._subroutine_idx += 1
return subroutine_name

def _create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)
2 changes: 1 addition & 1 deletion tests/assemblers/test_linear.py
Expand Up @@ -57,7 +57,7 @@ def test_multi_class():
assembler = assemblers.LinearModelAssembler(estimator)
actual = assembler.assemble()

expected = ast.ArrayExpr([
expected = ast.VectorExpr([
ast.SubroutineExpr(
ast.BinNumExpr(
ast.BinNumExpr(
Expand Down

0 comments on commit 19fee6f

Please sign in to comment.