Skip to content

Commit

Permalink
Merge 6a37a0f into 2d0105b
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Jan 31, 2019
2 parents 2d0105b + 6a37a0f commit 05fe3bc
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 87 deletions.
15 changes: 4 additions & 11 deletions examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.utils import shuffle

from m2cgen import exporters
import m2cgen as m2c


def train_model(estimator):
Expand All @@ -29,35 +29,28 @@ def train_model(estimator):
print("Variance score: %.2f" % r2_score(y_test, y_pred))


def print_model(defined_classes):
print(defined_classes[0][1])


def example_linear():
estimator = linear_model.LinearRegression()
train_model(estimator)

print("Coef", estimator.coef_)
print("Intercept", estimator.intercept_)

exporter = exporters.PythonExporter(estimator)
print_model(exporter.export())
print(m2c.export_to_python(estimator))


def example_tree():
estimator = tree.DecisionTreeRegressor()
train_model(estimator)

exporter = exporters.PythonExporter(estimator)
print_model(exporter.export())
print(m2c.export_to_python(estimator))


def example_random_forest():
estimator = ensemble.RandomForestRegressor(n_estimators=10)
train_model(estimator)

exporter = exporters.JavaExporter(estimator)
print_model(exporter.export())
print(m2c.export_to_java(estimator))


if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions m2cgen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .exporters import export_to_java, export_to_python


__all__ = [
export_to_java,
export_to_python,
]
65 changes: 28 additions & 37 deletions m2cgen/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,41 @@
from m2cgen import interpreters


class BaseExporter:
SUPPORTED_MODELS = {
"LinearRegression": assemblers.LinearModelAssembler,
"LogisticRegression": assemblers.LinearModelAssembler,
"DecisionTreeRegressor": assemblers.TreeModelAssembler,
"DecisionTreeClassifier": assemblers.TreeModelAssembler,
"RandomForestRegressor": assemblers.RandomForestModelAssembler,
"RandomForestClassifier": assemblers.RandomForestModelAssembler,
}

interpreter = None

models_to_assemblers = {
"LinearRegression": assemblers.LinearModelAssembler,
"LogisticRegression": assemblers.LinearModelAssembler,
"DecisionTreeRegressor": assemblers.TreeModelAssembler,
"DecisionTreeClassifier": assemblers.TreeModelAssembler,
"RandomForestRegressor": assemblers.RandomForestModelAssembler,
"RandomForestClassifier": assemblers.RandomForestModelAssembler,
}
def export_to_java(model, package_name=None, model_name="Model", indent=4):
interpreter = interpreters.JavaInterpreter(
package_name=package_name,
model_name=model_name,
indent=indent)
return _export(model, interpreter)

def __init__(self, model):
self.model = model
self.assembler = self._get_assembler_cls(type(model).__name__)(model)
assert self.interpreter, "interpreter is required"

def _get_assembler_cls(self, model_name):
assembler_cls = self.models_to_assemblers.get(model_name)
def export_to_python(model, indent=4):
interpreter = interpreters.PythonInterpreter(indent=indent)
return _export(model, interpreter)

if not assembler_cls:
raise NotImplementedError(
"Model {} is not supported".format(model_name))

return assembler_cls
def _export(model, interpreter):
assembler_cls = _get_assembler_cls(model)
model_ast = assembler_cls(model).assemble()
return interpreter.interpret(model_ast)

def export(self):
model_ast = self.assembler.assemble()
return self.interpreter.interpret(model_ast)

def _get_assembler_cls(model):
model_name = type(model).__name__
assembler_cls = SUPPORTED_MODELS.get(model_name)

class JavaExporter(BaseExporter):
if not assembler_cls:
raise NotImplementedError(
"Model {} is not supported".format(model_name))

def __init__(self, model, package_name=None, model_name="Model", indent=4):
self.interpreter = interpreters.JavaInterpreter(
package_name=package_name,
model_name=model_name,
indent=indent)
super(JavaExporter, self).__init__(model)


class PythonExporter(BaseExporter):

def __init__(self, model, indent=4):
self.interpreter = interpreters.PythonInterpreter(indent=indent)
super(PythonExporter, self).__init__(model)
return assembler_cls
4 changes: 1 addition & 3 deletions m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def interpret(self, expr):
os.path.dirname(__file__), "linear_algebra.java")
top_cg.add_code_lines(utils.get_file_content(filename))

return [
(self.model_name, top_cg.code),
]
return top_cg.code

def _create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)
Expand Down
4 changes: 1 addition & 3 deletions m2cgen/interpreters/python/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@ def interpret(self, expr):
if self.with_numpy:
self._cg.add_dependency("numpy", alias="np")

