diff --git a/doc/modules/label_propagation.rst b/doc/modules/label_propagation.rst index eddc34b7a8c7c..1aba742723f01 100644 --- a/doc/modules/label_propagation.rst +++ b/doc/modules/label_propagation.rst @@ -52,8 +52,8 @@ differ in modifications to the similarity matrix that graph and the clamping effect on the label distributions. Clamping allows the algorithm to change the weight of the true ground labeled data to some degree. The :class:`LabelPropagation` algorithm performs hard -clamping of input labels, which means :math:`\alpha=1`. This clamping factor -can be relaxed, to say :math:`\alpha=0.8`, which means that we will always +clamping of input labels, which means :math:`\alpha=0`. This clamping factor +can be relaxed, to say :math:`\alpha=0.2`, which means that we will always retain 80 percent of our original label distribution, but the algorithm gets to change its confidence of the distribution within 20 percent. diff --git a/doc/whats_new.rst b/doc/whats_new.rst index a9601419c9edd..73fa6dcee8b06 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -448,7 +448,16 @@ Bug fixes in :class:`decomposition.PCA`, :class:`decomposition.RandomizedPCA` and :class:`decomposition.IncrementalPCA`. - :issue:`9105` by `Hanmin Qin `_. + :issue:`9105` by `Hanmin Qin `_. + + - Fix :class:`semi_supervised.BaseLabelPropagation` to correctly implement + ``LabelPropagation`` and ``LabelSpreading`` as done in the referenced + papers. :class:`semi_supervised.LabelPropagation` now always does hard + clamping. Its ``alpha`` parameter has no effect and is + deprecated to be removed in 0.21. :issue:`6727` :issue:`3550` issue:`5770` + by :user:`Andre Ambrosio Boechat `, :user:`Utkarsh Upadhyay + `, and `Joel Nothman`_. + API changes summary ------------------- diff --git a/examples/semi_supervised/plot_label_propagation_structure.py b/examples/semi_supervised/plot_label_propagation_structure.py index 7cc15d73f1b89..95f19ec108e82 100644 --- a/examples/semi_supervised/plot_label_propagation_structure.py +++ b/examples/semi_supervised/plot_label_propagation_structure.py @@ -30,7 +30,7 @@ # ############################################################################# # Learn with LabelSpreading -label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=1.0) +label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=0.2) label_spread.fit(X, labels) # ############################################################################# diff --git a/sklearn/semi_supervised/label_propagation.py b/sklearn/semi_supervised/label_propagation.py index 1759b2c1d7572..ab0dd64bf81ea 100644 --- a/sklearn/semi_supervised/label_propagation.py +++ b/sklearn/semi_supervised/label_propagation.py @@ -14,11 +14,12 @@ Model Features -------------- Label clamping: - The algorithm tries to learn distributions of labels over the dataset. In the - "Hard Clamp" mode, the true ground labels are never allowed to change. They - are clamped into position. In the "Soft Clamp" mode, they are allowed some - wiggle room, but some alpha of their original value will always be retained. - Hard clamp is the same as soft clamping with alpha set to 1. + The algorithm tries to learn distributions of labels over the dataset given + label assignments over an initial subset. In one variant, the algorithm does + not allow for any errors in the initial assignment (hard-clamping) while + in another variant, the algorithm allows for some wiggle room for the initial + assignments, allowing them to change by a fraction alpha in each iteration + (soft-clamping). Kernel: A function which projects a vector into some higher dimensional space. This @@ -55,6 +56,7 @@ # License: BSD from abc import ABCMeta, abstractmethod +import warnings import numpy as np from scipy import sparse @@ -239,10 +241,13 @@ def fit(self, X, y): n_samples, n_classes = len(y), len(classes) + alpha = self.alpha + if self._variant == 'spreading' and \ + (alpha is None or alpha <= 0.0 or alpha >= 1.0): + raise ValueError('alpha=%s is invalid: it must be inside ' + 'the open interval (0, 1)' % alpha) y = np.asarray(y) unlabeled = y == -1 - clamp_weights = np.ones((n_samples, 1)) - clamp_weights[unlabeled, 0] = self.alpha # initialize distributions self.label_distributions_ = np.zeros((n_samples, n_classes)) @@ -250,13 +255,17 @@ def fit(self, X, y): self.label_distributions_[y == label, classes == label] = 1 y_static = np.copy(self.label_distributions_) - if self.alpha > 0.: - y_static *= 1 - self.alpha - y_static[unlabeled] = 0 + if self._variant == 'propagation': + # LabelPropagation + y_static[unlabeled] = 0 + else: + # LabelSpreading + y_static *= 1 - alpha l_previous = np.zeros((self.X_.shape[0], n_classes)) remaining_iter = self.max_iter + unlabeled = unlabeled[:, np.newaxis] if sparse.isspmatrix(graph_matrix): graph_matrix = graph_matrix.tocsr() while (_not_converged(self.label_distributions_, l_previous, self.tol) @@ -264,13 +273,23 @@ def fit(self, X, y): l_previous = self.label_distributions_ self.label_distributions_ = safe_sparse_dot( graph_matrix, self.label_distributions_) - # clamp - self.label_distributions_ = np.multiply( - clamp_weights, self.label_distributions_) + y_static + + if self._variant == 'propagation': + normalizer = np.sum( + self.label_distributions_, axis=1)[:, np.newaxis] + self.label_distributions_ /= normalizer + self.label_distributions_ = np.where(unlabeled, + self.label_distributions_, + y_static) + else: + # clamp + self.label_distributions_ = np.multiply( + alpha, self.label_distributions_) + y_static remaining_iter -= 1 normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis] self.label_distributions_ /= normalizer + # set the transduction item transduction = self.classes_[np.argmax(self.label_distributions_, axis=1)] @@ -299,7 +318,11 @@ class LabelPropagation(BaseLabelPropagation): Parameter for knn kernel alpha : float - Clamping factor + Clamping factor. + + .. deprecated:: 0.19 + This parameter will be removed in 0.21. + 'alpha' is fixed to zero in 'LabelPropagation'. max_iter : float Change maximum number of iterations allowed @@ -350,6 +373,14 @@ class LabelPropagation(BaseLabelPropagation): LabelSpreading : Alternate label propagation strategy more robust to noise """ + _variant = 'propagation' + + def __init__(self, kernel='rbf', gamma=20, n_neighbors=7, + alpha=None, max_iter=30, tol=1e-3, n_jobs=1): + super(LabelPropagation, self).__init__( + kernel=kernel, gamma=gamma, n_neighbors=n_neighbors, alpha=alpha, + max_iter=max_iter, tol=tol, n_jobs=n_jobs) + def _build_graph(self): """Matrix representing a fully connected graph between each sample @@ -366,6 +397,15 @@ class distributions will exceed 1 (normalization may be desired). affinity_matrix /= normalizer[:, np.newaxis] return affinity_matrix + def fit(self, X, y): + if self.alpha is not None: + warnings.warn( + "alpha is deprecated since 0.19 and will be removed in 0.21.", + DeprecationWarning + ) + self.alpha = None + return super(LabelPropagation, self).fit(X, y) + class LabelSpreading(BaseLabelPropagation): """LabelSpreading model for semi-supervised learning @@ -391,7 +431,11 @@ class LabelSpreading(BaseLabelPropagation): parameter for knn kernel alpha : float - clamping factor + Clamping factor. A value in [0, 1] that specifies the relative amount + that an instance should adopt the information from its neighbors as + opposed to its initial label. + alpha=0 means keeping the initial label information; alpha=1 means + replacing all initial information. max_iter : float maximum number of iterations allowed @@ -446,6 +490,8 @@ class LabelSpreading(BaseLabelPropagation): LabelPropagation : Unregularized graph based semi-supervised learning """ + _variant = 'spreading' + def __init__(self, kernel='rbf', gamma=20, n_neighbors=7, alpha=0.2, max_iter=30, tol=1e-3, n_jobs=1): diff --git a/sklearn/semi_supervised/tests/test_label_propagation.py b/sklearn/semi_supervised/tests/test_label_propagation.py index 81e7dd028bf5d..3d5bd21a89110 100644 --- a/sklearn/semi_supervised/tests/test_label_propagation.py +++ b/sklearn/semi_supervised/tests/test_label_propagation.py @@ -3,8 +3,12 @@ import numpy as np from sklearn.utils.testing import assert_equal +from sklearn.utils.testing import assert_warns +from sklearn.utils.testing import assert_raises +from sklearn.utils.testing import assert_no_warnings from sklearn.semi_supervised import label_propagation from sklearn.metrics.pairwise import rbf_kernel +from sklearn.datasets import make_classification from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_equal @@ -59,3 +63,85 @@ def test_predict_proba(): clf = estimator(**parameters).fit(samples, labels) assert_array_almost_equal(clf.predict_proba([[1., 1.]]), np.array([[0.5, 0.5]])) + + +def test_alpha_deprecation(): + X, y = make_classification(n_samples=100) + y[::3] = -1 + + lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1) + lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_ + + lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1) + lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_ + + assert_array_equal(lp_default_y, lp_0_y) + + +def test_label_spreading_closed_form(): + n_classes = 2 + X, y = make_classification(n_classes=n_classes, n_samples=200, + random_state=0) + y[::3] = -1 + clf = label_propagation.LabelSpreading().fit(X, y) + # adopting notation from Zhou et al (2004): + S = clf._build_graph() + Y = np.zeros((len(y), n_classes + 1)) + Y[np.arange(len(y)), y] = 1 + Y = Y[:, :-1] + for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]: + expected = np.dot(np.linalg.inv(np.eye(len(S)) - alpha * S), Y) + expected /= expected.sum(axis=1)[:, np.newaxis] + clf = label_propagation.LabelSpreading(max_iter=10000, alpha=alpha) + clf.fit(X, y) + assert_array_almost_equal(expected, clf.label_distributions_, 4) + + +def test_label_propagation_closed_form(): + n_classes = 2 + X, y = make_classification(n_classes=n_classes, n_samples=200, + random_state=0) + y[::3] = -1 + Y = np.zeros((len(y), n_classes + 1)) + Y[np.arange(len(y)), y] = 1 + unlabelled_idx = Y[:, (-1,)].nonzero()[0] + labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0] + + clf = label_propagation.LabelPropagation(max_iter=10000, + gamma=0.1).fit(X, y) + # adopting notation from Zhu et al 2002 + T_bar = clf._build_graph() + Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')] + Tul = T_bar[np.meshgrid(unlabelled_idx, labelled_idx, indexing='ij')] + Y = Y[:, :-1] + Y_l = Y[labelled_idx, :] + Y_u = np.dot(np.dot(np.linalg.inv(np.eye(Tuu.shape[0]) - Tuu), Tul), Y_l) + + expected = Y.copy() + expected[unlabelled_idx, :] = Y_u + expected /= expected.sum(axis=1)[:, np.newaxis] + + assert_array_almost_equal(expected, clf.label_distributions_, 4) + + +def test_valid_alpha(): + n_classes = 2 + X, y = make_classification(n_classes=n_classes, n_samples=200, + random_state=0) + for alpha in [-0.1, 0, 1, 1.1, None]: + assert_raises(ValueError, + lambda **kwargs: + label_propagation.LabelSpreading(**kwargs).fit(X, y), + alpha=alpha) + + +def test_convergence_speed(): + # This is a non-regression test for #5774 + X = np.array([[1., 0.], [0., 1.], [1., 2.5]]) + y = np.array([0, 1, -1]) + mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=5000) + mdl.fit(X, y) + + # this should converge quickly: + assert mdl.n_iter_ < 10 + assert_array_equal(mdl.predict(X), [0, 1, 1])