Skip to content

Commit

Permalink
propagate classes_ in metaestimators
Browse files Browse the repository at this point in the history
Signed-off-by: Samuel Hoffman <hoffman.sc@gmail.com>
  • Loading branch information
hoffmansc committed Jul 1, 2022
1 parent 2effc0b commit 35f87a7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
6 changes: 6 additions & 0 deletions aif360/sklearn/postprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class PostProcessingMeta(BaseEstimator, MetaEstimatorMixin):
Attributes:
estimator_: Fitted estimator.
postprocessor_: Fitted postprocessor.
classes_ (array, shape (n_classes,)): Class labels from `estimator_`.
"""

def __init__(self, estimator, postprocessor, *, prefit=False, val_size=0.25,
Expand Down Expand Up @@ -57,6 +58,11 @@ def __init__(self, estimator, postprocessor, *, prefit=False, val_size=0.25,
def _estimator_type(self):
return self.postprocessor._estimator_type

@property
def classes_(self):
"""Class labels from the base estimator."""
return self.estimator_.classes_

def fit(self, X, y, sample_weight=None, **fit_params):
"""Splits the training samples with
:func:`~sklearn.model_selection.train_test_split` and uses the resultant
Expand Down
6 changes: 6 additions & 0 deletions aif360/sklearn/preprocessing/reweighing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ class ReweighingMeta(BaseEstimator, MetaEstimatorMixin):
Attributes:
estimator_ (sklearn.BaseEstimator): The fitted underlying estimator.
reweigher_: The fitted underlying reweigher.
classes_ (array, shape (n_classes,)): Class labels from `estimator_`.
"""
def __init__(self, estimator, reweigher=None):
"""
Expand All @@ -119,6 +120,11 @@ def __init__(self, estimator, reweigher=None):
def _estimator_type(self):
return self.estimator._estimator_type

@property
def classes_(self):
"""Class labels from the base estimator."""
return self.estimator_.classes_

def fit(self, X, y, sample_weight=None):
"""Performs ``self.reweigher_.fit_transform(X, y, sample_weight)`` and
then ``self.estimator_.fit(X, y, sample_weight)`` using the reweighed
Expand Down

0 comments on commit 35f87a7

Please sign in to comment.