Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix java export for huge models #306

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions m2cgen/interpreters/java/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def class_definition(self, class_name):
yield
self.add_block_termination()

@contextmanager
def module_definition(self, module_name):
self.add_class_def(module_name, modifier="private static")
yield
self.add_block_termination()

def module_function_invocation(self, module_name, function_name, *args):
invocation_code = self.function_invocation(function_name, *args)
return f"{module_name}.{invocation_code}"

@contextmanager
def method_definition(self, name, args, is_vector_output,
modifier="public"):
Expand Down
3 changes: 2 additions & 1 deletion m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
# to adjustments in future.
ast_size_check_frequency = 100
ast_size_per_subroutine_threshold = 4600
subroutine_per_group_threshold = 15

supported_bin_vector_ops = {
BinNumOpType.ADD: "addVectors",
Expand Down Expand Up @@ -62,7 +63,7 @@ def interpret(self, expr):
# Since we use SubroutinesMixin, we already have logic
# of adding methods. We create first subroutine for incoming
# expression and call `process_subroutine_queue` method.
self.enqueue_subroutine(self.function_name, expr)
self.enqueue_subroutine(self.function_name, 0, expr)
self.process_subroutine_queue(top_cg)

current_dir = Path(__file__).absolute().parent
Expand Down
47 changes: 41 additions & 6 deletions m2cgen/interpreters/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from m2cgen import ast
from m2cgen.interpreters.interpreter import BaseToCodeInterpreter
from m2cgen.interpreters.utils import chunks


class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter):
Expand Down Expand Up @@ -89,7 +90,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(), **kwargs):
*extra_func_args)


Subroutine = namedtuple('Subroutine', ['name', 'expr'])
Subroutine = namedtuple('Subroutine', ['name', 'idx', 'expr'])


class SubroutinesMixin(BaseToCodeInterpreter):
Expand All @@ -99,9 +100,11 @@ class SubroutinesMixin(BaseToCodeInterpreter):

Subclasses only need to implement `create_code_generator` method.

Their code generators should implement 3 methods:
Their code generators should implement 5 methods:
- function_definition;
- function_invocation;
- module_definition;
- module_function_invocation;
- add_return_statement.

Interpreter should prepare at least one subroutine using method
Expand All @@ -112,6 +115,7 @@ class SubroutinesMixin(BaseToCodeInterpreter):
# disabled by default
ast_size_check_frequency = sys.maxsize
ast_size_per_subroutine_threshold = sys.maxsize
subroutine_per_group_threshold = sys.maxsize
Copy link
Member

@izeigerman izeigerman Jun 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like subroutine_per_module_threshold seems more accurate


