Skip to content

Commit

Permalink
Merge ae2fc65 into c342ca1
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Mar 10, 2020
2 parents c342ca1 + ae2fc65 commit a823ede
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 105 deletions.
7 changes: 6 additions & 1 deletion Dockerfile
Expand Up @@ -11,12 +11,15 @@ RUN apt-get update && \
wget -qO- https://storage.googleapis.com/download.dartlang.org/linux/debian/dart_stable.list > /etc/apt/sources.list.d/dart_stable.list && \
apt-get update && \
apt-get install --no-install-recommends -y \
git \
gcc \
g++ \
libc-dev \
libgomp1 \
python3.7 \
python3-setuptools \
python3-pip \
python3.7-dev \
openjdk-8-jdk \
golang-go \
dotnet-sdk-3.0 \
Expand All @@ -29,7 +32,9 @@ RUN apt-get update && \
WORKDIR /m2cgen

COPY requirements-test.txt ./
RUN pip3 install --no-cache-dir Cython && \
RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.7 1 && \
pip3 install --upgrade pip && \
pip3 install --no-cache-dir Cython numpy && \
pip3 install --no-cache-dir -r requirements-test.txt

CMD python3 setup.py develop && pytest -v -x --fast
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -15,7 +15,7 @@ docker-generate-examples:
$(DOCKER_RUN_ARGS) bash -c "python3 setup.py develop && python3 tools/generate_code_examples.py generated_code_examples"

docker-flake8:
$(DOCKER_RUN_ARGS) bash -c "flake8 ."
$(DOCKER_RUN_ARGS) bash -c "flake8 --exclude tmp,generated_code_examples ."

docker-shell:
$(DOCKER_RUN_ARGS) bash
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/java/code_generator.py
Expand Up @@ -53,6 +53,6 @@ def _get_var_declare_type(self, is_vector):
self.vector_output_type if is_vector
else self.scalar_output_type)

# Method `function_definition` is required by SubroutinesAsFunctionsMixin.
# Method `function_definition` is required by SubroutinesMixin.
# We already have this functionality in `method_definition` method.
function_definition = method_definition
44 changes: 4 additions & 40 deletions m2cgen/interpreters/java/interpreter.py
@@ -1,5 +1,4 @@
import os
import math

from m2cgen import ast
from m2cgen.interpreters import mixins
Expand All @@ -10,12 +9,11 @@

class JavaInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin,
mixins.SubroutinesAsFunctionsMixin,
mixins.BinExpressionDepthTrackingMixin):
mixins.SubroutinesMixin):

# The below numbers have been determined experimentally and are subject
# to adjustments in future.
bin_depth_threshold = 100
ast_size_check_frequency = 100
ast_size_per_subroutine_threshold = 4600

supported_bin_vector_ops = {
Expand Down Expand Up @@ -50,7 +48,7 @@ def interpret(self, expr):

with top_cg.class_definition(self.class_name):

# Since we use SubroutinesAsFunctionsMixin, we already have logic
# Since we use SubroutinesMixin, we already have logic
# of adding methods. We create first subroutine for incoming
# expression and call `process_subroutine_queue` method.
self.enqueue_subroutine(self.function_name, expr)
Expand All @@ -63,41 +61,7 @@ def interpret(self, expr):

return top_cg.code

def interpret_subroutine_expr(self, expr, **kwargs):
return self._do_interpret(expr.expr, **kwargs)

# Required by SubroutinesAsFunctionsMixin to create new code generator for
# Required by SubroutinesMixin to create new code generator for
# each subroutine.
def create_code_generator(self):
return JavaCodeGenerator(indent=self.indent)

def bin_depth_threshold_hook(self, expr, **kwargs):
# The condition below is a sanity check to ensure that the expression
# is actually worth moving into a separate subroutine.
if ast.count_exprs(expr) > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)
return self._cg.function_invocation(
function_name, self._feature_array_name)
else:
return self._do_interpret(expr, **kwargs)