return [
("", self._cg.code),
]
return self._cg.code

def interpret_vector_val(self, expr, **kwargs):
self.with_numpy = True
Expand Down
19 changes: 8 additions & 11 deletions tests/e2e/executors/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
import subprocess
import shutil

from m2cgen import exporters
import m2cgen as m2c
from tests.e2e.executors import base


class JavaExecutor(base.BaseExecutor):

def __init__(self, model):
self.model = model
self.exporter = exporters.JavaExporter(model)
self.model_name = "Model"

java_home = os.environ.get("JAVA_HOME")
assert java_home, "JAVA_HOME is not specified"
Expand All @@ -33,22 +33,19 @@ def predict(self, X):

def prepare(self):
# Create files generated by exporter in the temp dir.
files_to_compile = []
for model_name, code in self.exporter.export():
file_name = os.path.join(self._resource_tmp_dir,
"{}.java".format(model_name))
code = m2c.export_to_java(self.model, model_name=self.model_name)
code_file_name = os.path.join(self._resource_tmp_dir,
"{}.java".format(self.model_name))

with open(file_name, "w") as f:
f.write(code)

files_to_compile.append(file_name)
with open(code_file_name, "w") as f:
f.write(code)

# Move Executor.java to the same temp dir.
module_path = os.path.dirname(__file__)
shutil.copy(os.path.join(module_path, "Executor.java"),
self._resource_tmp_dir)

# Compile all files together.
exec_args = [self._javac_bin] + files_to_compile + (
exec_args = [self._javac_bin, code_file_name] + (
[os.path.join(self._resource_tmp_dir, "Executor.java")])
subprocess.call(exec_args)
8 changes: 2 additions & 6 deletions tests/e2e/executors/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import os
import sys

from m2cgen import exporters
import m2cgen as m2c
from tests.e2e.executors import base


class PythonExecutor(base.BaseExecutor):

def __init__(self, model):
self.model = model
self.exporter = exporters.PythonExporter(model)

def predict(self, X):
# Hacky way to dynamically import generated function
Expand All @@ -29,10 +28,7 @@ def predict(self, X):
return score(X.tolist())

def prepare(self):
exported_models = self.exporter.export()
assert len(exported_models) == 1

_, code = exported_models[0]
code = m2c.export_to_python(self.model)

file_name = os.path.join(self._resource_tmp_dir, "model.py")

Expand Down
18 changes: 9 additions & 9 deletions tests/interpreters/test_java.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_if_expr():
}
"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_num_expr():
Expand All @@ -46,7 +46,7 @@ def test_bin_num_expr():
}
}"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_dependable_condition():
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_dependable_condition():

interpreter = interpreters.JavaInterpreter()

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_nested_condition():
Expand Down Expand Up @@ -138,7 +138,7 @@ def test_nested_condition():
}"""

interpreter = interpreters.JavaInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_package_name():
Expand All @@ -156,7 +156,7 @@ def test_package_name():
}
}"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_subroutine():
Expand All @@ -179,7 +179,7 @@ def test_subroutine():
}"""

interpreter = interpreters.JavaInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_multi_output():
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_multi_output():
}"""

interpreter = interpreters.JavaInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_vector_expr():
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_bin_vector_expr():
return result;
}
}"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_vector_num_expr():
Expand Down Expand Up @@ -274,4 +274,4 @@ def test_bin_vector_num_expr():
return result;
}
}"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)
14 changes: 7 additions & 7 deletions tests/interpreters/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def score(input):
return var0
"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_num_expr():
Expand All @@ -37,7 +37,7 @@ def score(input):
return ((input[0]) / (-2)) * (2)
"""

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_dependable_condition():
Expand Down Expand Up @@ -71,7 +71,7 @@ def score(input):

interpreter = interpreters.PythonInterpreter()

utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_nested_condition():
Expand Down Expand Up @@ -112,7 +112,7 @@ def score(input):
"""

interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_multi_output():
Expand All @@ -136,7 +136,7 @@ def score(input):
"""

interpreter = interpreters.PythonInterpreter()
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_vector_expr():
Expand All @@ -152,7 +152,7 @@ def test_bin_vector_expr():
def score(input):
return (np.asarray([1, 2])) * (np.asarray([3, 4]))
"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)


def test_bin_vector_num_expr():
Expand All @@ -168,4 +168,4 @@ def test_bin_vector_num_expr():
def score(input):
return (np.asarray([1, 2])) * (1)
"""
utils.assert_code_equal(interpreter.interpret(expr)[0][1], expected_code)
utils.assert_code_equal(interpreter.interpret(expr), expected_code)

0 comments on commit 05fe3bc

Please sign in to comment.