From 2dbeeda9308029a8e8db56bed07d48f4d5718efb Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Mon, 14 Jun 2021 13:06:40 +0200 Subject: [PATCH] [MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259) * Add batch implementation of Sinkhorn * Reformat to pep8 and modify parameter * Fix error in batch size * Code review and add test * Fix accidental typo in test_empirical_sinkhorn * Remove whitespace * Edit config.yml --- .circleci/config.yml | 1 + ot/bregman.py | 134 ++++++++++++++++++++++++++++++++++++------- test/test_bregman.py | 44 ++++++++++++++ 3 files changed, 158 insertions(+), 21 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 29c9a0716..e4c71dde1 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -73,6 +73,7 @@ jobs: command: | cd docs; make html; + no_output_timeout: 30m # Save the outputs - store_artifacts: diff --git a/ot/bregman.py b/ot/bregman.py index b10effd91..105b38be6 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,6 +11,7 @@ # Mokhtar Z. Alaya # Alexander Tong # Ievgen Redko +# Quang Huy Tran # # License: MIT License @@ -18,6 +19,7 @@ import numpy as np from scipy.optimize import fmin_l_bfgs_b +from scipy.special import logsumexp from ot.utils import unif, dist, list_to_array from .backend import get_backend @@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem and return the @@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' - + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) + + if isLazy: + if log: + dict_log = {"err": []} - M = dist(X_s, X_t, metric=metric) + log_a, log_b = np.log(a), np.log(b) + f, g = np.zeros(ns), np.zeros(nt) + + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError("Batch size must be in integer or a tuple of two integers") + + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = np.zeros(ns) + lse_g = np.zeros(nt) + + for i_ot in range(numIterMax): + + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1) + f = log_a - lse_f + + for j in range_t: + M = dist(X_s, X_t[j:j + bt, :], metric=metric) + lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1 = np.zeros_like(a) + for i in range_s: + M = dist(X_s[i:i + bs, :], X_t, metric=metric) + m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1) + err = np.abs(m1 - a).sum() + if log: + dict_log["err"].append(err) + + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format(i_ot + 1, err)) + + if err <= stopThr: + break + + if log: + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) + else: + return (f, g) - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) - return pi, log else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) - return pi + M = dist(X_s, X_t, metric=metric) + + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs) + return pi def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, **kwargs): + isLazy=False, batchSize=100, verbose=False, log=False, **kwargs): r''' Solve the entropic regularization optimal transport problem from empirical data and return the OT loss @@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num Max number of iterations stopThr : float, optional Stop threshol on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory) + If False, calculate full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batcheses used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. verbose : bool, optional Print information along iterations log : bool, optional @@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. ''' + ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = unif(np.shape(X_s)[0]) + a = unif(ns) if b is None: - b = unif(np.shape(X_t)[0]) + b = unif(nt) - M = dist(X_s, X_t, metric=metric) + if isLazy: + if log: + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log) + + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + + loss = 0 + for i in range_s: + M_block = dist(X_s[i:i + bs, :], X_t, metric=metric) + pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += np.sum(M_block * pi_block) + + if log: + return loss, dict_log + else: + return loss - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss, log else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, - **kwargs) - return sinkhorn_loss + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log, + **kwargs) + return sinkhorn_loss def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, diff --git a/test/test_bregman.py b/test/test_bregman.py index 7c5162a9b..9665229b7 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -2,6 +2,7 @@ # Author: Remi Flamary # Kilian Fatras +# Quang Huy Tran # # License: MIT License @@ -329,6 +330,49 @@ def test_empirical_sinkhorn(): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +def test_lazy_empirical_sinkhorn(): + # test sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + numIterMax = 1000 + + X_s = np.reshape(np.arange(n), (n, 1)) + X_t = np.reshape(np.arange(0, n), (n, 1)) + M = ot.dist(X_s, X_t) + M_m = ot.dist(X_s, X_t, metric='minkowski') + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True) + G_sqe = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_sqe = ot.sinkhorn(a, b, M, 1) + + f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) + sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True) + + f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1) + G_m = np.exp(f[:, None] + g[None, :] - M_m / 1) + sinkhorn_m = ot.sinkhorn(a, b, M_m, 1) + + loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + loss_sinkhorn = ot.sinkhorn2(a, b, M, 1) + + # check constratints + np.testing.assert_allclose( + sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log + np.testing.assert_allclose( + sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian + np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + + def test_empirical_sinkhorn_divergence(): # Test sinkhorn divergence n = 10