Skip to content

Commit

Permalink
[MRG+1] Fix estimators to work if sample_weight parameter is pandas S…
Browse files Browse the repository at this point in the history
…eries type (scikit-learn#7825)

* addressed comments in the PR about parameters in check_array

* update the test case for the evaluation of estimators with pandas series

* bug fix, need to check for *not* None explicitly

* updated with isinstance check if the documentation says there is acceptance of floats

* ran pep8 linter on modified files

* moving the test case to estimators_check

* add a predict function into the testing pandas.Series class

* avoid running anything beyond the newly added meta checks

* check if pandas is installed before running the specific test

* changed the order of the try-catch to check for sample_weight param beforehand

* pass on import error rather than printing something to std out

* improve test case naming and pd.Series check in the bad estimator class

* address a pep8 linter error with unused import

* pep8 warning disabled for potential unused import

* throw a warning when SkipTest is raised

* add a SkipTestWarning

* updated the whats_new.rst with this issue

* rebase and fix a spacing issue
  • Loading branch information
kathyxchen authored and Sundrique committed Jun 14, 2017
1 parent 6d05b1a commit 383bf8e
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 4 deletions.
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -104,6 +104,10 @@ Bug fixes
sparse array X and initial centroids, where X's means were unnecessarily
being subtracted from the centroids. :issue:`7872` by `Josh Karnofsky <https://github.com/jkarno>`_.

- Fix estimators to accept a ``sample_weight`` parameter of type
``pandas.Series`` in their ``fit`` function. :issue:`7825` by
`Kathleen Chen`_.

.. _changes_0_18_1:

Version 0.18.1
Expand Down Expand Up @@ -4824,3 +4828,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Srivatsan Ramesh: https://github.com/srivatsan-ramesh

.. _Ron Weiss: http://www.ee.columbia.edu/~ronw

.. _Kathleen Chen: https://github.com/kchen17
1 change: 1 addition & 0 deletions sklearn/ensemble/weight_boosting.py
Expand Up @@ -116,6 +116,7 @@ def fit(self, X, y, sample_weight=None):
sample_weight = np.empty(X.shape[0], dtype=np.float64)
sample_weight[:] = 1. / X.shape[0]
else:
sample_weight = check_array(sample_weight, ensure_2d=False)
# Normalize existing weights
sample_weight = sample_weight / sample_weight.sum(dtype=np.float64)

Expand Down
10 changes: 10 additions & 0 deletions sklearn/exceptions.py
Expand Up @@ -11,6 +11,7 @@
'EfficiencyWarning',
'FitFailedWarning',
'NonBLASDotWarning',
'SkipTestWarning',
'UndefinedMetricWarning']


Expand Down Expand Up @@ -138,6 +139,15 @@ class NonBLASDotWarning(EfficiencyWarning):
"""


class SkipTestWarning(UserWarning):
"""Warning class used to notify the user of a test that was skipped.
For example, one of the estimator checks requires a pandas import.
If the pandas package cannot be imported, the test will be skipped rather
than register as a failure.
"""


class UndefinedMetricWarning(UserWarning):
"""Warning used when the metric is invalid
Expand Down
6 changes: 4 additions & 2 deletions sklearn/kernel_ridge.py
Expand Up @@ -9,7 +9,7 @@
from .base import BaseEstimator, RegressorMixin
from .metrics.pairwise import pairwise_kernels
from .linear_model.ridge import _solve_cholesky_kernel
from .utils import check_X_y
from .utils import check_array, check_X_y
from .utils.validation import check_is_fitted


Expand Down Expand Up @@ -135,7 +135,7 @@ def fit(self, X, y=None, sample_weight=None):
y : array-like, shape = [n_samples] or [n_samples, n_targets]
Target values
sample_weight : float or numpy array of shape [n_samples]
sample_weight : float or array-like of shape [n_samples]
Individual weights for each sample, ignored if None is passed.
Returns
Expand All @@ -145,6 +145,8 @@ def fit(self, X, y=None, sample_weight=None):
# Convert data
X, y = check_X_y(X, y, accept_sparse=("csr", "csc"), multi_output=True,
y_numeric=True)
if sample_weight is not None and not isinstance(sample_weight, float):
sample_weight = check_array(sample_weight, ensure_2d=False)

