Skip to content

Commit

Permalink
[MRG+1] Fix semi_supervised (scikit-learn#9239)
Browse files Browse the repository at this point in the history
* Files for my dev environment with Docker

* Fixing label clamping (alpha=0 for hard clamping)

* Deprecating alpha, fixing its value to zero

* Correct way to deprecate alpha for LabelPropagation

The previous way was breaking the test
sklearn.tests.test_common.test_all_estimators

* Detailed info for LabelSpreading's alpha parameter

Based on the original paper.

* Minor changes in the deprecation message

* Improving "deprecated" doc string and raising DeprecationWarning

* Using a local "alpha" in "fit" to deprecate LabelPropagation's alpha

This solution isn't great, but it sets the correct value for alpha
without violating the restrictions imposed by the tests.

* Removal of my development files

* Using sphinx's "deprecated" tag (jnothman's suggestion)

* Deprecation warning: stating that the alpha's value will be ignored

* Use __init__ with alpha=None

* Update what's new

* Try fix RuntimeWarning in test_alpha_deprecation

* DOC Indent deprecation details

* DOC wording

* Update docs

* Change to the one true implementation.

* Add sanity-checked impl. of Label{Propagation,Spreading}

* Raise ValueError if alpha is invalid in LabelSpreading.

* Add a normalizing step before clamping to LabelPropagation.

* Fix flake8 errors.

* Remove duplicate imports.

* DOC Update What's New.

* Specify alpha's value in the error.

* Tidy up tests.

Add a test and add references, where needed.

* Add comment to non-regression test.

* Fix documentation.

* Move check for alpha into fit from __init__.

* Fix corner case of LabelSpreading with alpha=None.

* alpha -> self.variant

* Make Whats_new more explicit.

* Simplify impl. of Label{Propagation,Spreading}.

* variant -> _variant.
  • Loading branch information
musically-ut authored and NelleV committed Aug 11, 2017
1 parent ac1be3d commit a9b9e3f
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 19 deletions.
4 changes: 2 additions & 2 deletions doc/modules/label_propagation.rst
Expand Up @@ -52,8 +52,8 @@ differ in modifications to the similarity matrix that graph and the
clamping effect on the label distributions.
Clamping allows the algorithm to change the weight of the true ground labeled
data to some degree. The :class:`LabelPropagation` algorithm performs hard
clamping of input labels, which means :math:`\alpha=1`. This clamping factor
can be relaxed, to say :math:`\alpha=0.8`, which means that we will always
clamping of input labels, which means :math:`\alpha=0`. This clamping factor
can be relaxed, to say :math:`\alpha=0.2`, which means that we will always
retain 80 percent of our original label distribution, but the algorithm gets to
change its confidence of the distribution within 20 percent.

Expand Down
11 changes: 10 additions & 1 deletion doc/whats_new.rst
Expand Up @@ -448,7 +448,16 @@ Bug fixes
in :class:`decomposition.PCA`,
:class:`decomposition.RandomizedPCA` and
:class:`decomposition.IncrementalPCA`.
:issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.
:issue:`9105` by `Hanmin Qin <https://github.com/qinhanmin2014>`_.

- Fix :class:`semi_supervised.BaseLabelPropagation` to correctly implement
``LabelPropagation`` and ``LabelSpreading`` as done in the referenced
papers. :class:`semi_supervised.LabelPropagation` now always does hard
clamping. Its ``alpha`` parameter has no effect and is
deprecated to be removed in 0.21. :issue:`6727` :issue:`3550` issue:`5770`
by :user:`Andre Ambrosio Boechat <boechat107>`, :user:`Utkarsh Upadhyay
<musically-ut>`, and `Joel Nothman`_.


API changes summary
-------------------
Expand Down
Expand Up @@ -30,7 +30,7 @@

# #############################################################################
# Learn with LabelSpreading
label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=1.0)
label_spread = label_propagation.LabelSpreading(kernel='knn', alpha=0.2)
label_spread.fit(X, labels)

# #############################################################################
Expand Down
76 changes: 61 additions & 15 deletions sklearn/semi_supervised/label_propagation.py
Expand Up @@ -14,11 +14,12 @@
Model Features
--------------
Label clamping:
The algorithm tries to learn distributions of labels over the dataset. In the
"Hard Clamp" mode, the true ground labels are never allowed to change. They
are clamped into position. In the "Soft Clamp" mode, they are allowed some
wiggle room, but some alpha of their original value will always be retained.
Hard clamp is the same as soft clamping with alpha set to 1.
The algorithm tries to learn distributions of labels over the dataset given
label assignments over an initial subset. In one variant, the algorithm does
not allow for any errors in the initial assignment (hard-clamping) while
in another variant, the algorithm allows for some wiggle room for the initial
assignments, allowing them to change by a fraction alpha in each iteration
(soft-clamping).
Kernel:
A function which projects a vector into some higher dimensional space. This
Expand Down Expand Up @@ -55,6 +56,7 @@
# License: BSD
from abc import ABCMeta, abstractmethod

