diff --git a/.travis.yml b/.travis.yml index e1a1c455..1a095fb0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,7 @@ env: - TEST=API - TEST=E2E LANG="c_lang or python or java or go_lang or javascript or php" - TEST=E2E LANG="c_sharp or visual_basic or powershell" - - TEST=E2E LANG="r_lang" + - TEST=E2E LANG="r_lang or dart" before_install: - bash .travis/setup.sh diff --git a/.travis/setup.sh b/.travis/setup.sh index efc9a4fe..b2db0808 100644 --- a/.travis/setup.sh +++ b/.travis/setup.sh @@ -29,3 +29,11 @@ if [[ $LANG == *"php"* ]]; then sudo apt-get update sudo apt-get install --no-install-recommends -y php fi + +# Install Dart. (https://dart.dev/get-dart) +if [[ $LANG == *"dart"* ]]; then + sudo sh -c 'wget -qO- https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -' + sudo sh -c 'wget -qO- https://storage.googleapis.com/download.dartlang.org/linux/debian/dart_stable.list > /etc/apt/sources.list.d/dart_stable.list' + sudo apt-get update + sudo apt-get install --no-install-recommends -y dart +fi diff --git a/Dockerfile b/Dockerfile index c4bc3f31..b4b350b2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,8 @@ RUN apt-get update && \ add-apt-repository ppa:deadsnakes/ppa && \ wget -q https://packages.microsoft.com/config/ubuntu/$(lsb_release -rs)/packages-microsoft-prod.deb -O packages-microsoft-prod.deb && \ dpkg -i packages-microsoft-prod.deb && \ + wget -qO- https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add - && \ + wget -qO- https://storage.googleapis.com/download.dartlang.org/linux/debian/dart_stable.list > /etc/apt/sources.list.d/dart_stable.list && \ apt-get update && \ apt-get install --no-install-recommends -y \ gcc \ @@ -20,7 +22,8 @@ RUN apt-get update && \ dotnet-sdk-3.0 \ powershell \ r-base \ - php && \ + php \ + dart && \ rm -rf /var/lib/apt/lists/* WORKDIR /m2cgen diff --git a/MANIFEST.in b/MANIFEST.in index 2cadfbab..710be031 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ include LICENSE recursive-include m2cgen VERSION.txt recursive-include m2cgen linear_algebra.* -recursive-include m2cgen tanh.bas +recursive-include m2cgen tanh.* global-exclude *.py[cod] diff --git a/README.md b/README.md index 2513f95d..b923e711 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![PyPI Version](https://img.shields.io/pypi/v/m2cgen.svg?logo=pypi&logoColor=white)](https://pypi.org/project/m2cgen) [![Downloads](https://pepy.tech/badge/m2cgen)](https://pepy.tech/project/m2cgen) -**m2cgen** (Model 2 Code Generator) - is a lightweight library which provides an easy way to transpile trained statistical models into a native code (Python, C, Java, Go, JavaScript, Visual Basic, C#, PowerShell, R, PHP). +**m2cgen** (Model 2 Code Generator) - is a lightweight library which provides an easy way to transpile trained statistical models into a native code (Python, C, Java, Go, JavaScript, Visual Basic, C#, PowerShell, R, PHP, Dart). * [Installation](#installation) * [Supported Languages](#supported-languages) @@ -28,6 +28,7 @@ pip install m2cgen - C - C# +- Dart - Go - Java - JavaScript diff --git a/m2cgen/__init__.py b/m2cgen/__init__.py index 42d9dbc4..316d2150 100644 --- a/m2cgen/__init__.py +++ b/m2cgen/__init__.py @@ -11,6 +11,7 @@ export_to_powershell, export_to_r, export_to_php, + export_to_dart, ) __all__ = [ @@ -24,6 +25,7 @@ export_to_powershell, export_to_r, export_to_php, + export_to_dart, ] with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), diff --git a/m2cgen/cli.py b/m2cgen/cli.py index c156085a..84cb6ef9 100644 --- a/m2cgen/cli.py +++ b/m2cgen/cli.py @@ -31,6 +31,7 @@ "powershell": (m2cgen.export_to_powershell, ["indent", "function_name"]), "r": (m2cgen.export_to_r, ["indent", "function_name"]), "php": (m2cgen.export_to_php, ["indent", "function_name"]), + "dart": (m2cgen.export_to_dart, ["indent", "function_name"]), } diff --git a/m2cgen/exporters.py b/m2cgen/exporters.py index b69a9947..98cec7b9 100644 --- a/m2cgen/exporters.py +++ b/m2cgen/exporters.py @@ -302,6 +302,30 @@ def export_to_php(model, indent=4, function_name="score"): return _export(model, interpreter) +def export_to_dart(model, indent=4, function_name="score"): + """ + Generates a Dart code representation of the given model. + + Parameters + ---------- + model : object + The model object that should be transpiled into code. + indent : int, optional + The size of indents in the generated code. + function_name : string, optional + Name of the function in the generated code. + + Returns + ------- + code : string + """ + interpreter = interpreters.DartInterpreter( + indent=indent, + function_name=function_name, + ) + return _export(model, interpreter) + + def _export(model, interpreter): assembler_cls = assemblers.get_assembler_cls(model) model_ast = assembler_cls(model).assemble() diff --git a/m2cgen/interpreters/__init__.py b/m2cgen/interpreters/__init__.py index 6964c276..87042690 100644 --- a/m2cgen/interpreters/__init__.py +++ b/m2cgen/interpreters/__init__.py @@ -8,6 +8,7 @@ from .powershell.interpreter import PowershellInterpreter from .r.interpreter import RInterpreter from .php.interpreter import PhpInterpreter +from .dart.interpreter import DartInterpreter __all__ = [ JavaInterpreter, @@ -20,4 +21,5 @@ PowershellInterpreter, RInterpreter, PhpInterpreter, + DartInterpreter, ] diff --git a/m2cgen/interpreters/dart/__init__.py b/m2cgen/interpreters/dart/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/m2cgen/interpreters/dart/code_generator.py b/m2cgen/interpreters/dart/code_generator.py new file mode 100644 index 00000000..02ed661e --- /dev/null +++ b/m2cgen/interpreters/dart/code_generator.py @@ -0,0 +1,40 @@ +import contextlib + +from m2cgen.interpreters.code_generator import CLikeCodeGenerator + + +class DartCodeGenerator(CLikeCodeGenerator): + + scalar_type = "double" + vector_type = "List" + + def __init__(self, *args, **kwargs): + super(DartCodeGenerator, self).__init__(*args, **kwargs) + + def add_function_def(self, name, args, is_vector_output): + return_type = self._get_var_declare_type(is_vector_output) + function_def = return_type + " " + name + "(" + function_def += ",".join([ + self._get_var_declare_type(is_vector) + " " + n + for is_vector, n in args]) + function_def += ") {" + self.add_code_line(function_def) + self.increase_indent() + + @contextlib.contextmanager + def function_definition(self, name, args, is_vector_output): + self.add_function_def(name, args, is_vector_output) + yield + self.add_block_termination() + + def vector_init(self, values): + return "[" + ", ".join(values) + "]" + + def _get_var_declare_type(self, is_vector): + return ( + self.vector_type if is_vector + else self.scalar_type) + + def add_dependency(self, dep): + dep_str = "import '" + dep + "';" + self.prepend_code_line(dep_str) diff --git a/m2cgen/interpreters/dart/interpreter.py b/m2cgen/interpreters/dart/interpreter.py new file mode 100644 index 00000000..bbe0b396 --- /dev/null +++ b/m2cgen/interpreters/dart/interpreter.py @@ -0,0 +1,69 @@ +import os + +from m2cgen import ast +from m2cgen.interpreters import mixins +from m2cgen.interpreters import utils +from m2cgen.interpreters.interpreter import ToCodeInterpreter +from m2cgen.interpreters.dart.code_generator import DartCodeGenerator + + +class DartInterpreter(ToCodeInterpreter, + mixins.LinearAlgebraMixin, + mixins.BinExpressionDepthTrackingMixin): + + supported_bin_vector_ops = { + ast.BinNumOpType.ADD: "addVectors", + } + + supported_bin_vector_num_ops = { + ast.BinNumOpType.MUL: "mulVectorNumber", + } + + bin_depth_threshold = 465 + + exponent_function_name = "exp" + power_function_name = "pow" + tanh_function_name = "tanh" + + with_tanh_expr = False + + def __init__(self, indent=4, function_name="score", *args, **kwargs): + self.indent = indent + self.function_name = function_name + + cg = DartCodeGenerator(indent=indent) + super(DartInterpreter, self).__init__(cg, *args, **kwargs) + + def interpret(self, expr): + self._cg.reset_state() + self._reset_reused_expr_cache() + + args = [(True, self._feature_array_name)] + + with self._cg.function_definition( + name=self.function_name, + args=args, + is_vector_output=expr.output_size > 1): + last_result = self._do_interpret(expr) + self._cg.add_return_statement(last_result) + + if self.with_linear_algebra: + filename = os.path.join( + os.path.dirname(__file__), "linear_algebra.dart") + self._cg.add_code_lines(utils.get_file_content(filename)) + + # Use own tanh function in order to be compatible with Dart + if self.with_tanh_expr: + filename = os.path.join( + os.path.dirname(__file__), "tanh.dart") + self._cg.add_code_lines(utils.get_file_content(filename)) + + if self.with_math_module: + self._cg.add_dependency("dart:math") + + return self._cg.code + + def interpret_tanh_expr(self, expr, **kwargs): + self.with_tanh_expr = True + return super( + DartInterpreter, self).interpret_tanh_expr(expr, **kwargs) diff --git a/m2cgen/interpreters/dart/linear_algebra.dart b/m2cgen/interpreters/dart/linear_algebra.dart new file mode 100644 index 00000000..316be74e --- /dev/null +++ b/m2cgen/interpreters/dart/linear_algebra.dart @@ -0,0 +1,14 @@ +List addVectors(List v1, List v2) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] + v2[i]; + } + return result; +} +List mulVectorNumber(List v1, double num) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] * num; + } + return result; +} diff --git a/m2cgen/interpreters/dart/tanh.dart b/m2cgen/interpreters/dart/tanh.dart new file mode 100644 index 00000000..f6955237 --- /dev/null +++ b/m2cgen/interpreters/dart/tanh.dart @@ -0,0 +1,7 @@ +double tanh(double x) { + if (x > 22.0) + return 1.0; + if (x < -22.0) + return -1.0; + return ((exp(2*x) - 1)/(exp(2*x) + 1)); +} diff --git a/tests/e2e/executors/__init__.py b/tests/e2e/executors/__init__.py index 38934495..9547adb6 100644 --- a/tests/e2e/executors/__init__.py +++ b/tests/e2e/executors/__init__.py @@ -8,6 +8,7 @@ from tests.e2e.executors.powershell import PowershellExecutor from tests.e2e.executors.r import RExecutor from tests.e2e.executors.php import PhpExecutor +from tests.e2e.executors.dart import DartExecutor __all__ = [ JavaExecutor, @@ -20,4 +21,5 @@ PowershellExecutor, RExecutor, PhpExecutor, + DartExecutor, ] diff --git a/tests/e2e/executors/dart.py b/tests/e2e/executors/dart.py new file mode 100644 index 00000000..e17e8baf --- /dev/null +++ b/tests/e2e/executors/dart.py @@ -0,0 +1,63 @@ +import os +import string + +from m2cgen import assemblers, interpreters +from tests import utils +from tests.e2e.executors import base + +EXECUTOR_CODE_TPL = """ +${model_code} + +void main(List args) { + List input_ = args.map((x) => double.parse(x)).toList(); + ${print_code} +} +""" + +EXECUTE_AND_PRINT_SCALAR = """ + double res = score(input_); + print(res); +""" + +EXECUTE_AND_PRINT_VECTOR = """ + List res = score(input_); + print(res.join(" ")); +""" + + +class DartExecutor(base.BaseExecutor): + + executor_name = "score" + + def __init__(self, model): + self.model = model + self.interpreter = interpreters.DartInterpreter() + + assembler_cls = assemblers.get_assembler_cls(model) + self.model_ast = assembler_cls(model).assemble() + + self._dart = "dart" + + def predict(self, X): + file_name = os.path.join(self._resource_tmp_dir, + "{}.dart".format(self.executor_name)) + exec_args = [self._dart, + file_name, + *map(str, X)] + return utils.predict_from_commandline(exec_args) + + def prepare(self): + if self.model_ast.output_size > 1: + print_code = EXECUTE_AND_PRINT_VECTOR + else: + print_code = EXECUTE_AND_PRINT_SCALAR + + model_code = self.interpreter.interpret(self.model_ast) + executor_code = string.Template(EXECUTOR_CODE_TPL).substitute( + model_code=model_code, + print_code=print_code) + + executor_file_name = os.path.join( + self._resource_tmp_dir, "{}.dart".format(self.executor_name)) + with open(executor_file_name, "w") as f: + f.write(executor_code) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 30c18875..2be0a800 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -26,6 +26,7 @@ POWERSHELL = pytest.mark.powershell R = pytest.mark.r_lang PHP = pytest.mark.php +DART = pytest.mark.dart REGRESSION = pytest.mark.regr CLASSIFICATION = pytest.mark.clf @@ -125,6 +126,7 @@ def classification_binary_random(model): (executors.PowershellExecutor, POWERSHELL), (executors.RExecutor, R), (executors.PhpExecutor, PHP), + (executors.DartExecutor, DART), ], # These models will be tested against each language specified in the diff --git a/tests/interpreters/test_dart.py b/tests/interpreters/test_dart.py new file mode 100644 index 00000000..d473d0d3 --- /dev/null +++ b/tests/interpreters/test_dart.py @@ -0,0 +1,445 @@ +from m2cgen import ast +from m2cgen.interpreters import DartInterpreter +from tests import utils + + +def test_if_expr(): + expr = ast.IfExpr( + ast.CompExpr(ast.NumVal(1), ast.FeatureRef(0), ast.CompOpType.EQ), + ast.NumVal(2), + ast.NumVal(3)) + + expected_code = """ +double score(List input) { + double var0; + if ((1) == (input[0])) { + var0 = 2; + } else { + var0 = 3; + } + return var0; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_bin_num_expr(): + expr = ast.BinNumExpr( + ast.BinNumExpr( + ast.FeatureRef(0), ast.NumVal(-2), ast.BinNumOpType.DIV), + ast.NumVal(2), + ast.BinNumOpType.MUL) + + expected_code = """ +double score(List input) { + return ((input[0]) / (-2)) * (2); +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_dependable_condition(): + left = ast.BinNumExpr( + ast.IfExpr( + ast.CompExpr(ast.NumVal(1), + ast.NumVal(1), + ast.CompOpType.EQ), + ast.NumVal(1), + ast.NumVal(2)), + ast.NumVal(2), + ast.BinNumOpType.ADD) + + right = ast.BinNumExpr(ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.DIV) + bool_test = ast.CompExpr(left, right, ast.CompOpType.GTE) + + expr = ast.IfExpr(bool_test, ast.NumVal(1), ast.FeatureRef(0)) + + expected_code = """ +double score(List input) { + double var0; + double var1; + if ((1) == (1)) { + var1 = 1; + } else { + var1 = 2; + } + if (((var1) + (2)) >= ((1) / (2))) { + var0 = 1; + } else { + var0 = input[0]; + } + return var0; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_nested_condition(): + left = ast.BinNumExpr( + ast.IfExpr( + ast.CompExpr(ast.NumVal(1), + ast.NumVal(1), + ast.CompOpType.EQ), + ast.NumVal(1), + ast.NumVal(2)), + ast.NumVal(2), + ast.BinNumOpType.ADD) + + bool_test = ast.CompExpr(ast.NumVal(1), left, ast.CompOpType.EQ) + + expr_nested = ast.IfExpr(bool_test, ast.FeatureRef(2), ast.NumVal(2)) + + expr = ast.IfExpr(bool_test, expr_nested, ast.NumVal(2)) + + expected_code = """ +double score(List input) { + double var0; + double var1; + if ((1) == (1)) { + var1 = 1; + } else { + var1 = 2; + } + if ((1) == ((var1) + (2))) { + double var2; + if ((1) == (1)) { + var2 = 1; + } else { + var2 = 2; + } + if ((1) == ((var2) + (2))) { + var0 = input[2]; + } else { + var0 = 2; + } + } else { + var0 = 2; + } + return var0; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_raw_array(): + expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) + + expected_code = """ +List score(List input) { + return [3, 4]; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_multi_output(): + expr = ast.SubroutineExpr( + ast.IfExpr( + ast.CompExpr( + ast.NumVal(1), + ast.NumVal(1), + ast.CompOpType.EQ), + ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), + ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))) + + expected_code = """ +List score(List input) { + List var0; + if ((1) == (1)) { + var0 = [1, 2]; + } else { + var0 = [3, 4]; + } + return var0; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_bin_vector_expr(): + expr = ast.BinVectorExpr( + ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), + ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]), + ast.BinNumOpType.ADD) + + expected_code = """ +List score(List input) { + return addVectors([1, 2], [3, 4]); +} +List addVectors(List v1, List v2) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] + v2[i]; + } + return result; +} +List mulVectorNumber(List v1, double num) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] * num; + } + return result; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_bin_vector_num_expr(): + expr = ast.BinVectorNumExpr( + ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), + ast.NumVal(1), + ast.BinNumOpType.MUL) + + expected_code = """ +List score(List input) { + return mulVectorNumber([1, 2], 1); +} +List addVectors(List v1, List v2) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] + v2[i]; + } + return result; +} +List mulVectorNumber(List v1, double num) { + List result = new List(v1.length); + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] * num; + } + return result; +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +class CustomDartInterpreter(DartInterpreter): + bin_depth_threshold = 2 + + +def test_depth_threshold_with_bin_expr(): + expr = ast.NumVal(1) + for i in range(4): + expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD) + + interpreter = CustomDartInterpreter() + + expected_code = """ +double score(List input) { + double var0; + var0 = (1) + ((1) + (1)); + return (1) + ((1) + (var0)); +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_depth_threshold_without_bin_expr(): + expr = ast.NumVal(1) + for i in range(4): + expr = ast.IfExpr( + ast.CompExpr( + ast.NumVal(1), ast.NumVal(1), ast.CompOpType.EQ), + ast.NumVal(1), + expr) + + interpreter = CustomDartInterpreter() + + expected_code = """ +double score(List input) { + double var0; + if ((1) == (1)) { + var0 = 1; + } else { + if ((1) == (1)) { + var0 = 1; + } else { + if ((1) == (1)) { + var0 = 1; + } else { + if ((1) == (1)) { + var0 = 1; + } else { + var0 = 1; + } + } + } + } + return var0; +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_deep_mixed_exprs_not_reaching_threshold(): + expr = ast.NumVal(1) + for i in range(4): + inner = ast.NumVal(1) + for i in range(2): + inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) + + expr = ast.IfExpr( + ast.CompExpr( + inner, ast.NumVal(1), ast.CompOpType.EQ), + ast.NumVal(1), + expr) + + interpreter = CustomDartInterpreter() + + expected_code = """ +double score(List input) { + double var0; + if (((1) + ((1) + (1))) == (1)) { + var0 = 1; + } else { + if (((1) + ((1) + (1))) == (1)) { + var0 = 1; + } else { + if (((1) + ((1) + (1))) == (1)) { + var0 = 1; + } else { + if (((1) + ((1) + (1))) == (1)) { + var0 = 1; + } else { + var0 = 1; + } + } + } + } + return var0; +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_deep_mixed_exprs_exceeding_threshold(): + expr = ast.NumVal(1) + for i in range(4): + inner = ast.NumVal(1) + for i in range(4): + inner = ast.BinNumExpr(ast.NumVal(1), inner, ast.BinNumOpType.ADD) + + expr = ast.IfExpr( + ast.CompExpr( + inner, ast.NumVal(1), ast.CompOpType.EQ), + ast.NumVal(1), + expr) + + interpreter = CustomDartInterpreter() + + expected_code = """ +double score(List input) { + double var0; + double var1; + var1 = (1) + ((1) + (1)); + if (((1) + ((1) + (var1))) == (1)) { + var0 = 1; + } else { + double var2; + var2 = (1) + ((1) + (1)); + if (((1) + ((1) + (var2))) == (1)) { + var0 = 1; + } else { + double var3; + var3 = (1) + ((1) + (1)); + if (((1) + ((1) + (var3))) == (1)) { + var0 = 1; + } else { + double var4; + var4 = (1) + ((1) + (1)); + if (((1) + ((1) + (var4))) == (1)) { + var0 = 1; + } else { + var0 = 1; + } + } + } + } + return var0; +} +""" + + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_exp_expr(): + expr = ast.ExpExpr(ast.NumVal(1.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return exp(1.0); +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_pow_expr(): + expr = ast.PowExpr(ast.NumVal(2.0), ast.NumVal(3.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return pow(2.0, 3.0); +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_tanh_expr(): + expr = ast.TanhExpr(ast.NumVal(2.0)) + + expected_code = """ +import 'dart:math'; +double score(List input) { + return tanh(2.0); +} +double tanh(double x) { + if (x > 22.0) + return 1.0; + if (x < -22.0) + return -1.0; + return ((exp(2*x) - 1)/(exp(2*x) + 1)); +} +""" + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code) + + +def test_reused_expr(): + reused_expr = ast.ExpExpr(ast.NumVal(1.0), to_reuse=True) + expr = ast.BinNumExpr(reused_expr, reused_expr, ast.BinNumOpType.DIV) + + expected_code = """ +import 'dart:math'; +double score(List input) { + double var0; + var0 = exp(1.0); + return (var0) / (var0); +} +""" + + interpreter = DartInterpreter() + utils.assert_code_equal(interpreter.interpret(expr), expected_code)