Skip to content

Commit

Permalink
Fix java export for huge models
Browse files Browse the repository at this point in the history
  • Loading branch information
Aulust committed Jun 24, 2022
1 parent c3664db commit c78a10d
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 8 deletions.
10 changes: 10 additions & 0 deletions m2cgen/interpreters/java/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def class_definition(self, class_name):
yield
self.add_block_termination()

@contextmanager
def module_definition(self, module_name):
self.add_class_def(module_name, modifier="private static")
yield
self.add_block_termination()

def module_function_invocation(self, module_name, function_name, *args):
invocation_code = self.function_invocation(function_name, *args)
return f"{module_name}.{invocation_code}"

@contextmanager
def method_definition(self, name, args, is_vector_output,
modifier="public"):
Expand Down
3 changes: 2 additions & 1 deletion m2cgen/interpreters/java/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
# to adjustments in future.
ast_size_check_frequency = 100
ast_size_per_subroutine_threshold = 4600
subroutine_per_group_threshold = 15

supported_bin_vector_ops = {
BinNumOpType.ADD: "addVectors",
Expand Down Expand Up @@ -62,7 +63,7 @@ def interpret(self, expr):
# Since we use SubroutinesMixin, we already have logic
# of adding methods. We create first subroutine for incoming
# expression and call `process_subroutine_queue` method.
self.enqueue_subroutine(self.function_name, expr)
self.enqueue_subroutine(self.function_name, 0, expr)
self.process_subroutine_queue(top_cg)

current_dir = Path(__file__).absolute().parent
Expand Down
47 changes: 41 additions & 6 deletions m2cgen/interpreters/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from m2cgen import ast
from m2cgen.interpreters.interpreter import BaseToCodeInterpreter
from m2cgen.interpreters.utils import chunks


class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter):
Expand Down Expand Up @@ -89,7 +90,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(), **kwargs):
*extra_func_args)


Subroutine = namedtuple('Subroutine', ['name', 'expr'])
Subroutine = namedtuple('Subroutine', ['name', 'idx', 'expr'])


class SubroutinesMixin(BaseToCodeInterpreter):
Expand All @@ -99,9 +100,11 @@ class SubroutinesMixin(BaseToCodeInterpreter):
Subclasses only need to implement `create_code_generator` method.
Their code generators should implement 3 methods:
Their code generators should implement 5 methods:
- function_definition;
- function_invocation;
- module_definition;
- module_function_invocation;
- add_return_statement.
Interpreter should prepare at least one subroutine using method
Expand All @@ -112,6 +115,7 @@ class SubroutinesMixin(BaseToCodeInterpreter):
# disabled by default
ast_size_check_frequency = sys.maxsize
ast_size_per_subroutine_threshold = sys.maxsize
subroutine_per_group_threshold = sys.maxsize

