Skip to content

Commit

Permalink
Merge d409632 into 024b61e
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed May 26, 2020
2 parents 024b61e + d409632 commit 5f49dbb
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pip install m2cgen
| **SVM** | <ul><li>scikit-learn<ul><li>LinearSVC</li><li>NuSVC</li><li>SVC</li></ul></li><li>lightning<ul><li>KernelSVC</li><li>LinearSVC</li></ul></li></ul> | <ul><li>scikit-learn<ul><li>LinearSVR</li><li>NuSVR</li><li>SVR</li></ul></li><li>lightning<ul><li>LinearSVR</li></ul></li></ul> |
| **Tree** | <ul><li>DecisionTreeClassifier</li><li>ExtraTreeClassifier</li></ul> | <ul><li>DecisionTreeRegressor</li><li>ExtraTreeRegressor</li></ul> |
| **Random Forest** | <ul><li>ExtraTreesClassifier</li><li>LGBMClassifier(rf booster only)</li><li>RandomForestClassifier</li><li>XGBRFClassifier</li></ul> | <ul><li>ExtraTreesRegressor</li><li>LGBMRegressor(rf booster only)</li><li>RandomForestRegressor</li><li>XGBRFRegressor</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree/gblinear booster only)</li></ul> |
| **Boosting** | <ul><li>LGBMClassifier(gbdt/dart/goss booster only)</li><li>XGBClassifier(gbtree(including boosted forests)/gblinear booster only)</li><ul> | <ul><li>LGBMRegressor(gbdt/dart/goss booster only)</li><li>XGBRegressor(gbtree(including boosted forests)/gblinear booster only)</li></ul> |

You can find versions of packages with which compatibility is guaranteed by CI tests [here](https://github.com/BayesWitnesses/m2cgen/blob/master/requirements-test.txt#L1).
Other versions can also be supported but they are untested.
Expand Down
33 changes: 19 additions & 14 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class BaseBoostingAssembler(ModelAssembler):

classifier_names = {}
strided_layout_for_multiclass = True
multiclass_params_seq_len = 1

def __init__(self, model, estimator_params, base_score=0):
super().__init__(model)
Expand Down Expand Up @@ -54,9 +54,10 @@ def _assemble_single_output(self, estimator_params,
def _assemble_multi_class_output(self, estimator_params):
# Multi-class output is calculated based on discussion in
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-295962863
# and the enhancement to support boosted forests in XGBoost.
splits = _split_estimator_params_by_classes(
estimator_params, self._output_size,
self.strided_layout_for_multiclass)
self.multiclass_params_seq_len)

base_score = self._base_score
exprs = [
Expand Down Expand Up @@ -114,8 +115,8 @@ class XGBoostTreeModelAssembler(BaseTreeBoostingAssembler):
classifier_names = {"XGBClassifier", "XGBRFClassifier"}

def __init__(self, model):
if type(model).__name__ == "XGBRFClassifier":
self.strided_layout_for_multiclass = False
self.multiclass_params_seq_len = model.get_params().get(
"num_parallel_tree", 1)
feature_names = model.get_booster().feature_names
self._feature_name_to_idx = {
name: idx for idx, name in enumerate(feature_names or [])
Expand Down Expand Up @@ -244,13 +245,17 @@ def _assemble_tree(self, tree):
self._assemble_tree(false_child))


def _split_estimator_params_by_classes(values, n_classes, strided):
if strided:
# Splits are computed based on a comment
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-267400592.
return [values[class_idx::n_classes] for class_idx in range(n_classes)]
else:
values_len = len(values)
block_len = values_len // n_classes
return [values[start_block_idx:start_block_idx + block_len]
for start_block_idx in range(0, values_len, block_len)]
def _split_estimator_params_by_classes(values, n_classes, params_seq_len):
# Splits are computed based on a comment
# https://github.com/dmlc/xgboost/issues/1746#issuecomment-267400592
# and the enhancement to support boosted forests in XGBoost.
values_len = len(values)
block_len = n_classes * params_seq_len
indices = list(range(values_len))
indices_by_class = np.array(
[[indices[i:i + params_seq_len]
for i in range(j, values_len, block_len)]
for j in range(0, block_len, params_seq_len)]
).reshape(n_classes, -1)
return [[values[idx] for idx in class_idxs]
for class_idxs in indices_by_class]
10 changes: 10 additions & 0 deletions tests/e2e/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def regression_bounded(model, test_fraction=0.02):
random_state=RANDOM_SEED)
XGBOOST_PARAMS_RF = dict(base_score=0.6, n_estimators=10,
random_state=RANDOM_SEED)
XGBOOST_PARAMS_BOOSTED_RF = dict(base_score=0.6, n_estimators=5,
num_parallel_tree=3, subsample=0.8,
colsample_bynode=0.8, learning_rate=1.0,
random_state=RANDOM_SEED)
XGBOOST_PARAMS_LARGE = dict(base_score=0.6, n_estimators=100, max_depth=12,
random_state=RANDOM_SEED)
LIGHTGBM_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
Expand Down Expand Up @@ -203,6 +207,12 @@ def regression_bounded(model, test_fraction=0.02):
classification(xgboost.XGBRFClassifier(**XGBOOST_PARAMS_RF)),
classification_binary(xgboost.XGBRFClassifier(**XGBOOST_PARAMS_RF)),
# XGBoost (Boosted Random Forests)
regression(xgboost.XGBRegressor(**XGBOOST_PARAMS_BOOSTED_RF)),
classification(xgboost.XGBClassifier(**XGBOOST_PARAMS_BOOSTED_RF)),
classification_binary(
xgboost.XGBClassifier(**XGBOOST_PARAMS_BOOSTED_RF)),
# XGBoost (Large Trees)
regression_random(
xgboost.XGBRegressor(**XGBOOST_PARAMS_LARGE)),
Expand Down

0 comments on commit 5f49dbb

Please sign in to comment.