Skip to content

Commit

Permalink
Merge pull request #54 from antoinedemathelin/master
Browse files Browse the repository at this point in the history
fix: change prediction in TrAdaBoost
  • Loading branch information
antoinedemathelin committed Apr 29, 2022
2 parents 8fd2bbd + b496e89 commit 065f6a2
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions adapt/instance_based/_tradaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,12 +304,16 @@ def _boost(self, iboost, Xs, ys, Xt, yt,
warm_start=False,
**fit_params)

if hasattr(estimator, "predict_proba"):
ys_pred = estimator.predict_proba(Xs)
yt_pred = estimator.predict_proba(Xt)
elif hasattr(estimator, "_predict_proba_lr"):
ys_pred = estimator._predict_proba_lr(Xs)
yt_pred = estimator._predict_proba_lr(Xt)
if not isinstance(self, TrAdaBoostR2) and isinstance(estimator, BaseEstimator):
if hasattr(estimator, "predict_proba"):
ys_pred = estimator.predict_proba(Xs)
yt_pred = estimator.predict_proba(Xt)
elif hasattr(estimator, "_predict_proba_lr"):
ys_pred = estimator._predict_proba_lr(Xs)
yt_pred = estimator._predict_proba_lr(Xt)
else:
ys_pred = estimator.predict(Xs)
yt_pred = estimator.predict(Xt)
else:
ys_pred = estimator.predict(Xs)
yt_pred = estimator.predict(Xt)
Expand Down

0 comments on commit 065f6a2

Please sign in to comment.