def __init__(self, *args, **kwargs):
self._subroutine_idx = 0
Expand All @@ -124,15 +128,33 @@ def process_subroutine_queue(self, top_code_generator):
subroutine queue.
"""
self._subroutine_idx = 0
subroutines = []

while len(self.subroutine_expr_queue):
while self.subroutine_expr_queue:
self._reset_reused_expr_cache()
subroutine = self.subroutine_expr_queue.pop(0)
subroutine_code = self._process_subroutine(subroutine)
subroutines.append((subroutine, subroutine_code))

subroutines.sort(key=lambda subroutine: subroutine[0].idx)

groups = chunks(subroutines, self.subroutine_per_group_threshold)
for _, subroutine_code in next(groups):
top_code_generator.add_code_lines(subroutine_code)

def enqueue_subroutine(self, name, expr):
self.subroutine_expr_queue.append(Subroutine(name, expr))
for index, subroutine_group in enumerate(groups):
cg = self.create_code_generator()

with cg.module_definition(
module_name=self._format_group_name(index + 1)):
for _, subroutine_code in subroutine_group:
cg.add_code_lines(subroutine_code)

top_code_generator.add_code_lines(
cg.finalize_and_get_generated_code())

def enqueue_subroutine(self, name, idx, expr):
self.subroutine_expr_queue.append(Subroutine(name, idx, expr))

def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
if isinstance(expr, ast.BinExpr) and not expr.to_reuse:
Expand All @@ -145,7 +167,16 @@ def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
ast_size = ast.count_exprs(expr)
if ast_size > self.ast_size_per_subroutine_threshold:
function_name = self._get_subroutine_name()
self.enqueue_subroutine(function_name, expr)

self.enqueue_subroutine(function_name, self._subroutine_idx, expr)

group_idx = self._subroutine_idx // self.subroutine_per_group_threshold
if group_idx != 0:
return self._cg.module_function_invocation(
self._format_group_name(group_idx),
function_name,
self._feature_array_name), kwargs

return self._cg.function_invocation(function_name, self._feature_array_name), kwargs

kwargs['ast_size_check_counter'] = ast_size_check_counter
Expand Down Expand Up @@ -191,6 +222,10 @@ def _get_subroutine_name(self):
self._subroutine_idx += 1
return subroutine_name

@staticmethod
def _format_group_name(group_idx):
return f"SubroutineGroup{group_idx}"

# Methods to be implemented by subclasses.

def create_code_generator(self):
Expand Down
6 changes: 6 additions & 0 deletions m2cgen/interpreters/r/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,9 @@ def array_index_access(self, array_name, index):

def vector_init(self, values):
return f"c({', '.join(values)})"

def module_definition(self, module_name):
raise NotImplementedError("Modules in r is not supported")

def module_function_invocation(self, module_name, function_name, *args):
raise NotImplementedError("Modules in r is not supported")
2 changes: 1 addition & 1 deletion m2cgen/interpreters/r/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, indent=4, function_name="score", *args, **kwargs):
def interpret(self, expr):
top_cg = self.create_code_generator()

self.enqueue_subroutine(self.function_name, expr)
self.enqueue_subroutine(self.function_name, 0, expr)
self.process_subroutine_queue(top_cg)

current_dir = Path(__file__).absolute().parent
Expand Down
5 changes: 5 additions & 0 deletions m2cgen/interpreters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ def _normalize_expr_name(name):

def format_float(value):
return np.format_float_positional(value, unique=True, trim="0")


def chunks(arr, n):
for i in range(0, len(arr), n):
yield arr[i:i + n]
18 changes: 18 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
XGBOOST_PARAMS_BOOSTED_RF = dict(base_score=0.6, n_estimators=5, num_parallel_tree=3, subsample=0.8,
colsample_bynode=0.8, learning_rate=1.0, random_state=RANDOM_SEED)
XGBOOST_PARAMS_LARGE = dict(base_score=0.6, n_estimators=100, max_depth=12, random_state=RANDOM_SEED)
XGBOOST_PARAMS_HUGE = dict(base_score=0.6, n_estimators=500, max_depth=12, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_DART = dict(n_estimators=10, boosting_type='dart', max_drop=30, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_GOSS = dict(n_estimators=10, boosting_type='goss',
Expand All @@ -142,6 +143,7 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
subsample=0.7, subsample_freq=1, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_EXTRA_TREES = dict(n_estimators=10, extra_trees=True, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_LARGE = dict(n_estimators=100, num_leaves=100, max_depth=64, random_state=RANDOM_SEED)
LIGHTGBM_PARAMS_HUGE = dict(n_estimators=500, num_leaves=100, max_depth=64, random_state=RANDOM_SEED)
SVC_PARAMS = dict(random_state=RANDOM_SEED, decision_function_shape="ovo")
STATSMODELS_LINEAR_REGULARIZED_PARAMS = dict(method="elastic_net", alpha=7, L1_wt=0.2)

Expand Down Expand Up @@ -204,6 +206,14 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
classification_random_w_missing_values(lgb.LGBMClassifier(**LIGHTGBM_PARAMS)),
classification_binary_random_w_missing_values(lgb.LGBMClassifier(**LIGHTGBM_PARAMS)),
# LightGBM (Huge Trees)
regression_random(
lightgbm.LGBMRegressor(**LIGHTGBM_PARAMS_HUGE)),
classification_random(
lightgbm.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),
classification_binary_random(
lightgbm.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),
# LightGBM (Different Objectives)
regression(lgb.LGBMRegressor(**LIGHTGBM_PARAMS, objective="mse", reg_sqrt=True)),
regression(lgb.LGBMRegressor(**LIGHTGBM_PARAMS, objective="mae")),
Expand Down Expand Up @@ -254,6 +264,14 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
classification_random(xgb.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
classification_binary_random(xgb.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
# XGBoost (Huge Trees)
regression_random(
xgboost.XGBRegressor(**XGBOOST_PARAMS_HUGE)),
classification_random(
xgboost.XGBClassifier(**XGBOOST_PARAMS_HUGE)),
classification_binary_random(
xgboost.XGBClassifier(**XGBOOST_PARAMS_HUGE)),
# Sklearn Linear SVM
regression(svm.LinearSVR(random_state=RANDOM_SEED)),
classification(svm.LinearSVC(random_state=RANDOM_SEED)),
Expand Down

0 comments on commit c78a10d

Please sign in to comment.