Skip to content

Commit

Permalink
added clean representation of objective
Browse files Browse the repository at this point in the history
  • Loading branch information
CDonnerer committed Sep 12, 2021
1 parent d5c38a4 commit adade2c
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/xgboost_distribution/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def fit(
self._distribution.check_target(y)

params = self.get_xgb_params()
params["objective"] = None
params["disable_default_eval_metric"] = True
params["num_class"] = len(self._distribution.params)

Expand Down Expand Up @@ -136,7 +137,7 @@ def fit(
evals_result = {}
model, _, params = self._configure_fit(xgb_model, None, params)

# hack to suppress warnings from the extra distribution parameter
# Suppress warnings from unexpected distribution & natural_gradient params
with config_context(verbosity=0):
self._Booster = train(
params,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def test_XGBDistribution_early_stopping_predict(small_train_test_data):
np.testing.assert_array_equal(var, var_iter)


def test_objective_overwrite_by_repeated_fit(small_X_y_data):
"""In the fit step, we set the objective to distribution:normal (e.g.)
This needs to be re-set at each fit, as it's not a standard xgboost objective.
"""
X, y = small_X_y_data
model = XGBDistribution()
model.fit(X, y)
model.fit(X, y)


def test_distribution_set_param(small_X_y_data):
"""Check that updating the distribution params works"""
X, y = small_X_y_data
Expand Down

0 comments on commit adade2c

Please sign in to comment.