Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create binary and multiclass objective classes (#504)
* creating new binary / multiclass variants of pipelines, duplicating code for now * moving common fxns back to pipeline base * more moving around fxns in pipelines classes * capping xgboost * fixing typo * more cleanup, making predict_proba standard regardless of binary/multiclass * renaming other_objectives to objectives * updating score's objective parameter to calculate all objectives, not just additional * removing self.objective for scoring * removing objectives from pipeline initialization, adding objective as predict param * remove xgboost cap from branch * changelog * capping xgboost on local branch since tests timing out * cleaning up * more cleanup * reverting requirements file * adding classification pipeline subclass, cleaning up via PR comments * more cleanup for docstrings * more cleanup of changelog and comments * putting tests in subfolders and adding few more tests * Update dependencies (#412) * Update latest dependencies * Hide features with zero importance in plot by default (#413) * adding functionality and test * changelog and adding boolean param * Update dependencies check: package whitelist (#417) * Add a whitelist for update_deps check * Remove from expected * Update deps * Changelog * adding skeleton for subclasses * fixing test and linting * updating change from master * fixing fixture * cherry picked wip remove ROC and confusion matrix * fixing merge * fixing merge * cleaning up * make test use static attribute instead of instance' * deleting needs_fitting * updating code to use new objective classes, still broken * updating threshold, still need to clean up tests * comment out for now * more cleanup * cleaning up * more cleanup * still more cleanup * fixing plot unsuccessful merging * more cleanup but still some things to work out * cleaning up using multiclass objectives for binary classification problems * fixing typo with recall and cleanup * cleaning up * adding default * some more cleaning up * removing irrelevant test * forgot to add attribute, breaking things again * cleanup and change objective of test * removing objective from predict * more cleanup :d * remove unused attribute * cleaning up via comments * more comments * changelog * order of decorators changed * fixing copy and paste err * update random state for binary class pipelines * updating objective * typo * fixing? * fixing imports * fixing tests * adding objective as parameter for predict, removing for fit * cleaning up test * more fixing test * minor linting, need more to go * more cleanup * forgot to fix test * more merging :x * starting to add tuning logic to automl * changelog * cleaning up * change conditional for objective split * cleaning up docstrings * forgot to use classificationobjective class... * add additional cond * adding tests * cleanup * removing decision function for multiclass * updating via comments * removing classification_objective file * add test + more updates * use cls instead for pep8 standards * updating can_optimize to property * update score * fix tests * minor cleanup from comments * updating predict behavior * add separate objective check * fixing some merge conflicts cont * add fraud test * patching * remove old test * updating for now * add another test * add more tests * adding test structure, still need to fix * adding test * fix iloc * fix tests * fix import * fix test? * removing can_optimize_threshold * linting * update docs a little * remove accuracy * add more doc fixes * move binary and multi pipelines in api ref * revert components notebook * updating from comments * oops, fix none set * update docstring * update api ref? * addressing comments * revert and update * update docstring * updating docstrings and lint * updating unnecessary call to constructor * pushing empty commit to refresh Co-authored-by: Jeremy Shih <jeremyliweishih@gmail.com> Co-authored-by: Dylan Sherry <sharshofski@gmail.com>
- Loading branch information
1 parent
06bc5a8
commit 55d737a
Showing
41 changed files
with
788 additions
and
526 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import pandas as pd | ||
from scipy.optimize import minimize_scalar | ||
|
||
from .objective_base import ObjectiveBase | ||
|
||
from evalml.problem_types import ProblemTypes | ||
|
||
|
||
class BinaryClassificationObjective(ObjectiveBase): | ||
""" | ||
Base class for all binary classification objectives. | ||
problem_type (ProblemTypes): Specifies the type of problem this objective is defined for (binary classification) | ||
can_optimize_threshold (bool): Determines if threshold used by objective can be optimized or not. | ||
""" | ||
problem_type = ProblemTypes.BINARY | ||
|
||
@property | ||
def can_optimize_threshold(cls): | ||
"""Returns a boolean determining if we can optimize the binary classification objective threshold. This will be false for any objective that works directly with predicted probabilities, like log loss and AUC. Otherwise, it will be true.""" | ||
return not cls.score_needs_proba | ||
|
||
def optimize_threshold(self, ypred_proba, y_true, X=None): | ||
"""Learn a binary classification threshold which optimizes the current objective. | ||
Arguments: | ||
ypred_proba (list): The classifier's predicted probabilities | ||
y_true (list): The ground truth for the predictions. | ||
X (pd.DataFrame, optional): Any extra columns that are needed from training data. | ||
Returns: | ||
Optimal threshold for this objective | ||
""" | ||
if not self.can_optimize_threshold: | ||
raise RuntimeError("Trying to optimize objective that can't be optimized!") | ||
|
||
def cost(threshold): | ||
predictions = self.decision_function(ypred_proba=ypred_proba, threshold=threshold, X=X) | ||
cost = self.objective_function(predictions, y_true, X=X) | ||
return -cost if self.greater_is_better else cost | ||
|
||
optimal = minimize_scalar(cost, method='Golden', options={"maxiter": 100}) | ||
return optimal.x | ||
|
||
def decision_function(self, ypred_proba, threshold=0.5, X=None): | ||
"""Apply a learned threshold to predicted probabilities to get predicted classes. | ||
Arguments: | ||
ypred_proba (list): The classifier's predicted probabilities | ||
threshold (float, optional): Threshold used to make a prediction. Defaults to 0.5. | ||
X (pd.DataFrame, optional): Any extra columns that are needed from training data. | ||
Returns: | ||
predictions | ||
""" | ||
if not isinstance(ypred_proba, pd.Series): | ||
ypred_proba = pd.Series(ypred_proba) | ||
return ypred_proba > threshold |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.