def __init__(self, *args, **kwargs):
self._subroutine_idx = 0
Expand All @@ -124,15 +128,33 @@ def process_subroutine_queue(self, top_code_generator):
subroutine queue.
"""
self._subroutine_idx = 0
subroutines = []

while len(self.subroutine_expr_queue):
while self.subroutine_expr_queue:
self._reset_reused_expr_cache()
subroutine = self.subroutine_expr_queue.pop(0)
subroutine_code = self._process_subroutine(subroutine)
subroutines.append((subroutine, subroutine_code))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If my assumption in the previous comment is correct I don't think we actually need subroutine here and having only subroutine_code will suffice.


subroutines.sort(key=lambda subroutine: subroutine[0].idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this sort given that subroutines are added in a specific order already (the subroutine_expr_queue)? And a follow-up: why do we need index as part of the Subroutine data structure?


groups = chunks(subroutines, self.subroutine_per_group_threshold)
for _, subroutine_code in next(groups):
top_code_generator.add_code_lines(subroutine_code)

def enqueue_subroutine(self, name, expr):
self.subroutine_expr_queue.append(Subroutine(name, expr))
for index, subroutine_group in enumerate(groups):
cg = self.create_code_generator()

with cg.module_definition(
module_name=self._format_group_name(index + 1)):
for _, subroutine_code in subroutine_group:
cg.add_code_lines(subroutine_code)

top_code_generator.add_code_lines(
cg.finalize_and_get_generated_code())

def enqueue_subroutine(self, name, idx, expr):
self.subroutine_expr_queue.append(Subroutine(name, idx, expr))

def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
if isinstance(expr, ast.BinExpr) and not expr.to_reuse:
Expand All @@ -145,7 +167,16 @@ def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
ast_size = ast.count_exprs(expr)
if ast_size > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)

self.enqueue_subroutine(function_name, self._subroutine_idx, expr)

group_idx = self._subroutine_idx // self.subroutine_per_group_threshold
if group_idx != 0:
return self._cg.module_function_invocation(
self._format_group_name(group_idx),
function_name,
self._feature_array_name), kwargs

return self._cg.function_invocation(function_name, self._feature_array_name), kwargs

kwargs['ast_size_check_counter'] = ast_size_check_counter
Expand Down Expand Up @@ -191,6 +222,10 @@ def _get_subroutine_name(self):
self._subroutine_idx += 1
return subroutine_name

@staticmethod
def _format_group_name(group_idx):
return f"SubroutineGroup{group_idx}"

# Methods to be implemented by subclasses.

def create_code_generator(self):
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/r/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ def array_index_access(self, array_name, index):

def vector_init(self, values):
return f"c({', '.join(values)})"

def module_definition(self, module_name):
raise NotImplementedError("Modules in r is not supported")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since "Modules" is plural I suggest to you use are not supported.


def module_function_invocation(self, module_name, function_name, *args):
raise NotImplementedError("Modules in r is not supported")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

2 changes: 1 addition & 1 deletion m2cgen/interpreters/r/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, indent=4, function_name="score", *args, **kwargs):
def interpret(self, expr):
top_cg = self.create_code_generator()

self.enqueue_subroutine(self.function_name, expr)
self.enqueue_subroutine(self.function_name, 0, expr)
self.process_subroutine_queue(top_cg)

current_dir = Path(__file__).absolute().parent
Expand Down
5 changes: 5 additions & 0 deletions m2cgen/interpreters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ def _normalize_expr_name(name):

def format_float(value):
return np.format_float_positional(value, unique=True, trim="0")


def chunks(arr, n):
for i in range(0, len(arr), n):
yield arr[i:i + n]
12 changes: 12 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
XGBOOST_PARAMS_BOOSTED_RF = dict(base_score=0.6, n_estimators=5, num_parallel_tree=3, subsample=0.8,
colsample_bynode=0.8, learning_rate=1.0, random_state=RANDOM_SEED)
XGBOOST_PARAMS_LARGE = dict(base_score=0.6, n_estimators=100, max_depth=12, random_state=RANDOM_SEED)
XGBOOST_PARAMS_HUGE = dict(base_score=0.6, n_estimators=500, max_depth=12, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_DART = dict(n_estimators=10, boosting_type='dart', max_drop=30, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_GOSS = dict(n_estimators=10, boosting_type='goss',
Expand All @@ -142,6 +143,7 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
subsample=0.7, subsample_freq=1, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_EXTRA_TREES = dict(n_estimators=10, extra_trees=True, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_LARGE = dict(n_estimators=100, num_leaves=100, max_depth=64, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_HUGE = dict(n_estimators=500, num_leaves=100, max_depth=64, random_state=RANDOM_SEED)
SVC_PARAMS = dict(random_state=RANDOM_SEED, decision_function_shape="ovo")
STATSMODELS_LINEAR_REGULARIZED_PARAMS = dict(method="elastic_net", alpha=7, L1_wt=0.2)

Expand Down Expand Up @@ -204,6 +206,11 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
classification_random_w_missing_values(lgb.LGBMClassifier(**LIGHTGBM_PARAMS)),
classification_binary_random_w_missing_values(lgb.LGBMClassifier(**LIGHTGBM_PARAMS)),

# LightGBM (Huge Trees)
regression_random(lgb.LGBMRegressor(**LIGHTGBM_PARAMS_HUGE)),
classification_random(lgb.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),
classification_binary_random(lgb.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),

# LightGBM (Different Objectives)
regression(lgb.LGBMRegressor(**LIGHTGBM_PARAMS, objective="mse", reg_sqrt=True)),
regression(lgb.LGBMRegressor(**LIGHTGBM_PARAMS, objective="mae")),
Expand Down Expand Up @@ -254,6 +261,11 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
classification_random(xgb.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
classification_binary_random(xgb.XGBClassifier(**XGBOOST_PARAMS_LARGE)),

# XGBoost (Huge Trees)
regression_random(xgb.XGBRegressor(**XGBOOST_PARAMS_HUGE)),
classification_random(xgb.XGBClassifier(**XGBOOST_PARAMS_HUGE)),
classification_binary_random(xgb.XGBClassifier(**XGBOOST_PARAMS_HUGE)),

# Sklearn Linear SVM
regression(svm.LinearSVR(random_state=RANDOM_SEED)),
classification(svm.LinearSVC(random_state=RANDOM_SEED)),
Expand Down