Skip to content

Commit

Permalink
ENH Add random_state parameter to AffinityPropagation (scikit-learn#1…
Browse files Browse the repository at this point in the history
…6801)

* Added value checks and random state parameter to method

* Changed default parameter to None instead of 0

* Added numpy RandomState to the check

* Replaced inline validation with check_random_state from utils and pointed at glossery

* Needed a different default parameter to pass the default way this has been working in the past

* Updated to conform with flake8 stds

* Add random_state to AffinityPropagation class.

* Add test.

* Add what's new entry and versionadded directive.

* Add PR number.

* Fix lint error due to this PR.

* Use np.array_equal in test.

* Update sklearn/cluster/_affinity_propagation.py

Co-Authored-By: Adrin Jalali <adrin.jalali@gmail.com>

* Homogenize parametre descriptions, default random_state to None.

* Update sklearn/cluster/_affinity_propagation.py

Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>

* Update sklearn/cluster/_affinity_propagation.py

Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>

* Update sklearn/cluster/_affinity_propagation.py

Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>

* Update doc/whats_new/v0.23.rst

Co-Authored-By: Nicolas Hug <contact@nicolas-hug.com>

* Change test name.

* Modify check in test.

* Fix lint errors.

* Address comment.

* Address comment.

* Add 'deprecation' and its correspondent test.

* Fix lint errors.

* Add random_state parameter to tests, to avoid FutureWarnings.

* Move warning in fit. Modify tests.

* Modify example.

* Tentative fix for failures.

* Document default value to 0. Revert docstring.

* Explicit link to Glossary.

* Fix default value.

* Remove some warnings from tests.

* Validate and test docstring.

* Tentative fix.

* Tentative fix.

* Ignore FutureWarning in fit attribute test.

* Set random_state to avoid FutureWarning in test_fit_docstring_attributes.

* [doc build] Force documentation build.

* Clarify warning message.

Co-authored-by: rwoolston.admin <rwoolston.admin@LXO-DS-DEV.afcucorp.local>
Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
  • Loading branch information
4 people committed May 5, 2020
1 parent fa8f498 commit 22a7d5b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 35 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new/v0.23.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ Changelog
deprecated. It has no effect. :pr:`11950` by
:user:`Jeremie du Boisberranger <jeremiedbb>`.

- |API| The ``random_state`` parameter has been added to
:class:`cluster.AffinityPropagation`. :pr:`16801` by :user:`rcwoolston`
and :user:`Chiara Marmo <cmarmo>`.

:mod:`sklearn.compose`
......................

Expand Down
70 changes: 48 additions & 22 deletions sklearn/cluster/_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..exceptions import ConvergenceWarning
from ..base import BaseEstimator, ClusterMixin
from ..utils import as_float_array, check_array
from ..utils import as_float_array, check_array, check_random_state
from ..utils.validation import check_is_fitted, _deprecate_positional_args
from ..metrics import euclidean_distances
from ..metrics import pairwise_distances_argmin
Expand All @@ -32,7 +32,7 @@ def all_equal_similarities():