def _pre_interpret_hook(self, expr, **kwargs):
if isinstance(expr, ast.BinExpr):
threshold = self._calc_bin_depth_threshold(expr)
self.bin_depth_threshold = min(threshold, self.bin_depth_threshold)
return super()._pre_interpret_hook(expr, **kwargs)

def _calc_bin_depth_threshold(self, expr):
# The logic below counts the number of non-binary expressions
# in a non-recursive branch of a binary expression to account
# for large tree-like models and adjust the bin depth threshold
# if necessary.
cnt = None
if not isinstance(expr.left, ast.BinExpr):
cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
elif not isinstance(expr.right, ast.BinExpr):
cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
if cnt and cnt < self.ast_size_per_subroutine_threshold:
return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
return self.bin_depth_threshold
65 changes: 49 additions & 16 deletions m2cgen/interpreters/mixins.py
@@ -1,4 +1,5 @@
import sys
import math
from collections import namedtuple

from m2cgen import ast
Expand All @@ -23,15 +24,15 @@ class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter):

def _pre_interpret_hook(self, expr, bin_depth=0, **kwargs):
if not isinstance(expr, ast.BinExpr):
return None, kwargs
return super()._pre_interpret_hook(expr, **kwargs)

# We track depth of the binary expressions and call a hook if it
# reaches specified threshold .
if bin_depth >= self.bin_depth_threshold:
return self.bin_depth_threshold_hook(expr, **kwargs), kwargs

kwargs["bin_depth"] = bin_depth + 1
return None, kwargs
return super()._pre_interpret_hook(expr, **kwargs)

# Default implementation. Simply adds new variable.
def bin_depth_threshold_hook(self, expr, **kwargs):
Expand Down Expand Up @@ -92,9 +93,10 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
Subroutine = namedtuple('Subroutine', ['name', 'expr'])


class SubroutinesAsFunctionsMixin(BaseToCodeInterpreter):
class SubroutinesMixin(BaseToCodeInterpreter):
"""
This mixin provides ability to interpret each SubroutineExpr as a function.
This mixin provides ability to split the code into subroutines based on
the size of the AST.
Subclasses only need to implement `create_code_generator` method.
Expand All @@ -108,6 +110,10 @@ class SubroutinesAsFunctionsMixin(BaseToCodeInterpreter):
instance of code generator, which will be populated with the result code.
"""

# disabled by default
ast_size_check_frequency = sys.maxsize
ast_size_per_subroutine_threshold = sys.maxsize

def __init__(self, *args, **kwargs):
self._subroutine_idx = 0
self.subroutine_expr_queue = []
Expand All @@ -123,19 +129,49 @@ def process_subroutine_queue(self, top_code_generator):
while len(self.subroutine_expr_queue):
self._reset_reused_expr_cache()
subroutine = self.subroutine_expr_queue.pop(0)
subroutine_code = self.process_subroutine(subroutine)
subroutine_code = self._process_subroutine(subroutine)
top_code_generator.add_code_lines(subroutine_code)

def interpret_subroutine_expr(self, expr, **kwargs):
def enqueue_subroutine(self, name, expr):
self.subroutine_expr_queue.append(Subroutine(name, expr))

def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
if isinstance(expr, ast.BinExpr) and not expr.to_reuse:
frequency = self._adjust_ast_check_frequency(expr)
self.ast_size_check_frequency = min(
frequency, self.ast_size_check_frequency)

ast_size_check_counter += 1
if ast_size_check_counter >= self.ast_size_check_frequency:
ast_size_check_counter = 0
ast_size = ast.count_exprs(expr)
if ast_size > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)
return self._cg.function_invocation(
function_name, self._feature_array_name), kwargs

kwargs['ast_size_check_counter'] = ast_size_check_counter

return super()._pre_interpret_hook(expr, **kwargs)