import warnings
import numpy as np
from scipy import sparse

Expand Down Expand Up @@ -239,38 +241,55 @@ def fit(self, X, y):

n_samples, n_classes = len(y), len(classes)

alpha = self.alpha
if self._variant == 'spreading' and \
(alpha is None or alpha <= 0.0 or alpha >= 1.0):
raise ValueError('alpha=%s is invalid: it must be inside '
'the open interval (0, 1)' % alpha)
y = np.asarray(y)
unlabeled = y == -1
clamp_weights = np.ones((n_samples, 1))
clamp_weights[unlabeled, 0] = self.alpha

# initialize distributions
self.label_distributions_ = np.zeros((n_samples, n_classes))
for label in classes:
self.label_distributions_[y == label, classes == label] = 1

y_static = np.copy(self.label_distributions_)
if self.alpha > 0.:
y_static *= 1 - self.alpha
y_static[unlabeled] = 0
if self._variant == 'propagation':
# LabelPropagation
y_static[unlabeled] = 0
else:
# LabelSpreading
y_static *= 1 - alpha

l_previous = np.zeros((self.X_.shape[0], n_classes))

remaining_iter = self.max_iter
unlabeled = unlabeled[:, np.newaxis]
if sparse.isspmatrix(graph_matrix):
graph_matrix = graph_matrix.tocsr()
while (_not_converged(self.label_distributions_, l_previous, self.tol)
and remaining_iter > 1):
l_previous = self.label_distributions_
self.label_distributions_ = safe_sparse_dot(
graph_matrix, self.label_distributions_)
# clamp
self.label_distributions_ = np.multiply(
clamp_weights, self.label_distributions_) + y_static

if self._variant == 'propagation':
normalizer = np.sum(
self.label_distributions_, axis=1)[:, np.newaxis]
self.label_distributions_ /= normalizer
self.label_distributions_ = np.where(unlabeled,
self.label_distributions_,
y_static)
else:
# clamp
self.label_distributions_ = np.multiply(
alpha, self.label_distributions_) + y_static
remaining_iter -= 1

normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
self.label_distributions_ /= normalizer

# set the transduction item
transduction = self.classes_[np.argmax(self.label_distributions_,
axis=1)]
Expand Down Expand Up @@ -299,7 +318,11 @@ class LabelPropagation(BaseLabelPropagation):
Parameter for knn kernel
alpha : float
Clamping factor
Clamping factor.
.. deprecated:: 0.19
This parameter will be removed in 0.21.
'alpha' is fixed to zero in 'LabelPropagation'.
max_iter : float
Change maximum number of iterations allowed
Expand Down Expand Up @@ -350,6 +373,14 @@ class LabelPropagation(BaseLabelPropagation):
LabelSpreading : Alternate label propagation strategy more robust to noise
"""

_variant = 'propagation'

def __init__(self, kernel='rbf', gamma=20, n_neighbors=7,
alpha=None, max_iter=30, tol=1e-3, n_jobs=1):
super(LabelPropagation, self).__init__(
kernel=kernel, gamma=gamma, n_neighbors=n_neighbors, alpha=alpha,
max_iter=max_iter, tol=tol, n_jobs=n_jobs)

def _build_graph(self):
"""Matrix representing a fully connected graph between each sample
Expand All @@ -366,6 +397,15 @@ class distributions will exceed 1 (normalization may be desired).
affinity_matrix /= normalizer[:, np.newaxis]
return affinity_matrix

def fit(self, X, y):
if self.alpha is not None:
warnings.warn(
"alpha is deprecated since 0.19 and will be removed in 0.21.",
DeprecationWarning
)
self.alpha = None
return super(LabelPropagation, self).fit(X, y)


class LabelSpreading(BaseLabelPropagation):
"""LabelSpreading model for semi-supervised learning
Expand All @@ -391,7 +431,11 @@ class LabelSpreading(BaseLabelPropagation):
parameter for knn kernel
alpha : float
clamping factor
Clamping factor. A value in [0, 1] that specifies the relative amount
that an instance should adopt the information from its neighbors as
opposed to its initial label.
alpha=0 means keeping the initial label information; alpha=1 means
replacing all initial information.
max_iter : float
maximum number of iterations allowed
Expand Down Expand Up @@ -446,6 +490,8 @@ class LabelSpreading(BaseLabelPropagation):
LabelPropagation : Unregularized graph based semi-supervised learning
"""