K = self._get_kernel(X)
alpha = np.atleast_1d(self.alpha)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/linear_model/ridge.py
Expand Up @@ -957,6 +957,8 @@ def fit(self, X, y, sample_weight=None):
"""
X, y = check_X_y(X, y, ['csr', 'csc', 'coo'], dtype=np.float64,
multi_output=True, y_numeric=True)
if sample_weight is not None and not isinstance(sample_weight, float):
sample_weight = check_array(sample_weight, ensure_2d=False)
n_samples, n_features = X.shape

X, y, X_offset, y_offset, X_scale = LinearModel._preprocess_data(
Expand Down
33 changes: 31 additions & 2 deletions sklearn/utils/estimator_checks.py
Expand Up @@ -45,10 +45,12 @@
from sklearn.decomposition import NMF, ProjectedGradientNMF
from sklearn.exceptions import ConvergenceWarning
from sklearn.exceptions import DataConversionWarning
from sklearn.exceptions import SkipTestWarning
from sklearn.model_selection import train_test_split

from sklearn.utils import shuffle
from sklearn.utils.fixes import signature
from sklearn.utils.validation import has_fit_parameter
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import load_iris, load_boston, make_blobs

Expand Down Expand Up @@ -80,6 +82,7 @@ def _yield_non_meta_checks(name, Estimator):
yield check_estimators_dtypes
yield check_fit_score_takes_y
yield check_dtype_object
yield check_sample_weights_pandas_series
yield check_estimators_fit_returns_self

# Check that all estimator yield informative messages when
Expand Down Expand Up @@ -198,7 +201,6 @@ def _yield_transformer_checks(name, Transformer):
yield check_transformer_n_iter



def _yield_clustering_checks(name, Clusterer):
yield check_clusterer_compute_labels_predict
if name not in ('WardAgglomeration', "FeatureAgglomeration"):
Expand Down Expand Up @@ -252,7 +254,12 @@ def check_estimator(Estimator):
name = Estimator.__name__
check_parameters_default_constructible(name, Estimator)
for check in _yield_all_checks(name, Estimator):
check(name, Estimator)
try:
check(name, Estimator)
except SkipTest as message:
# the only SkipTest thrown currently results from not
# being able to import pandas.
warnings.warn(message, SkipTestWarning)


def _boston_subset(n_samples=200):
Expand Down Expand Up @@ -381,6 +388,28 @@ def check_estimator_sparse_data(name, Estimator):
raise


@ignore_warnings(category=DeprecationWarning)
def check_sample_weights_pandas_series(name, Estimator):
# check that estimators will accept a 'sample_weight' parameter of
# type pandas.Series in the 'fit' function.
estimator = Estimator()
if has_fit_parameter(estimator, "sample_weight"):
try:
import pandas as pd
X = pd.DataFrame([[1, 1], [1, 2], [1, 3], [2, 1], [2, 2], [2, 3]])
y = pd.Series([1, 1, 1, 2, 2, 2])
weights = pd.Series([1] * 6)
try:
estimator.fit(X, y, sample_weight=weights)
except ValueError:
raise ValueError("Estimator {0} raises error if "
"'sample_weight' parameter is of "
"type pandas.Series".format(name))
except ImportError:
raise SkipTest("pandas is not installed: not testing for "
"input of type pandas.Series to class weight.")


@ignore_warnings(category=(DeprecationWarning, UserWarning))
def check_dtype_object(name, Estimator):
# check that estimators treat dtype object as numeric if possible
Expand Down
28 changes: 28 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Expand Up @@ -73,6 +73,25 @@ def predict(self, X):
return np.ones(X.shape[0])


class NoSampleWeightPandasSeriesType(BaseEstimator):
def fit(self, X, y, sample_weight=None):
# Convert data
X, y = check_X_y(X, y,
accept_sparse=("csr", "csc"),
multi_output=True,
y_numeric=True)
# Function is only called after we verify that pandas is installed
from pandas import Series
if isinstance(sample_weight, Series):
raise ValueError("Estimator does not accept 'sample_weight'"
"of type pandas.Series")
return self

def predict(self, X):
X = check_array(X)
return np.ones(X.shape[0])


def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.
Expand All @@ -86,6 +105,15 @@ def test_check_estimator():
# check that fit does input validation
msg = "TypeError not raised"
assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier)
# check that sample_weights in fit accepts pandas.Series type
try:
from pandas import Series # noqa
msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
"'sample_weight' parameter is of type pandas.Series")
assert_raises_regex(
ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType)
except ImportError:
pass
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator doesn't check for NaN and inf in predict"
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
Expand Down

0 comments on commit 383bf8e

Please sign in to comment.