def _adjust_ast_check_frequency(self, expr):
"""
This method will be called whenever new subroutine is encountered.
The logic below counts the number of non-binary expressions
in a non-recursive branch of a binary expression to account
for large tree-like models and adjust the size check frequency
if necessary.
"""
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr.expr)
return self._cg.function_invocation(
function_name, self._feature_array_name)

def process_subroutine(self, subroutine):
cnt = None
if not isinstance(expr.left, ast.BinExpr):
cnt = ast.count_exprs(expr.left, exclude_list={ast.BinExpr})
elif not isinstance(expr.right, ast.BinExpr):
cnt = ast.count_exprs(expr.right, exclude_list={ast.BinExpr})
if cnt and cnt < self.ast_size_per_subroutine_threshold:
return math.ceil(self.ast_size_per_subroutine_threshold / cnt)
return self.ast_size_check_frequency

def _process_subroutine(self, subroutine):
"""
Handles single subroutine. Creates new code generator and defines a
function for a given subroutine.
Expand All @@ -153,9 +189,6 @@ def process_subroutine(self, subroutine):

return self._cg.code

def enqueue_subroutine(self, name, expr):
self.subroutine_expr_queue.append(Subroutine(name, expr))

def _get_subroutine_name(self):
subroutine_name = "subroutine" + str(self._subroutine_idx)
self._subroutine_idx += 1
Expand Down
5 changes: 4 additions & 1 deletion m2cgen/interpreters/r/interpreter.py
Expand Up @@ -6,7 +6,7 @@
class RInterpreter(ToCodeInterpreter,
mixins.LinearAlgebraMixin,
mixins.BinExpressionDepthTrackingMixin,
mixins.SubroutinesAsFunctionsMixin):
mixins.SubroutinesMixin):

# R doesn't allow to have more than 50 nested if, [, [[, {, ( calls.
# It raises contextstack overflow error not only for explicitly nested
Expand All @@ -19,6 +19,9 @@ class RInterpreter(ToCodeInterpreter,
# based on the users' feedback.
bin_depth_threshold = 25

ast_size_check_frequency = 2
ast_size_per_subroutine_threshold = 200

exponent_function_name = "exp"
tanh_function_name = "tanh"

Expand Down
8 changes: 4 additions & 4 deletions tests/interpreters/test_java.py
Expand Up @@ -358,7 +358,7 @@ def test_depth_threshold_with_bin_expr():
expr = ast.BinNumExpr(ast.NumVal(1), expr, ast.BinNumOpType.ADD)

interpreter = interpreters.JavaInterpreter()
interpreter.bin_depth_threshold = 2
interpreter.ast_size_check_frequency = 3
interpreter.ast_size_per_subroutine_threshold = 1

expected_code = """
Expand All @@ -385,7 +385,7 @@ def test_depth_threshold_without_bin_expr():
expr)

interpreter = interpreters.JavaInterpreter()
interpreter.bin_depth_threshold = 2
interpreter.ast_size_check_frequency = 2
interpreter.ast_size_per_subroutine_threshold = 1

expected_code = """
Expand Down Expand Up @@ -431,7 +431,7 @@ def test_deep_mixed_exprs_not_reaching_threshold():
expr)

interpreter = interpreters.JavaInterpreter()
interpreter.bin_depth_threshold = 2
interpreter.ast_size_check_frequency = 3
interpreter.ast_size_per_subroutine_threshold = 1

expected_code = """
Expand Down Expand Up @@ -477,7 +477,7 @@ def test_deep_mixed_exprs_exceeding_threshold():
expr)

interpreter = interpreters.JavaInterpreter()
interpreter.bin_depth_threshold = 2
interpreter.ast_size_check_frequency = 3
interpreter.ast_size_per_subroutine_threshold = 1

expected_code = """
Expand Down

0 comments on commit a823ede

Please sign in to comment.