_variant = 'spreading'

def __init__(self, kernel='rbf', gamma=20, n_neighbors=7, alpha=0.2,
max_iter=30, tol=1e-3, n_jobs=1):

Expand Down
86 changes: 86 additions & 0 deletions sklearn/semi_supervised/tests/test_label_propagation.py
Expand Up @@ -3,8 +3,12 @@
import numpy as np

from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_no_warnings
from sklearn.semi_supervised import label_propagation
from sklearn.metrics.pairwise import rbf_kernel
from sklearn.datasets import make_classification
from numpy.testing import assert_array_almost_equal
from numpy.testing import assert_array_equal

Expand Down Expand Up @@ -59,3 +63,85 @@ def test_predict_proba():
clf = estimator(**parameters).fit(samples, labels)
assert_array_almost_equal(clf.predict_proba([[1., 1.]]),
np.array([[0.5, 0.5]]))


def test_alpha_deprecation():
X, y = make_classification(n_samples=100)
y[::3] = -1

lp_default = label_propagation.LabelPropagation(kernel='rbf', gamma=0.1)
lp_default_y = assert_no_warnings(lp_default.fit, X, y).transduction_

lp_0 = label_propagation.LabelPropagation(alpha=0, kernel='rbf', gamma=0.1)
lp_0_y = assert_warns(DeprecationWarning, lp_0.fit, X, y).transduction_

assert_array_equal(lp_default_y, lp_0_y)


def test_label_spreading_closed_form():
n_classes = 2
X, y = make_classification(n_classes=n_classes, n_samples=200,
random_state=0)
y[::3] = -1
clf = label_propagation.LabelSpreading().fit(X, y)
# adopting notation from Zhou et al (2004):
S = clf._build_graph()
Y = np.zeros((len(y), n_classes + 1))
Y[np.arange(len(y)), y] = 1
Y = Y[:, :-1]
for alpha in [0.1, 0.3, 0.5, 0.7, 0.9]:
expected = np.dot(np.linalg.inv(np.eye(len(S)) - alpha * S), Y)
expected /= expected.sum(axis=1)[:, np.newaxis]
clf = label_propagation.LabelSpreading(max_iter=10000, alpha=alpha)
clf.fit(X, y)
assert_array_almost_equal(expected, clf.label_distributions_, 4)


def test_label_propagation_closed_form():
n_classes = 2
X, y = make_classification(n_classes=n_classes, n_samples=200,
random_state=0)
y[::3] = -1
Y = np.zeros((len(y), n_classes + 1))
Y[np.arange(len(y)), y] = 1
unlabelled_idx = Y[:, (-1,)].nonzero()[0]
labelled_idx = (Y[:, (-1,)] == 0).nonzero()[0]

clf = label_propagation.LabelPropagation(max_iter=10000,
gamma=0.1).fit(X, y)
# adopting notation from Zhu et al 2002
T_bar = clf._build_graph()
Tuu = T_bar[np.meshgrid(unlabelled_idx, unlabelled_idx, indexing='ij')]
Tul = T_bar[np.meshgrid(unlabelled_idx, labelled_idx, indexing='ij')]
Y = Y[:, :-1]
Y_l = Y[labelled_idx, :]
Y_u = np.dot(np.dot(np.linalg.inv(np.eye(Tuu.shape[0]) - Tuu), Tul), Y_l)

expected = Y.copy()
expected[unlabelled_idx, :] = Y_u
expected /= expected.sum(axis=1)[:, np.newaxis]

assert_array_almost_equal(expected, clf.label_distributions_, 4)


def test_valid_alpha():
n_classes = 2
X, y = make_classification(n_classes=n_classes, n_samples=200,
random_state=0)
for alpha in [-0.1, 0, 1, 1.1, None]:
assert_raises(ValueError,
lambda **kwargs:
label_propagation.LabelSpreading(**kwargs).fit(X, y),
alpha=alpha)


def test_convergence_speed():
# This is a non-regression test for #5774
X = np.array([[1., 0.], [0., 1.], [1., 2.5]])
y = np.array([0, 1, -1])
mdl = label_propagation.LabelSpreading(kernel='rbf', max_iter=5000)
mdl.fit(X, y)

# this should converge quickly:
assert mdl.n_iter_ < 10
assert_array_equal(mdl.predict(X), [0, 1, 1])

0 comments on commit a9b9e3f

Please sign in to comment.