Skip to content

Commit

Permalink
fixed sample_weight=None bug and classes_ typo
Browse files Browse the repository at this point in the history
  • Loading branch information
hoffmansc committed Feb 5, 2020
1 parent 042bb12 commit 5ded679
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions aif360/sklearn/postprocessing/calibrated_equalized_odds.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def __init__(self, prot_attr=None, cost_constraint='weighted',
self.cost_constraint = cost_constraint
self.random_state = random_state

def _weighted_cost(self, y_true, probas_pred, pos_label, sample_weight):
def _weighted_cost(self, y_true, probas_pred, sample_weight=None):
"""Evaluates the cost function specified by ``self.cost_constraint``."""
fpr = generalized_fpr(y_true, probas_pred, pos_label, sample_weight)
fnr = generalized_fnr(y_true, probas_pred, pos_label, sample_weight)
br = base_rate(y_true, probas_pred, pos_label, sample_weight)
fpr = generalized_fpr(y_true, probas_pred, self.pos_label_, sample_weight)
fnr = generalized_fnr(y_true, probas_pred, self.pos_label_, sample_weight)
br = base_rate(y_true, probas_pred, self.pos_label_, sample_weight)
if self.cost_constraint == 'fpr':
return fpr
elif self.cost_constraint == 'fnr':
Expand Down Expand Up @@ -117,7 +117,7 @@ def _args(grp_idx, triv=False):
idx = (groups == self.groups_[grp_idx])
pred = (np.full_like(y_pred, self.base_rates_[grp_idx]) if triv else
y_pred)
return [y_true[idx], pred[idx], pos_label, sample_weight[idx]]
return [y_true[idx], pred[idx], sample_weight[idx]]

self.base_rates_ = [base_rate(*_args(i)) for i in range(2)]

Expand Down Expand Up @@ -178,7 +178,7 @@ def predict(self, y_pred):
numpy.ndarray: Predicted class label per sample.
"""
scores = self.predict_proba(y_pred)
return self.classes[scores.argmax(axis=1)]
return self.classes_[scores.argmax(axis=1)]

def score(self, y_pred, y_true, sample_weight=None):
"""Score the predictions according to the cost constraint specified.
Expand Down

0 comments on commit 5ded679

Please sign in to comment.