-
Notifications
You must be signed in to change notification settings - Fork 241
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce Dart language support (#167)
- Loading branch information
1 parent
3086e9a
commit 3c431f2
Showing
18 changed files
with
687 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
include LICENSE | ||
recursive-include m2cgen VERSION.txt | ||
recursive-include m2cgen linear_algebra.* | ||
recursive-include m2cgen tanh.bas | ||
recursive-include m2cgen tanh.* | ||
global-exclude *.py[cod] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import contextlib | ||
|
||
from m2cgen.interpreters.code_generator import CLikeCodeGenerator | ||
|
||
|
||
class DartCodeGenerator(CLikeCodeGenerator): | ||
|
||
scalar_type = "double" | ||
vector_type = "List<double>" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(DartCodeGenerator, self).__init__(*args, **kwargs) | ||
|
||
def add_function_def(self, name, args, is_vector_output): | ||
return_type = self._get_var_declare_type(is_vector_output) | ||
function_def = return_type + " " + name + "(" | ||
function_def += ",".join([ | ||
self._get_var_declare_type(is_vector) + " " + n | ||
for is_vector, n in args]) | ||
function_def += ") {" | ||
self.add_code_line(function_def) | ||
self.increase_indent() | ||
|
||
@contextlib.contextmanager | ||
def function_definition(self, name, args, is_vector_output): | ||
self.add_function_def(name, args, is_vector_output) | ||
yield | ||
self.add_block_termination() | ||
|
||
def vector_init(self, values): | ||
return "[" + ", ".join(values) + "]" | ||
|
||
def _get_var_declare_type(self, is_vector): | ||
return ( | ||
self.vector_type if is_vector | ||
else self.scalar_type) | ||
|
||
def add_dependency(self, dep): | ||
dep_str = "import '" + dep + "';" | ||
self.prepend_code_line(dep_str) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import os | ||
|
||
from m2cgen import ast | ||
from m2cgen.interpreters import mixins | ||
from m2cgen.interpreters import utils | ||
from m2cgen.interpreters.interpreter import ToCodeInterpreter | ||
from m2cgen.interpreters.dart.code_generator import DartCodeGenerator | ||
|
||
|
||
class DartInterpreter(ToCodeInterpreter, | ||
mixins.LinearAlgebraMixin, | ||
mixins.BinExpressionDepthTrackingMixin): | ||
|
||
supported_bin_vector_ops = { | ||
ast.BinNumOpType.ADD: "addVectors", | ||
} | ||
|
||
supported_bin_vector_num_ops = { | ||
ast.BinNumOpType.MUL: "mulVectorNumber", | ||
} | ||
|
||
bin_depth_threshold = 465 | ||
|
||
exponent_function_name = "exp" | ||
power_function_name = "pow" | ||
tanh_function_name = "tanh" | ||
|
||
with_tanh_expr = False | ||
|
||
def __init__(self, indent=4, function_name="score", *args, **kwargs): | ||
self.indent = indent | ||
self.function_name = function_name | ||
|
||
cg = DartCodeGenerator(indent=indent) | ||
super(DartInterpreter, self).__init__(cg, *args, **kwargs) | ||
|
||
def interpret(self, expr): | ||
self._cg.reset_state() | ||
self._reset_reused_expr_cache() | ||
|
||
args = [(True, self._feature_array_name)] | ||
|
||
with self._cg.function_definition( | ||
name=self.function_name, | ||
args=args, | ||
is_vector_output=expr.output_size > 1): | ||
last_result = self._do_interpret(expr) | ||
self._cg.add_return_statement(last_result) | ||
|
||
if self.with_linear_algebra: | ||
filename = os.path.join( | ||
os.path.dirname(__file__), "linear_algebra.dart") | ||
self._cg.add_code_lines(utils.get_file_content(filename)) | ||
|
||
# Use own tanh function in order to be compatible with Dart | ||
if self.with_tanh_expr: | ||
filename = os.path.join( | ||
os.path.dirname(__file__), "tanh.dart") | ||
self._cg.add_code_lines(utils.get_file_content(filename)) | ||
|
||
if self.with_math_module: | ||
self._cg.add_dependency("dart:math") | ||
|
||
return self._cg.code | ||
|
||
def interpret_tanh_expr(self, expr, **kwargs): | ||
self.with_tanh_expr = True | ||
return super( | ||
DartInterpreter, self).interpret_tanh_expr(expr, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
List<double> addVectors(List<double> v1, List<double> v2) { | ||
List<double> result = new List<double>(v1.length); | ||
for (int i = 0; i < v1.length; i++) { | ||
result[i] = v1[i] + v2[i]; | ||
} | ||
return result; | ||
} | ||
List<double> mulVectorNumber(List<double> v1, double num) { | ||
List<double> result = new List<double>(v1.length); | ||
for (int i = 0; i < v1.length; i++) { | ||
result[i] = v1[i] * num; | ||
} | ||
return result; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
double tanh(double x) { | ||
if (x > 22.0) | ||
return 1.0; | ||
if (x < -22.0) | ||
return -1.0; | ||
return ((exp(2*x) - 1)/(exp(2*x) + 1)); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import string | ||
|
||
from m2cgen import assemblers, interpreters | ||
from tests import utils | ||
from tests.e2e.executors import base | ||
|
||
EXECUTOR_CODE_TPL = """ | ||
${model_code} | ||
void main(List<String> args) { | ||
List<double> input_ = args.map((x) => double.parse(x)).toList(); | ||
${print_code} | ||
} | ||
""" | ||
|
||
EXECUTE_AND_PRINT_SCALAR = """ | ||
double res = score(input_); | ||
print(res); | ||
""" | ||
|
||
EXECUTE_AND_PRINT_VECTOR = """ | ||
List<double> res = score(input_); | ||
print(res.join(" ")); | ||
""" | ||
|
||
|
||
class DartExecutor(base.BaseExecutor): | ||
|
||
executor_name = "score" | ||
|
||
def __init__(self, model): | ||
self.model = model | ||
self.interpreter = interpreters.DartInterpreter() | ||
|
||
assembler_cls = assemblers.get_assembler_cls(model) | ||
self.model_ast = assembler_cls(model).assemble() | ||
|
||
self._dart = "dart" | ||
|
||
def predict(self, X): | ||
file_name = os.path.join(self._resource_tmp_dir, | ||
"{}.dart".format(self.executor_name)) | ||
exec_args = [self._dart, | ||
file_name, | ||
*map(str, X)] | ||
return utils.predict_from_commandline(exec_args) | ||
|
||
def prepare(self): | ||
if self.model_ast.output_size > 1: | ||
print_code = EXECUTE_AND_PRINT_VECTOR | ||
else: | ||
print_code = EXECUTE_AND_PRINT_SCALAR | ||
|
||
model_code = self.interpreter.interpret(self.model_ast) | ||
executor_code = string.Template(EXECUTOR_CODE_TPL).substitute( | ||
model_code=model_code, | ||
print_code=print_code) | ||
|
||
executor_file_name = os.path.join( | ||
self._resource_tmp_dir, "{}.dart".format(self.executor_name)) | ||
with open(executor_file_name, "w") as f: | ||
f.write(executor_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.