From b71fcadb436dcb22b65e0f73243b51456e4c7c11 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 10 Mar 2020 06:32:03 -0700 Subject: [PATCH 1/2] Generalize the Subroutine approach implemented in the Java interpreter and use it in the R interpreter --- Dockerfile | 7 +- Makefile | 2 +- m2cgen/interpreters/java/code_generator.py | 2 +- m2cgen/interpreters/java/interpreter.py | 44 ++----------- m2cgen/interpreters/mixins.py | 57 ++++++++++++---- m2cgen/interpreters/r/interpreter.py | 5 +- tests/interpreters/test_java.py | 8 +-- tests/interpreters/test_r.py | 76 ++++++++++------------ 8 files changed, 97 insertions(+), 104 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9340d502..7ff63c18 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,12 +11,15 @@ RUN apt-get update && \ wget -qO- https://storage.googleapis.com/download.dartlang.org/linux/debian/dart_stable.list > /etc/apt/sources.list.d/dart_stable.list && \ apt-get update && \ apt-get install --no-install-recommends -y \ + git \ gcc \ + g++ \ libc-dev \ libgomp1 \ python3.7 \ python3-setuptools \ python3-pip \ + python3.7-dev \ openjdk-8-jdk \ golang-go \ dotnet-sdk-3.0 \ @@ -29,7 +32,9 @@ RUN apt-get update && \ WORKDIR /m2cgen COPY requirements-test.txt ./ -RUN pip3 install --no-cache-dir Cython && \ +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1 && \ + pip3 install --upgrade pip && \ + pip3 install --no-cache-dir Cython numpy && \ pip3 install --no-cache-dir -r requirements-test.txt CMD python3 setup.py develop && pytest -v -x --fast diff --git a/Makefile b/Makefile index bc4233ce..99a838df 100644 --- a/Makefile +++ b/Makefile @@ -15,7 +15,7 @@ docker-generate-examples: $(DOCKER_RUN_ARGS) bash -c "python3 setup.py develop && python3 tools/generate_code_examples.py generated_code_examples" docker-flake8: - $(DOCKER_RUN_ARGS) bash -c "flake8 ." + $(DOCKER_RUN_ARGS) bash -c "flake8 --exclude tmp,generated_code_examples ." docker-shell: $(DOCKER_RUN_ARGS) bash diff --git a/m2cgen/interpreters/java/code_generator.py b/m2cgen/interpreters/java/code_generator.py index a302d139..598c4694 100644 --- a/m2cgen/interpreters/java/code_generator.py +++ b/m2cgen/interpreters/java/code_generator.py @@ -53,6 +53,6 @@ def _get_var_declare_type(self, is_vector): self.vector_output_type if is_vector else self.scalar_output_type) - # Method `function_definition` is required by SubroutinesAsFunctionsMixin. + # Method `function_definition` is required by SubroutinesMixin. # We already have this functionality in `method_definition` method. function_definition = method_definition diff --git a/m2cgen/interpreters/java/interpreter.py b/m2cgen/interpreters/java/interpreter.py index 175e21b8..b49779dd 100644 --- a/m2cgen/interpreters/java/interpreter.py +++ b/m2cgen/interpreters/java/interpreter.py @@ -1,5 +1,4 @@ import os -import math from m2cgen import ast from m2cgen.interpreters import mixins @@ -10,12 +9,11 @@ class JavaInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin, - mixins.SubroutinesAsFunctionsMixin, - mixins.BinExpressionDepthTrackingMixin): + mixins.SubroutinesMixin): # The below numbers have been determined experimentally and are subject # to adjustments in future. - bin_depth_threshold = 100 + ast_size_check_frequency = 100 ast_size_per_subroutine_threshold = 4600 supported_bin_vector_ops = { @@ -50,7 +48,7 @@ def interpret(self, expr): with top_cg.class_definition(self.class_name): - # Since we use SubroutinesAsFunctionsMixin, we already have logic + # Since we use SubroutinesMixin, we already have logic # of adding methods. We create first subroutine for incoming # expression and call `process_subroutine_queue` method. self.enqueue_subroutine(self.function_name, expr) @@ -63,41 +61,7 @@ def interpret(self, expr): return top_cg.code - def interpret_subroutine_expr(self, expr, **kwargs): - return self._do_interpret(expr.expr, **kwargs) - - # Required by SubroutinesAsFunctionsMixin to create new code generator for + # Required by SubroutinesMixin to create new code generator for # each subroutine. def create_code_generator(self): return JavaCodeGenerator(indent=self.indent) - - def bin_depth_threshold_hook(self, expr, **kwargs): - # The condition below is a sanity check to ensure that the expression - # is actually worth moving into a separate subroutine. - if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold: - function_name = self._get_subroutine_name() - self.enqueue_subroutine(function_name, expr) - return self._cg.function_invocation( - function_name, self._feature_array_name) - else: - return self._do_interpret(expr, **kwargs) - - def _pre_interpret_hook(self, expr, **kwargs): - if isinstance(expr, ast.BinExpr): - threshold = self._calc_bin_depth_threshold(expr) - self.bin_depth_threshold = min(threshold, self.bin_depth_threshold) - return super()._pre_interpret_hook(expr, **kwargs) - - def _calc_bin_depth_threshold(self, expr): - # The logic below counts the number of non-binary expressions - # in a non-recursive branch of a binary expression to account - # for large tree-like models and adjust the bin depth threshold - # if necessary. - cnt = None - if not isinstance(expr.left, ast.BinExpr): - cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr}) - elif not isinstance(expr.right, ast.BinExpr): - cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr}) - if cnt and cnt < self.ast_size_per_subroutine_threshold: - return math.ceil(self.ast_size_per_subroutine_threshold / cnt) - return self.bin_depth_threshold diff --git a/m2cgen/interpreters/mixins.py b/m2cgen/interpreters/mixins.py index f2ea4aa8..46268018 100644 --- a/m2cgen/interpreters/mixins.py +++ b/m2cgen/interpreters/mixins.py @@ -1,4 +1,5 @@ import sys +import math from collections import namedtuple from m2cgen import ast @@ -92,9 +93,10 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(), Subroutine = namedtuple('Subroutine', ['name', 'expr']) -class SubroutinesAsFunctionsMixin(BaseToCodeInterpreter): +class SubroutinesMixin(BinExpressionDepthTrackingMixin): """ - This mixin provides ability to interpret each SubroutineExpr as a function. + This mixin provides ability to split the code into subroutines based on + the size of the AST. Subclasses only need to implement `create_code_generator` method. @@ -108,6 +110,10 @@ class SubroutinesAsFunctionsMixin(BaseToCodeInterpreter): instance of code generator, which will be populated with the result code. """ + # disabled by default + ast_size_check_frequency = sys.maxsize + ast_size_per_subroutine_threshold = sys.maxsize + def __init__(self, *args, **kwargs): self._subroutine_idx = 0 self.subroutine_expr_queue = [] @@ -123,19 +129,45 @@ def process_subroutine_queue(self, top_code_generator): while len(self.subroutine_expr_queue): self._reset_reused_expr_cache() subroutine = self.subroutine_expr_queue.pop(0) - subroutine_code = self.process_subroutine(subroutine) + subroutine_code = self._process_subroutine(subroutine) top_code_generator.add_code_lines(subroutine_code) - def interpret_subroutine_expr(self, expr, **kwargs): + def enqueue_subroutine(self, name, expr): + self.subroutine_expr_queue.append(Subroutine(name, expr)) + + def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): + if isinstance(expr, ast.BinExpr): + frequency = self._adjust_ast_check_frequency(expr) + self.ast_size_check_frequency = min( + frequency, self.ast_size_check_frequency) + + if bin_depth >= self.ast_size_check_frequency: + ast_size = ast.count_exprs(expr) + if ast_size > self.ast_size_per_subroutine_threshold: + function_name = self._get_subroutine_name() + self.enqueue_subroutine(function_name, expr) + return self._cg.function_invocation( + function_name, self._feature_array_name), kwargs + + return super()._pre_interpret_hook(expr, bin_depth=bin_depth, **kwargs) + + def _adjust_ast_check_frequency(self, expr): """ - This method will be called whenever new subroutine is encountered. + The logic below counts the number of non-binary expressions + in a non-recursive branch of a binary expression to account + for large tree-like models and adjust the size check frequency + if necessary. """ - function_name = self._get_subroutine_name() - self.enqueue_subroutine(function_name, expr.expr) - return self._cg.function_invocation( - function_name, self._feature_array_name) - - def process_subroutine(self, subroutine): + cnt = None + if not isinstance(expr.left, ast.BinExpr): + cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr}) + elif not isinstance(expr.right, ast.BinExpr): + cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr}) + if cnt and cnt < self.ast_size_per_subroutine_threshold: + return math.ceil(self.ast_size_per_subroutine_threshold / cnt) + return self.ast_size_check_frequency + + def _process_subroutine(self, subroutine): """ Handles single subroutine. Creates new code generator and defines a function for a given subroutine. @@ -153,9 +185,6 @@ def process_subroutine(self, subroutine): return self._cg.code - def enqueue_subroutine(self, name, expr): - self.subroutine_expr_queue.append(Subroutine(name, expr)) - def _get_subroutine_name(self): subroutine_name = "subroutine" + str(self._subroutine_idx) self._subroutine_idx += 1 diff --git a/m2cgen/interpreters/r/interpreter.py b/m2cgen/interpreters/r/interpreter.py index 10571558..887bffe2 100644 --- a/m2cgen/interpreters/r/interpreter.py +++ b/m2cgen/interpreters/r/interpreter.py @@ -5,8 +5,7 @@ class RInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin, - mixins.BinExpressionDepthTrackingMixin, - mixins.SubroutinesAsFunctionsMixin): + mixins.SubroutinesMixin): # R doesn't allow to have more than 50 nested if, [, [[, {, ( calls. # It raises contextstack overflow error not only for explicitly nested @@ -18,6 +17,8 @@ class RInterpreter(ToCodeInterpreter, # This value is just a heuristic and is subject to change in the future # based on the users' feedback. bin_depth_threshold = 25 + ast_size_check_frequency = 2 + ast_size_per_subroutine_threshold = 200 exponent_function_name = "exp" tanh_function_name = "tanh" diff --git a/tests/interpreters/test_java.py b/tests/interpreters/test_java.py index a45d66eb..3cac4095 100644 --- a/tests/interpreters/test_java.py +++ b/tests/interpreters/test_java.py @@ -358,7 +358,7 @@ def test_depth_threshold_with_bin_expr(): expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD) interpreter = interpreters.JavaInterpreter() - interpreter.bin_depth_threshold = 2 + interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ @@ -385,7 +385,7 @@ def test_depth_threshold_without_bin_expr(): expr) interpreter = interpreters.JavaInterpreter() - interpreter.bin_depth_threshold = 2 + interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ @@ -431,7 +431,7 @@ def test_deep_mixed_exprs_not_reaching_threshold(): expr) interpreter = interpreters.JavaInterpreter() - interpreter.bin_depth_threshold = 2 + interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ @@ -477,7 +477,7 @@ def test_deep_mixed_exprs_exceeding_threshold(): expr) interpreter = interpreters.JavaInterpreter() - interpreter.bin_depth_threshold = 2 + interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ diff --git a/tests/interpreters/test_r.py b/tests/interpreters/test_r.py index b4a181e0..ae79f263 100644 --- a/tests/interpreters/test_r.py +++ b/tests/interpreters/test_r.py @@ -123,27 +123,6 @@ def test_nested_condition(): utils.assert_code_equal(interpreter.interpret(expr), expected_code) -def test_subroutine(): - expr = ast.BinNumExpr( - ast.FeatureRef(0), - ast.SubroutineExpr( - ast.BinNumExpr( - ast.NumVal(1), ast.NumVal(2), ast.BinNumOpType.ADD)), - ast.BinNumOpType.MUL) - - expected_code = """ -score <- function(input) { - return((input[1]) * (subroutine0(input))) -} -subroutine0 <- function(input) { - return((1) + (2)) -} -""" - - interpreter = RInterpreter() - utils.assert_code_equal(interpreter.interpret(expr), expected_code) - - def test_raw_array(): expr = ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]) @@ -158,20 +137,16 @@ def test_raw_array(): def test_multi_output(): - expr = ast.SubroutineExpr( - ast.IfExpr( - ast.CompExpr( - ast.NumVal(1), - ast.NumVal(1), - ast.CompOpType.EQ), - ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), - ast.VectorVal([ast.NumVal(3), ast.NumVal(4)]))) + expr = ast.IfExpr( + ast.CompExpr( + ast.NumVal(1), + ast.NumVal(1), + ast.CompOpType.EQ), + ast.VectorVal([ast.NumVal(1), ast.NumVal(2)]), + ast.VectorVal([ast.NumVal(3), ast.NumVal(4)])) expected_code = """ score <- function(input) { - return(subroutine0(input)) -} -subroutine0 <- function(input) { if ((1) == (1)) { var0 <- c(1, 2) } else { @@ -329,24 +304,23 @@ def test_deep_mixed_exprs_exceeding_threshold(): ast.NumVal(1), expr) - interpreter = CustomRInterpreter() + interpreter = RInterpreter() + interpreter.bin_depth_threshold = 1 + interpreter.ast_size_check_frequency = 1 + interpreter.ast_size_per_subroutine_threshold = 6 expected_code = """ score <- function(input) { - var1 <- (1) + ((1) + (1)) - if (((1) + ((1) + (var1))) == (1)) { + if (((1) + (subroutine0(input))) == (1)) { var0 <- 1 } else { - var2 <- (1) + ((1) + (1)) - if (((1) + ((1) + (var2))) == (1)) { + if (((1) + (subroutine1(input))) == (1)) { var0 <- 1 } else { - var3 <- (1) + ((1) + (1)) - if (((1) + ((1) + (var3))) == (1)) { + if (((1) + (subroutine2(input))) == (1)) { var0 <- 1 } else { - var4 <- (1) + ((1) + (1)) - if (((1) + ((1) + (var4))) == (1)) { + if (((1) + (subroutine3(input))) == (1)) { var0 <- 1 } else { var0 <- 1 @@ -356,6 +330,26 @@ def test_deep_mixed_exprs_exceeding_threshold(): } return(var0) } +subroutine0 <- function(input) { + var1 <- (1) + (1) + var0 <- (1) + (var1) + return((1) + (var0)) +} +subroutine1 <- function(input) { + var1 <- (1) + (1) + var0 <- (1) + (var1) + return((1) + (var0)) +} +subroutine2 <- function(input) { + var1 <- (1) + (1) + var0 <- (1) + (var1) + return((1) + (var0)) +} +subroutine3 <- function(input) { + var1 <- (1) + (1) + var0 <- (1) + (var1) + return((1) + (var0)) +} """ utils.assert_code_equal(interpreter.interpret(expr), expected_code) From ae2fc65b55289ba1a415a6b94b2f4a316defa339 Mon Sep 17 00:00:00 2001 From: Iaroslav Zeigerman Date: Tue, 10 Mar 2020 13:20:47 -0700 Subject: [PATCH 2/2] Ignore reusable expressions in the SubroutinesMixin pre-interpret hook --- m2cgen/interpreters/mixins.py | 18 +++++++++++------- m2cgen/interpreters/r/interpreter.py | 2 ++ tests/interpreters/test_java.py | 6 +++--- tests/interpreters/test_r.py | 14 +++++++++----- 4 files changed, 25 insertions(+), 15 deletions(-) diff --git a/m2cgen/interpreters/mixins.py b/m2cgen/interpreters/mixins.py index 46268018..d49a92d6 100644 --- a/m2cgen/interpreters/mixins.py +++ b/m2cgen/interpreters/mixins.py @@ -24,7 +24,7 @@ class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter): def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): if not isinstance(expr, ast.BinExpr): - return None, kwargs + return super()._pre_interpret_hook(expr, **kwargs) # We track depth of the binary expressions and call a hook if it # reaches specified threshold . @@ -32,7 +32,7 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): return self.bin_depth_threshold_hook(expr, **kwargs), kwargs kwargs["bin_depth"] = bin_depth + 1 - return None, kwargs + return super()._pre_interpret_hook(expr, **kwargs) # Default implementation. Simply adds new variable. def bin_depth_threshold_hook(self, expr, **kwargs): @@ -93,7 +93,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(), Subroutine = namedtuple('Subroutine', ['name', 'expr']) -class SubroutinesMixin(BinExpressionDepthTrackingMixin): +class SubroutinesMixin(BaseToCodeInterpreter): """ This mixin provides ability to split the code into subroutines based on the size of the AST. @@ -135,13 +135,15 @@ def process_subroutine_queue(self, top_code_generator): def enqueue_subroutine(self, name, expr): self.subroutine_expr_queue.append(Subroutine(name, expr)) - def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): - if isinstance(expr, ast.BinExpr): + def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs): + if isinstance(expr, ast.BinExpr) and not expr.to_reuse: frequency = self._adjust_ast_check_frequency(expr) self.ast_size_check_frequency = min( frequency, self.ast_size_check_frequency) - if bin_depth >= self.ast_size_check_frequency: + ast_size_check_counter += 1 + if ast_size_check_counter >= self.ast_size_check_frequency: + ast_size_check_counter = 0 ast_size = ast.count_exprs(expr) if ast_size > self.ast_size_per_subroutine_threshold: function_name = self._get_subroutine_name() @@ -149,7 +151,9 @@ def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs): return self._cg.function_invocation( function_name, self._feature_array_name), kwargs - return super()._pre_interpret_hook(expr, bin_depth=bin_depth, **kwargs) + kwargs['ast_size_check_counter'] = ast_size_check_counter + + return super()._pre_interpret_hook(expr, **kwargs) def _adjust_ast_check_frequency(self, expr): """ diff --git a/m2cgen/interpreters/r/interpreter.py b/m2cgen/interpreters/r/interpreter.py index 887bffe2..a2a2e948 100644 --- a/m2cgen/interpreters/r/interpreter.py +++ b/m2cgen/interpreters/r/interpreter.py @@ -5,6 +5,7 @@ class RInterpreter(ToCodeInterpreter, mixins.LinearAlgebraMixin, + mixins.BinExpressionDepthTrackingMixin, mixins.SubroutinesMixin): # R doesn't allow to have more than 50 nested if, [, [[, {, ( calls. @@ -17,6 +18,7 @@ class RInterpreter(ToCodeInterpreter, # This value is just a heuristic and is subject to change in the future # based on the users' feedback. bin_depth_threshold = 25 + ast_size_check_frequency = 2 ast_size_per_subroutine_threshold = 200 diff --git a/tests/interpreters/test_java.py b/tests/interpreters/test_java.py index 3cac4095..7bc719a0 100644 --- a/tests/interpreters/test_java.py +++ b/tests/interpreters/test_java.py @@ -358,7 +358,7 @@ def test_depth_threshold_with_bin_expr(): expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD) interpreter = interpreters.JavaInterpreter() - interpreter.ast_size_check_frequency = 2 + interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ @@ -431,7 +431,7 @@ def test_deep_mixed_exprs_not_reaching_threshold(): expr) interpreter = interpreters.JavaInterpreter() - interpreter.ast_size_check_frequency = 2 + interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ @@ -477,7 +477,7 @@ def test_deep_mixed_exprs_exceeding_threshold(): expr) interpreter = interpreters.JavaInterpreter() - interpreter.ast_size_check_frequency = 2 + interpreter.ast_size_check_frequency = 3 interpreter.ast_size_per_subroutine_threshold = 1 expected_code = """ diff --git a/tests/interpreters/test_r.py b/tests/interpreters/test_r.py index ae79f263..48960926 100644 --- a/tests/interpreters/test_r.py +++ b/tests/interpreters/test_r.py @@ -306,21 +306,25 @@ def test_deep_mixed_exprs_exceeding_threshold(): interpreter = RInterpreter() interpreter.bin_depth_threshold = 1 - interpreter.ast_size_check_frequency = 1 + interpreter.ast_size_check_frequency = 2 interpreter.ast_size_per_subroutine_threshold = 6 expected_code = """ score <- function(input) { - if (((1) + (subroutine0(input))) == (1)) { + var1 <- subroutine0(input) + if (((1) + (var1)) == (1)) { var0 <- 1 } else { - if (((1) + (subroutine1(input))) == (1)) { + var2 <- subroutine1(input) + if (((1) + (var2)) == (1)) { var0 <- 1 } else { - if (((1) + (subroutine2(input))) == (1)) { + var3 <- subroutine2(input) + if (((1) + (var3)) == (1)) { var0 <- 1 } else { - if (((1) + (subroutine3(input))) == (1)) { + var4 <- subroutine3(input) + if (((1) + (var4)) == (1)) { var0 <- 1 } else { var0 <- 1