Skip to content

Commit

Permalink
Define leaves_cutoff_threshold as constant (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS authored and izeigerman committed Jan 19, 2020
1 parent 7375b15 commit 43a189c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
from m2cgen.assemblers.base import ModelAssembler


LEAVES_CUTOFF_THRESHOLD = 3000


class BaseBoostingAssembler(ModelAssembler):

classifier_name = None

def __init__(self, model, trees, base_score=0, tree_limit=None,
leaves_cutoff_threshold=3000):
leaves_cutoff_threshold=LEAVES_CUTOFF_THRESHOLD):
super().__init__(model)
self.all_trees = trees
self._base_score = base_score
Expand Down Expand Up @@ -130,7 +133,8 @@ class XGBoostModelAssembler(BaseBoostingAssembler):

classifier_name = "XGBClassifier"

def __init__(self, model, leaves_cutoff_threshold=3000):
def __init__(self, model,
leaves_cutoff_threshold=LEAVES_CUTOFF_THRESHOLD):
feature_names = model.get_booster().feature_names
self._feature_name_to_idx = {
name: idx for idx, name in enumerate(feature_names or [])
Expand Down Expand Up @@ -198,7 +202,8 @@ class LightGBMModelAssembler(BaseBoostingAssembler):

classifier_name = "LGBMClassifier"

def __init__(self, model, leaves_cutoff_threshold=3000):
def __init__(self, model,
leaves_cutoff_threshold=LEAVES_CUTOFF_THRESHOLD):
model_dump = model.booster_.dump_model()
trees = [m["tree_structure"] for m in model_dump["tree_info"]]
self.n_iter = len(trees) // model_dump["num_tree_per_iteration"]
Expand Down

0 comments on commit 43a189c

Please sign in to comment.