def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
damping=0.5, copy=True, verbose=False,
return_n_iter=False):
return_n_iter=False, random_state='warn'):
"""Perform Affinity Propagation Clustering of data
Read more in the :ref:`User Guide <affinity_propagation>`.
Expand Down Expand Up @@ -72,6 +72,14 @@ def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
return_n_iter : bool, default False
Whether or not to return the number of iterations.
random_state : int or np.random.RandomStateInstance, default: 0
Pseudo-random number generator to control the starting state.
Use an int for reproducible results across function calls.
See the :term:`Glossary <random_state>`.
.. versionadded:: 0.23
this parameter was previously hardcoded as 0.
Returns
-------
Expand Down Expand Up @@ -133,7 +141,16 @@ def affinity_propagation(S, preference=None, convergence_iter=15, max_iter=200,
if return_n_iter
else (np.array([0]), np.array([0] * n_samples)))

random_state = np.random.RandomState(0)
if random_state == 'warn':
warnings.warn(("'random_state' has been introduced in 0.23. "
"It will be set to None starting from 0.25 which "
"means that results will differ at every function "
"call. Set 'random_state' to None to silence this "
"warning, or to 0 to keep the behavior of versions "
"<0.23."),
FutureWarning)
random_state = 0
random_state = check_random_state(random_state)

# Place preference on the diagonal of S
S.flat[::(n_samples + 1)] = preference
Expand Down Expand Up @@ -274,6 +291,13 @@ class AffinityPropagation(ClusterMixin, BaseEstimator):
verbose : bool, default=False
Whether to be verbose.
random_state : int or np.random.RandomStateInstance, default: 0
Pseudo-random number generator to control the starting state.
Use an int for reproducible results across function calls.
See the :term:`Glossary <random_state>`.
.. versionadded:: 0.23
this parameter was previously hardcoded as 0.
Attributes
----------
Expand All @@ -292,23 +316,6 @@ class AffinityPropagation(ClusterMixin, BaseEstimator):
n_iter_ : int
Number of iterations taken to converge.
Examples
--------
>>> from sklearn.cluster import AffinityPropagation
>>> import numpy as np
>>> X = np.array([[1, 2], [1, 4], [1, 0],
... [4, 2], [4, 4], [4, 0]])
>>> clustering = AffinityPropagation().fit(X)
>>> clustering
AffinityPropagation()
>>> clustering.labels_
array([0, 0, 0, 1, 1, 1])
>>> clustering.predict([[0, 0], [4, 4]])
array([0, 1])
>>> clustering.cluster_centers_
array([[1, 2],
[4, 2]])
Notes
-----
For an example, see :ref:`examples/cluster/plot_affinity_propagation.py
Expand All @@ -333,11 +340,28 @@ class AffinityPropagation(ClusterMixin, BaseEstimator):
Brendan J. Frey and Delbert Dueck, "Clustering by Passing Messages
Between Data Points", Science Feb. 2007
Examples
--------
>>> from sklearn.cluster import AffinityPropagation
>>> import numpy as np
>>> X = np.array([[1, 2], [1, 4], [1, 0],
... [4, 2], [4, 4], [4, 0]])
>>> clustering = AffinityPropagation(random_state=5).fit(X)
>>> clustering
AffinityPropagation(random_state=5)
>>> clustering.labels_
array([0, 0, 0, 1, 1, 1])
>>> clustering.predict([[0, 0], [4, 4]])
array([0, 1])
>>> clustering.cluster_centers_
array([[1, 2],
[4, 2]])
"""
@_deprecate_positional_args
def __init__(self, *, damping=.5, max_iter=200, convergence_iter=15,
copy=True, preference=None, affinity='euclidean',
verbose=False):
verbose=False, random_state='warn'):

self.damping = damping
self.max_iter = max_iter
Expand All @@ -346,6 +370,7 @@ def __init__(self, *, damping=.5, max_iter=200, convergence_iter=15,
self.verbose = verbose
self.preference = preference
self.affinity = affinity
self.random_state = random_state

@property
def _pairwise(self):
Expand Down Expand Up @@ -388,7 +413,8 @@ def fit(self, X, y=None):
affinity_propagation(
self.affinity_matrix_, self.preference, max_iter=self.max_iter,
convergence_iter=self.convergence_iter, damping=self.damping,
copy=self.copy, verbose=self.verbose, return_n_iter=True)
copy=self.copy, verbose=self.verbose, return_n_iter=True,
random_state=self.random_state)

if self.affinity != "precomputed":
self.cluster_centers_ = X[self.cluster_centers_indices_].copy()
Expand Down
62 changes: 49 additions & 13 deletions sklearn/cluster/tests/test_affinity_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,18 @@ def test_affinity_propagation():
preference = np.median(S) * 10
# Compute Affinity Propagation
cluster_centers_indices, labels = affinity_propagation(
S, preference=preference)
S, preference=preference, random_state=39)

n_clusters_ = len(cluster_centers_indices)

assert n_clusters == n_clusters_

af = AffinityPropagation(preference=preference, affinity="precomputed")
af = AffinityPropagation(preference=preference, affinity="precomputed",
random_state=28)
labels_precomputed = af.fit(S).labels_

af = AffinityPropagation(preference=preference, verbose=True)
af = AffinityPropagation(preference=preference, verbose=True,
random_state=37)
labels = af.fit(X).labels_

assert_array_equal(labels, labels_precomputed)
Expand All @@ -55,24 +57,24 @@ def test_affinity_propagation():

# Test also with no copy
_, labels_no_copy = affinity_propagation(S, preference=preference,
copy=False)
copy=False, random_state=74)
assert_array_equal(labels, labels_no_copy)

# Test input validation
with pytest.raises(ValueError):
affinity_propagation(S[:, :-1])
with pytest.raises(ValueError):
affinity_propagation(S, damping=0)
af = AffinityPropagation(affinity="unknown")
af = AffinityPropagation(affinity="unknown", random_state=78)
with pytest.raises(ValueError):
af.fit(X)
af_2 = AffinityPropagation(affinity='precomputed')
af_2 = AffinityPropagation(affinity='precomputed', random_state=21)
with pytest.raises(TypeError):
af_2.fit(csr_matrix((3, 3)))

