From 65ba51ae79c89c5fb0e1f2163a6999d3cfcd8d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zadro=C5=BCny?= Date: Mon, 17 Jan 2022 22:19:02 +0100 Subject: [PATCH 1/3] [MRG] Implement Sinkhorn in log-domain for WDA * for small values of the regularization parameter (reg) the current implementation runs into numerical issues (nans and infs) * this can be resolved by using log-domain implementation of the sinkhorn algorithm --- ot/dr.py | 23 ++++++++++++++++------- test/test_dr.py | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/ot/dr.py b/ot/dr.py index 1671ca0f8..399e14387 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -30,16 +30,25 @@ def dist(x1, x2): return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T) +def logsumexp(M, axis): + r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) + """ + amax = np.amax(M, axis=axis, keepdims=True) + return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) + + def sinkhorn(w1, w2, M, reg, k): - r"""Sinkhorn algorithm with fixed number of iteration (autograd) + r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) """ - K = np.exp(-M / reg) - ui = np.ones((M.shape[0],)) - vi = np.ones((M.shape[1],)) + Mr = -M / reg + ui = np.zeros((M.shape[0],)) + vi = np.zeros((M.shape[1],)) + log_w1 = np.log(w1) + log_w2 = np.log(w2) for i in range(k): - vi = w2 / (np.dot(K.T, ui)) - ui = w1 / (np.dot(K, vi)) - G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) + vi = log_w2 - logsumexp(Mr + ui[:, None], 0) + ui = log_w1 - logsumexp(Mr + vi[None, :], 1) + G = np.exp(ui[:, None] + Mr + vi[None, :]) return G diff --git a/test/test_dr.py b/test/test_dr.py index 741f2add1..2c380c1b8 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -60,6 +60,28 @@ def test_wda(): np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) +@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") +def test_wda_low_reg(): + + n_samples = 100 # nb samples in source and target datasets + np.random.seed(0) + + # generate gaussian dataset + xs, ys = ot.datasets.make_data_classif('gaussrot', n_samples) + + n_features_noise = 8 + + xs = np.hstack((xs, np.random.randn(n_samples, n_features_noise))) + + p = 2 + + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10) + + projwda(xs) + + np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p)) + + @pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)") def test_wda_normalized(): From 76bdbfd03389e3dcb606363ed7f516833c042472 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zadro=C5=BCny?= Date: Tue, 18 Jan 2022 13:23:44 +0100 Subject: [PATCH 2/3] Add feature to RELEASES and contributor name --- RELEASES.md | 2 ++ ot/dr.py | 1 + 2 files changed, 3 insertions(+) diff --git a/RELEASES.md b/RELEASES.md index 9b92d971a..09e4fca63 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,8 @@ #### New features - Better list of related examples in quick start guide with `minigallery` (PR #334) +- Use log-domain Sinkhorn implementation in WDA to support smaller values + of the regularization parameter (PR #336) ## 0.8.1.0 *December 2021* diff --git a/ot/dr.py b/ot/dr.py index 399e14387..e1bcb4bad 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -11,6 +11,7 @@ # Author: Remi Flamary # Minhui Huang +# Jakub Zadrozny # # License: MIT License From 1c7256021f2c2f3c44cb03f4d10686235f53b934 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Zadro=C5=BCny?= Date: Thu, 20 Jan 2022 21:38:35 +0100 Subject: [PATCH 3/3] Add 'sinkhorn_method' parameter to WDA * use the standard Sinkhorn solver by default (faster) * use log-domain Sinkhorn if asked by the user --- RELEASES.md | 2 +- ot/dr.py | 36 +++++++++++++++++++++++++++++++++--- test/test_dr.py | 2 +- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index b87c53bb7..a5fcbe15c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,7 +5,7 @@ #### New features - Better list of related examples in quick start guide with `minigallery` (PR #334) -- Use log-domain Sinkhorn implementation in WDA to support smaller values +- Add optional log-domain Sinkhorn implementation in WDA to support smaller values of the regularization parameter (PR #336) #### Closed issues diff --git a/ot/dr.py b/ot/dr.py index e1bcb4bad..0955c5516 100644 --- a/ot/dr.py +++ b/ot/dr.py @@ -31,6 +31,19 @@ def dist(x1, x2): return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T) +def sinkhorn(w1, w2, M, reg, k): + r"""Sinkhorn algorithm with fixed number of iteration (autograd) + """ + K = np.exp(-M / reg) + ui = np.ones((M.shape[0],)) + vi = np.ones((M.shape[1],)) + for i in range(k): + vi = w2 / (np.dot(K.T, ui)) + ui = w1 / (np.dot(K, vi)) + G = ui.reshape((M.shape[0], 1)) * K * vi.reshape((1, M.shape[1])) + return G + + def logsumexp(M, axis): r"""Log-sum-exp reduction compatible with autograd (no numpy implementation) """ @@ -38,7 +51,7 @@ def logsumexp(M, axis): return np.log(np.sum(np.exp(M - amax), axis=axis)) + np.squeeze(amax, axis=axis) -def sinkhorn(w1, w2, M, reg, k): +def sinkhorn_log(w1, w2, M, reg, k): r"""Sinkhorn algorithm in log-domain with fixed number of iteration (autograd) """ Mr = -M / reg @@ -120,7 +133,7 @@ def proj(X): return Popt, proj -def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, normalize=False): +def wda(X, y, p=2, reg=1, k=10, solver=None, sinkhorn_method='sinkhorn', maxiter=100, verbose=0, P0=None, normalize=False): r""" Wasserstein Discriminant Analysis :ref:`[11] ` @@ -136,6 +149,14 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no - :math:`W` is entropic regularized Wasserstein distances - :math:`\mathbf{X}^i` are samples in the dataset corresponding to class i + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sparse cost matrices, you should use the + :py:func:`ot.dr.sinkhorn_log` solver that will avoid numerical + errors, but can be slow in practice. + Parameters ---------- X : ndarray, shape (n, d) @@ -149,6 +170,8 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no solver : None | str, optional None for steepest descent or 'TrustRegions' for trust regions algorithm else should be a pymanopt.solvers + sinkhorn_method : str + method used for the Sinkhorn solver, either 'sinkhorn' or 'sinkhorn_log' P0 : ndarray, shape (d, p) Initial starting point for projection. normalize : bool, optional @@ -171,6 +194,13 @@ def wda(X, y, p=2, reg=1, k=10, solver=None, maxiter=100, verbose=0, P0=None, no Wasserstein Discriminant Analysis. arXiv preprint arXiv:1608.08063. """ # noqa + if sinkhorn_method.lower() == 'sinkhorn': + sinkhorn_solver = sinkhorn + elif sinkhorn_method.lower() == 'sinkhorn_log': + sinkhorn_solver = sinkhorn_log + else: + raise ValueError("Unknown Sinkhorn method '%s'." % sinkhorn_method) + mx = np.mean(X) X -= mx.reshape((1, -1)) @@ -203,7 +233,7 @@ def cost(P): for j, xj in enumerate(xc[i:]): xj = np.dot(xj, P) M = dist(xi, xj) - G = sinkhorn(wc[i], wc[j + i], M, reg * regmean[i, j], k) + G = sinkhorn_solver(wc[i], wc[j + i], M, reg * regmean[i, j], k) if j == 0: loss_w += np.sum(G * M) else: diff --git a/test/test_dr.py b/test/test_dr.py index 2c380c1b8..6d7fc9aa6 100644 --- a/test/test_dr.py +++ b/test/test_dr.py @@ -75,7 +75,7 @@ def test_wda_low_reg(): p = 2 - Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10) + Pwda, projwda = ot.dr.wda(xs, ys, p, reg=0.01, maxiter=10, sinkhorn_method='sinkhorn_log') projwda(xs)