From 5f8d386c7c88c6d5a93cd6c38168f14acd642f1b Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 7 Aug 2024 22:42:35 +0200 Subject: [PATCH 1/5] translation invariant sinkhorn --- README.md | 4 +- .../unbalanced-partial/plot_unbalanced_OT.py | 2 +- ot/unbalanced.py | 253 +++++++++++++++++- 3 files changed, 255 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index b7493074c..715ae2b3f 100644 --- a/README.md +++ b/README.md @@ -370,4 +370,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [68] Chowdhury, S., Miller, D., & Needham, T. (2021). [Quantized gromov-wasserstein](https://link.springer.com/chapter/10.1007/978-3-030-86523-8_49). ECML PKDD 2021. Springer International Publishing. -[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. \ No newline at end of file +[69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. diff --git a/examples/unbalanced-partial/plot_unbalanced_OT.py b/examples/unbalanced-partial/plot_unbalanced_OT.py index 03487e7e2..6f545c1f7 100644 --- a/examples/unbalanced-partial/plot_unbalanced_OT.py +++ b/examples/unbalanced-partial/plot_unbalanced_OT.py @@ -73,7 +73,7 @@ reg_m_l2 = 5 mass = 0.7 -entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl) +entropic_kl_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl") kl_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_kl, div='kl') l2_uot = ot.unbalanced.mm_unbalanced(a, b, M, reg_m_l2, div='l2') partial_ot = ot.partial.partial_wasserstein(a, b, M, m=mass) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index c39888a31..a6ea88309 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -131,6 +131,10 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + See Also -------- @@ -139,7 +143,8 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` - + ot.unbalanced.sinkhorn_unbalanced_translation_invariant + Translation invariant Unbalanced Sinkhorn :ref:`[70] ` """ if method.lower() == 'sinkhorn': @@ -154,12 +159,21 @@ def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + return sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, + verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + else: raise ValueError("Unknown method '%s'." % method) @@ -269,12 +283,16 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] ` ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] ` ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] ` - + ot.unbalanced.sinkhorn_unbalanced_translation_invariant: :ref:`[70] ` """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) @@ -291,6 +309,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + res = sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, @@ -317,6 +342,13 @@ def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) + + elif method.lower() == 'sinkhorn_translation_invariant': + return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, + warmstart, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, **kwargs) + elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, @@ -779,6 +811,223 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", return ot_matrix +def rescale_potentials(f, g, a, b, rho1, rho2): + """ + TODO + """ + nx = get_backend(f, g, a, b) + + tau = (rho1 * rho2) / (rho1 + rho2) + transl = tau * (nx.logsumexp(nx.log(a) - f / rho1) + - nx.logsumexp(nx.log(b) - g / rho2)) + return transl + + +def softmin(a, f, rho): + nx = get_backend(a, f) + return - rho * nx.logsumexp(nx.log(a) - f / rho) + + +def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl", + warmstart=None, numItermax=1000, stopThr=1e-6, + verbose=False, log=False, **kwargs): + """ + Solve the entropic regularization unbalanced optimal transport problem and + return the OT plan + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot \Omega(\gamma) + + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) + + s.t. + \gamma \geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term,KL divergence + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - KL is the Kullback-Leibler divergence + + The algorithm used for solving the problem is the translation invariant Sinkhorn algorithm as proposed in :ref:`[70] ` + + + Parameters + ---------- + a : array-like (dim_a,) + Unnormalized histogram of dimension `dim_a` + b : array-like (dim_b,) or array-like (dim_b, n_hists) + One or multiple unnormalized histograms of dimension `dim_b` + If many, compute all the OT distances (a, b_i) + M : array-like (dim_a, dim_b) + loss matrix + reg : float + Entropy regularization term > 0 + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If reg_m is a scalar or an indexable object of length 1, + then the same reg_m is applied to both marginal relaxations. + The entropic balanced OT can be recovered using `reg_m=float("inf")`. + For semi-relaxed case, use either + `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. + If reg_m is an array, it must have the same backend as input arrays (a, b, M). + reg_type : string, optional TODO: CHECKE ENTROPY? + Regularizer term. Can take two values: + 'entropy' (negative entropy) + :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or + 'kl' (Kullback-Leibler) + :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors). + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + if n_hists == 1: + - gamma : (dim_a, dim_b) array-like + Optimal transportation matrix for the given parameters + - log : dict + log dictionary returned only if `log` is `True` + else: + - ot_distance : (n_hists,) array-like + the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` + - log : dict + log dictionary returned only if `log` is `True` + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.],[1., 0.]] + >>> ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, 1., 1.) + array([[0.32205361, 0.1184769], + [0.1184769, 0.32205361]]) + + .. _references-sinkhorn-unbalanced-translation-invariant: + References + ---------- + .. [70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + dim_a, dim_b = M.shape + + if len(a) == 0: + a = nx.ones(dim_a, type_as=M) / dim_a + if len(b) == 0: + b = nx.ones(dim_b, type_as=M) / dim_b + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, 1), type_as=M) + v = nx.ones((dim_b, n_hists), type_as=M) + a = a.reshape(dim_a, 1) + else: + u = nx.ones(dim_a, type_as=M) + v = nx.ones(dim_b, type_as=M) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + + # potentials + f, g = reg * nx.log(u), reg * nx.log(v) + + fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 + fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 + + k1 = reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2)) if reg_m1 != float("inf") else 0 + k2 = reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2)) if reg_m2 != float("inf") else 0 + + xi1 = (reg_m2 * reg) / (reg_m1 * (reg + reg_m1 + reg_m2)) + xi2 = (reg_m1 * reg) / (reg_m2 * (reg + reg_m1 + reg_m2)) + + err = 1. + + for i in range(numItermax): + fprev = f + gprev = g + + f_hat = -reg * fi_1 * nx.logsumexp(nx.log(b)[None, :] + (g[None, :] - M) / reg, axis=1) - k2 * softmin(b, g, reg_m2) + f = f_hat + xi1 * softmin(a, f_hat, reg_m1) + + g_hat = -reg * fi_2 * nx.logsumexp(nx.log(a)[:, None] + (f[:, None] - M) / reg, axis=0) - k1 * softmin(a, f, reg_m1) + g = g_hat + xi2 * softmin(b, g_hat, reg_m2) + + tprev = rescale_potentials(fprev, gprev, a, b, reg_m1, reg_m2) + uprev, vprev = nx.exp((fprev + tprev) / reg), nx.exp((gprev - tprev) / reg) + + t = rescale_potentials(f, g, a, b, reg_m1, reg_m2) + u, v = nx.exp((f + t) / reg), nx.exp((g - t) / reg) + + err_u = nx.max(nx.abs(u - uprev)) / max( + nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. + ) + err_v = nx.max(nx.abs(v - vprev)) / max( + nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. + ) + err = 0.5 * (err_u + err_v) + if log: + log['err'].append(err) + if verbose: + if i % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(i, err)) + + if err < stopThr: + break + + if log: + log['logu'] = nx.log(u + 1e-300) + log['logv'] = nx.log(v + 1e-300) + + K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] + + if n_hists: # return only loss + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u[:, None] * K * v[None, :], log + else: + return u[:, None] * K * v[None, :] + + def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False): From 5b617534ab21aca2dcac9cb1757f8ea0195f2d82 Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 7 Aug 2024 23:12:22 +0200 Subject: [PATCH 2/5] translation invariant sinkhorn --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 715ae2b3f..4db0a672b 100644 --- a/README.md +++ b/README.md @@ -372,4 +372,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [69] Delon, J., & Desolneux, A. (2020). [A Wasserstein-type distance in the space of Gaussian mixture models](https://epubs.siam.org/doi/abs/10.1137/19M1301047). SIAM Journal on Imaging Sciences, 13(2), 936-970. -[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. +[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). [Faster Unbalanced Optimal Transport: Translation Invariant Sinkhorn and 1-D Frank-Wolfe](https://proceedings.mlr.press/v151/sejourne22a.html). In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. From 28d27c27bc0ebdf12d64c47b1bcb8a5461b9d4eb Mon Sep 17 00:00:00 2001 From: Clement Date: Wed, 7 Aug 2024 23:13:36 +0200 Subject: [PATCH 3/5] translation invariant sinkhorn --- ot/unbalanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index a6ea88309..10b502fe8 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -915,8 +915,8 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_unbalanced_translation_invariant(a, b, M, 1., 1.) - array([[0.32205361, 0.1184769], - [0.1184769, 0.32205361]]) + array([[0.32205357, 0.11847689], + [0.11847689, 0.32205357]]) .. _references-sinkhorn-unbalanced-translation-invariant: References From 834fde5b62f66a7213ba7b8e638bcd17b04cef95 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 8 Aug 2024 09:38:30 +0200 Subject: [PATCH 4/5] example cv --- .../plot_conv_sinkhorn_ti.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 examples/unbalanced-partial/plot_conv_sinkhorn_ti.py diff --git a/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py new file mode 100644 index 000000000..887b39bcb --- /dev/null +++ b/examples/unbalanced-partial/plot_conv_sinkhorn_ti.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +""" +=============================================================== +Translation Invariant Sinkhorn for Unbalanced Optimal Transport +=============================================================== + +This examples illustrates the better convergence of the translation +invariance Sinkhorn algorithm proposed in [70] compared to the classical +Sinkhorn algorithm. + +[70] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). +Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. +In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + +""" + +# Author: Clément Bonet +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot + +############################################################################## +# Setting parameters +# ------------- + +# %% parameters + +n_iter = 50 # nb iters +n = 40 # nb samples + +num_iter_max = 100 +n_noise = 10 + +reg = 0.005 +reg_m_kl = 0.05 + +mu_s = np.array([-1, -1]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + + +############################################################################## +# Compute entropic kl-regularized UOT with Sinkhorn and Translation Invariant Sinkhorn +# ----------- + +err_sinkhorn_uot = np.empty((n_iter, num_iter_max)) +err_sinkhorn_uot_ti = np.empty((n_iter, num_iter_max)) + + +for seed in range(n_iter): + np.random.seed(seed) + xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) + xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + + xs = np.concatenate((xs, ((np.random.rand(n_noise, 2) - 4))), axis=0) + xt = np.concatenate((xt, ((np.random.rand(n_noise, 2) + 6))), axis=0) + + n = n + n_noise + + a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + + # loss matrix + M = ot.dist(xs, xt) + M /= M.max() + + entropic_kl_uot, log_uot = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", log=True, numItermax=num_iter_max, stopThr=0) + entropic_kl_uot_ti, log_uot_ti = ot.unbalanced.sinkhorn_unbalanced(a, b, M, reg, reg_m_kl, reg_type="kl", + method="sinkhorn_translation_invariant", log=True, + numItermax=num_iter_max, stopThr=0) + + err_sinkhorn_uot[seed] = log_uot["err"] + err_sinkhorn_uot_ti[seed] = log_uot_ti["err"] + +############################################################################## +# Plot the results +# ---------------- + +mean_sinkh = np.mean(err_sinkhorn_uot, axis=0) +std_sinkh = np.std(err_sinkhorn_uot, axis=0) + +mean_sinkh_ti = np.mean(err_sinkhorn_uot_ti, axis=0) +std_sinkh_ti = np.std(err_sinkhorn_uot_ti, axis=0) + +absc = list(range(num_iter_max)) + +pl.plot(absc, mean_sinkh, label="Sinkhorn") +pl.fill_between(absc, mean_sinkh - 2 * std_sinkh, mean_sinkh + 2 * std_sinkh, alpha=0.5) + +pl.plot(absc, mean_sinkh_ti, label="Translation Invariant Sinkhorn") +pl.fill_between(absc, mean_sinkh_ti - 2 * std_sinkh_ti, mean_sinkh_ti + 2 * std_sinkh_ti, alpha=0.5) + +pl.yscale("log") +pl.legend() +pl.xlabel("Number of Iterations") +pl.ylabel(r"$\|u-v\|_\infty$") +pl.grid(True) +pl.show() From 11f4e66a3a74291781bac74ee5ccfe6676942c28 Mon Sep 17 00:00:00 2001 From: Clement Date: Thu, 8 Aug 2024 14:34:08 +0200 Subject: [PATCH 5/5] test ti sinkhorn uot --- ot/unbalanced.py | 76 +++++++++++++++++++++++------------------ test/test_unbalanced.py | 8 ++--- 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/ot/unbalanced.py b/ot/unbalanced.py index 10b502fe8..04e575740 100644 --- a/ot/unbalanced.py +++ b/ot/unbalanced.py @@ -811,23 +811,6 @@ def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", return ot_matrix -def rescale_potentials(f, g, a, b, rho1, rho2): - """ - TODO - """ - nx = get_backend(f, g, a, b) - - tau = (rho1 * rho2) / (rho1 + rho2) - transl = tau * (nx.logsumexp(nx.log(a) - f / rho1) - - nx.logsumexp(nx.log(b) - g / rho2)) - return transl - - -def softmin(a, f, rho): - nx = get_backend(a, f) - return - rho * nx.logsumexp(nx.log(a) - f / rho) - - def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): @@ -875,7 +858,7 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). - reg_type : string, optional TODO: CHECKE ENTROPY? + reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or @@ -959,8 +942,12 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - # potentials - f, g = reg * nx.log(u), reg * nx.log(v) + u_, v_ = u, v + + if reg_type == "kl": + K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] + elif reg_type == "entropy": + K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 @@ -968,26 +955,49 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" k1 = reg * reg_m1 / ((reg + reg_m1) * (reg_m1 + reg_m2)) if reg_m1 != float("inf") else 0 k2 = reg * reg_m2 / ((reg + reg_m2) * (reg_m1 + reg_m2)) if reg_m2 != float("inf") else 0 - xi1 = (reg_m2 * reg) / (reg_m1 * (reg + reg_m1 + reg_m2)) - xi2 = (reg_m1 * reg) / (reg_m2 * (reg + reg_m1 + reg_m2)) + k_rho1 = k1 * reg_m1 / reg if reg_m1 != float("inf") else 0 + k_rho2 = k2 * reg_m2 / reg if reg_m2 != float("inf") else 0 + + if reg_m1 == float("inf") and reg_m2 == float("inf"): + xi1, xi2 = 0, 0 + fi_12 = 1 + elif reg_m1 == float("inf"): + xi1 = 0 + xi2 = reg / reg_m2 + fi_12 = reg_m2 + elif reg_m2 == float("inf"): + xi1 = reg / reg_m1 + xi2 = 0 + fi_12 = reg_m1 + else: + xi1 = (reg_m2 * reg) / (reg_m1 * (reg + reg_m1 + reg_m2)) + xi2 = (reg_m1 * reg) / (reg_m2 * (reg + reg_m1 + reg_m2)) + fi_12 = reg_m1 * reg_m2 / (reg_m1 + reg_m2) + + xi_rho1 = xi1 * reg_m1 / reg if reg_m1 != float("inf") else 0 + xi_rho2 = xi2 * reg_m2 / reg if reg_m2 != float("inf") else 0 + + reg_ratio1 = reg / reg_m1 if reg_m1 != float("inf") else 0 + reg_ratio2 = reg / reg_m2 if reg_m2 != float("inf") else 0 err = 1. for i in range(numItermax): - fprev = f - gprev = g + uprev = u + vprev = v - f_hat = -reg * fi_1 * nx.logsumexp(nx.log(b)[None, :] + (g[None, :] - M) / reg, axis=1) - k2 * softmin(b, g, reg_m2) - f = f_hat + xi1 * softmin(a, f_hat, reg_m1) + Kv = nx.dot(K, v_) + u_hat = (a / Kv) ** fi_1 * nx.sum(b * v_**reg_ratio2)**k_rho2 + u_ = u_hat * nx.sum(a * u_hat**(-reg_ratio1))**(-xi_rho1) - g_hat = -reg * fi_2 * nx.logsumexp(nx.log(a)[:, None] + (f[:, None] - M) / reg, axis=0) - k1 * softmin(a, f, reg_m1) - g = g_hat + xi2 * softmin(b, g_hat, reg_m2) + Ktu = nx.dot(K.T, u_) + v_hat = (b / Ktu) ** fi_2 * nx.sum(a * u_**(-reg_ratio1))**k_rho1 + v_ = v_hat * nx.sum(b * v_hat**(-reg_ratio2))**(-xi_rho2) - tprev = rescale_potentials(fprev, gprev, a, b, reg_m1, reg_m2) - uprev, vprev = nx.exp((fprev + tprev) / reg), nx.exp((gprev - tprev) / reg) + t = (nx.sum(a * u_**(-reg_ratio1)) / nx.sum(b * v_**(-reg_ratio2)))**(fi_12 / reg) - t = rescale_potentials(f, g, a, b, reg_m1, reg_m2) - u, v = nx.exp((f + t) / reg), nx.exp((g - t) / reg) + u = u_ * t + v = v_ / t err_u = nx.max(nx.abs(u - uprev)) / max( nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. @@ -1011,8 +1021,6 @@ def sinkhorn_unbalanced_translation_invariant(a, b, M, reg, reg_m, reg_type="kl" log['logu'] = nx.log(u + 1e-300) log['logv'] = nx.log(v + 1e-300) - K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] - if n_hists: # return only loss res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: diff --git a/test/test_unbalanced.py b/test/test_unbalanced.py index 7007e336b..fa4592ad5 100644 --- a/test/test_unbalanced.py +++ b/test/test_unbalanced.py @@ -14,7 +14,7 @@ from ot.unbalanced import barycenter_unbalanced -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) def test_unbalanced_convergence(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -78,7 +78,7 @@ def test_unbalanced_convergence(nx, method, reg_type): np.testing.assert_allclose(G_np, nx.to_numpy(G)) -@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"])) +@pytest.mark.parametrize("method,reg_type", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"])) def test_unbalanced_warmstart(nx, method, reg_type): # test generalized sinkhorn for unbalanced OT n = 100 @@ -140,7 +140,7 @@ def test_unbalanced_warmstart(nx, method, reg_type): np.testing.assert_allclose(nx.to_numpy(loss0), nx.to_numpy(loss1), atol=1e-5) -@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"], ["kl", "entropy"], [True, False])) +@pytest.mark.parametrize("method,reg_type, log", itertools.product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"], ["kl", "entropy"], [True, False])) def test_sinkhorn_unbalanced2(nx, method, reg_type, log): n = 100 rng = np.random.RandomState(42) @@ -214,7 +214,7 @@ def test_unbalanced_relaxation_parameters(nx, method, reg_m): nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05) -@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling"]) +@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_stabilized", "sinkhorn_reg_scaling", "sinkhorn_translation_invariant"]) def test_unbalanced_multiple_inputs(nx, method): # test generalized sinkhorn for unbalanced OT n = 100