def test_affinity_propagation_predict():
# Test AffinityPropagation.predict
af = AffinityPropagation(affinity="euclidean")
af = AffinityPropagation(affinity="euclidean", random_state=63)
labels = af.fit_predict(X)
labels2 = af.predict(X)
assert_array_equal(labels, labels2)
Expand All @@ -87,7 +89,7 @@ def test_affinity_propagation_predict_error():

# Predict not supported when affinity="precomputed".
S = np.dot(X, X.T)
af = AffinityPropagation(affinity="precomputed")
af = AffinityPropagation(affinity="precomputed", random_state=57)
af.fit(S)
with pytest.raises(ValueError):
af.predict(X)
Expand All @@ -100,7 +102,7 @@ def test_affinity_propagation_fit_non_convergence():
X = np.array([[0, 0], [1, 1], [-2, -2]])

# Force non-convergence by allowing only a single iteration
af = AffinityPropagation(preference=-10, max_iter=1)
af = AffinityPropagation(preference=-10, max_iter=1, random_state=82)

assert_warns(ConvergenceWarning, af.fit, X)
assert_array_equal(np.empty((0, 2)), af.cluster_centers_)
Expand Down Expand Up @@ -129,7 +131,7 @@ def test_affinity_propagation_equal_mutual_similarities():

# setting different preferences
cluster_center_indices, labels = assert_no_warnings(
affinity_propagation, S, preference=[-20, -10])
affinity_propagation, S, preference=[-20, -10], random_state=37)

# expect one cluster, with highest-preference sample as exemplar
assert_array_equal([1], cluster_center_indices)
Expand All @@ -143,7 +145,8 @@ def test_affinity_propagation_predict_non_convergence():

# Force non-convergence by allowing only a single iteration
af = assert_warns(ConvergenceWarning,
AffinityPropagation(preference=-10, max_iter=1).fit, X)
AffinityPropagation(preference=-10,
max_iter=1, random_state=75).fit, X)

# At prediction time, consider new samples as noise since there are no
# clusters
Expand All @@ -156,7 +159,8 @@ def test_affinity_propagation_non_convergence_regressiontest():
X = np.array([[1, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0],
[0, 0, 1, 0, 0, 1]])
af = AffinityPropagation(affinity='euclidean', max_iter=2).fit(X)
af = AffinityPropagation(affinity='euclidean',
max_iter=2, random_state=34).fit(X)
assert_array_equal(np.array([-1, -1, -1]), af.labels_)


Expand All @@ -181,14 +185,46 @@ def test_equal_similarities_and_preferences():
assert _equal_similarities_and_preferences(S, np.array(0))


def test_affinity_propagation_random_state():
# Significance of random_state parameter
# Generate sample data
centers = [[1, 1], [-1, -1], [1, -1]]
X, labels_true = make_blobs(n_samples=300, centers=centers,
cluster_std=0.5, random_state=0)
# random_state = 0
ap = AffinityPropagation(convergence_iter=1, max_iter=2, random_state=0)
ap.fit(X)
centers0 = ap.cluster_centers_

# random_state = 76
ap = AffinityPropagation(convergence_iter=1, max_iter=2, random_state=76)
ap.fit(X)
centers76 = ap.cluster_centers_

assert np.mean((centers0 - centers76) ** 2) > 1


# FIXME: to be removed in 0.25
def test_affinity_propagation_random_state_warning():
# test that a warning is raised when random_state is not defined.
X = np.array([[0, 0], [1, 1], [-2, -2]])
match = ("'random_state' has been introduced in 0.23. "
"It will be set to None starting from 0.25 which "
"means that results will differ at every function "
"call. Set 'random_state' to None to silence this "
"warning, or to 0 to keep the behavior of versions "
"<0.23.")
with pytest.warns(FutureWarning, match=match):
AffinityPropagation().fit(X)

@pytest.mark.parametrize('centers', [csr_matrix(np.zeros((1, 10))),
np.zeros((1, 10))])
def test_affinity_propagation_convergence_warning_dense_sparse(centers):
"""Non-regression, see #13334"""
rng = np.random.RandomState(42)
X = rng.rand(40, 10)
y = (4 * rng.rand(40)).astype(np.int)
ap = AffinityPropagation()
ap = AffinityPropagation(random_state=46)
ap.fit(X, y)
ap.cluster_centers_ = centers
with pytest.warns(None) as record:
Expand Down
4 changes: 4 additions & 0 deletions sklearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ def test_fit_docstring_attributes(name, Estimator):
if Estimator.__name__ == 'DummyClassifier':
est.strategy = "stratified"

# TO BE REMOVED for v0.25 (avoid FutureWarning)
if Estimator.__name__ == 'AffinityPropagation':
est.random_state = 63

X, y = make_classification(n_samples=20, n_features=3,
n_redundant=0, n_classes=2,
random_state=2)
Expand Down

0 comments on commit 22a7d5b

Please sign in to comment.