diff --git a/RELEASES.md b/RELEASES.md index 974dbbf73..df87425e9 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -23,6 +23,8 @@ - Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559) - Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566) - Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572) +- Create `ot/bregman/`repository (Issue #567, PR #569) + ## 0.9.1 *August 2023* diff --git a/ot/bregman.py b/ot/bregman.py deleted file mode 100644 index c90d89986..000000000 --- a/ot/bregman.py +++ /dev/null @@ -1,3794 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Bregman projections solvers for entropic regularized OT -""" - -# Author: Remi Flamary -# Nicolas Courty -# Kilian Fatras -# Titouan Vayer -# Hicham Janati -# Mokhtar Z. Alaya -# Alexander Tong -# Ievgen Redko -# Quang Huy Tran -# -# License: MIT License - -import warnings - -import numpy as np -from scipy.optimize import fmin_l_bfgs_b - -from ot.utils import dist, list_to_array, unif - -from .backend import get_backend - - -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): - r""" - Solve the entropic regularization optimal transport problem and return the OT matrix - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` - - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a - fast approximation of the Sinkhorn problem. For use of GPU and gradient - computation with small number of iterations we strongly recommend the - :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for - numerical problems. - - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - method : str - method used for the solver either 'sinkhorn','sinkhorn_log', - 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see - those function for specific parameters - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - .. _references-sinkhorn: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information Processing - Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms - for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, - A., & Peyré, G. (2019, April). Interpolating between optimal transport - and MMD using Sinkhorn divergences. In The 22nd International Conference - on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn - :ref:`[9] ` :ref:`[10] ` - ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling - :ref:`[9] ` :ref:`[10] ` - - """ - - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'greenkhorn': - return greenkhorn(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_epsilon_scaling': - return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, - stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): - r""" - Solve the entropic regularization optimal transport problem and return the loss - - The function solves the following optimization problem: - - .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - - and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without - the entropic contribution). - - .. note:: This function is backend-compatible and will work on arrays - from all compatible backends. - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` - - - **Choosing a Sinkhorn solver** - - By default and when using a regularization parameter that is not too small - the default sinkhorn solver should be enough. If you need to use a small - regularization to get sharper OT matrices, you should use the - :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical - errors. This last solver can be very slow in practice and might not even - converge to a reasonable OT matrix in a finite time. This is why - :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value - of the regularization (and using warm start) sometimes leads to better - solutions. Note that the greedy version of the sinkhorn - :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening - version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a - fast approximation of the Sinkhorn problem. For use of GPU and gradient - computation with small number of iterations we strongly recommend the - :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for - numerical problems. - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - method : str - method used for the solver either 'sinkhorn','sinkhorn_log', - 'sinkhorn_stabilized', see those function for specific parameters - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - W : (n_hists) float/array-like - Optimal transportation loss for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn2(a, b, M, 1) - 0.26894142136999516 - - - .. _references-sinkhorn2: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of - Optimal Transport, Advances in Neural Information - Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms - for Entropy Regularized Transport Problems. - arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - - .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation - algorithms for optimal transport via Sinkhorn iteration, - Advances in Neural Information Processing Systems (NIPS) 31, 2017 - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., - Trouvé, A., & Peyré, G. (2019, April). - Interpolating between optimal transport and MMD using Sinkhorn - divergences. In The 22nd International Conference on Artificial - Intelligence and Statistics (pp. 2681-2690). PMLR. - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` - ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` - ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn - :ref:`[9] ` :ref:`[10] ` - """ - - M, a, b = list_to_array(M, a, b) - nx = get_backend(M, a, b) - - if len(b.shape) < 2: - if method.lower() == 'sinkhorn': - res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - if log: - return nx.sum(M * res[0]), res[1] - else: - return nx.sum(M * res) - - else: - - if method.lower() == 'sinkhorn': - return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return sinkhorn_log(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, warmstart=warmstart, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, - stopThr=stopThr, warmstart=warmstart, - verbose=verbose, log=log, warn=warn, - **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): - r""" - Solve the entropic regularization optimal transport problem and return the OT matrix - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp - matrix scaling algorithm as proposed in :ref:`[2] ` - - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - - .. _references-sinkhorn-knopp: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information - Processing Systems (NIPS) 26, 2013 - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - - if len(a) == 0: - a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) - if len(b) == 0: - b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) - - # init data - dim_a = len(a) - dim_b = b.shape[0] - - if len(b.shape) > 1: - n_hists = b.shape[1] - else: - n_hists = 0 - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if warmstart is None: - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b - else: - u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - - K = nx.exp(M / (-reg)) - - Kp = (1 / a).reshape(-1, 1) * K - - err = 1 - for ii in range(numItermax): - uprev = u - vprev = v - KtransposeU = nx.dot(K.T, u) - v = b / KtransposeU - u = 1. / nx.dot(Kp, v) - - if (nx.any(KtransposeU == 0) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - warnings.warn('Warning: numerical errors at iteration %d' % ii) - u = uprev - v = vprev - break - if ii % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - if n_hists: - tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) - else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = nx.einsum('i,ij,j->j', u, K, v) - err = nx.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) - - if err < stopThr: - break - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - log['u'] = u - log['v'] = v - - if n_hists: # return only loss - res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) - if log: - return res, log - else: - return res - - else: # return OT matrix - - if log: - return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log - else: - return u.reshape((-1, 1)) * K * v.reshape((1, -1)) - - -def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): - r""" - Solve the entropic regularization optimal transport problem in log space - and return the OT matrix - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm :ref:`[2] ` with the - implementation from :ref:`[34] ` - - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.sinkhorn(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - - .. _references-sinkhorn-log: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of - Optimal Transport, Advances in Neural Information Processing - Systems (NIPS) 26, 2013 - - .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., - Trouvé, A., & Peyré, G. (2019, April). Interpolating between - optimal transport and MMD using Sinkhorn divergences. In The - 22nd International Conference on Artificial Intelligence and - Statistics (pp. 2681-2690). PMLR. - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - - if len(a) == 0: - a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) - if len(b) == 0: - b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) - - # init data - dim_a = len(a) - dim_b = b.shape[0] - - if len(b.shape) > 1: - n_hists = b.shape[1] - else: - n_hists = 0 - - # in case of multiple historgrams - if n_hists > 1 and warmstart is None: - warmstart = [None] * n_hists - - if n_hists: # we do not want to use tensors sor we do a loop - - lst_loss = [] - lst_u = [] - lst_v = [] - - for k in range(n_hists): - res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, - verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) - - if log: - lst_loss.append(nx.sum(M * res[0])) - lst_u.append(res[1]['log_u']) - lst_v.append(res[1]['log_v']) - else: - lst_loss.append(nx.sum(M * res)) - res = nx.stack(lst_loss) - if log: - log = {'log_u': nx.stack(lst_u, 1), - 'log_v': nx.stack(lst_v, 1), } - log['u'] = nx.exp(log['log_u']) - log['v'] = nx.exp(log['log_v']) - return res, log - else: - return res - - else: - - if log: - log = {'err': []} - - Mr = - M / reg - - # we assume that no distances are null except those of the diagonal of - # distances - if warmstart is None: - u = nx.zeros(dim_a, type_as=M) - v = nx.zeros(dim_b, type_as=M) - else: - u, v = warmstart - - def get_logT(u, v): - if n_hists: - return Mr[:, :, None] + u + v - else: - return Mr + u[:, None] + v[None, :] - - loga = nx.log(a) - logb = nx.log(b) - - err = 1 - for ii in range(numItermax): - - v = logb - nx.logsumexp(Mr + u[:, None], 0) - u = loga - nx.logsumexp(Mr + v[None, :], 1) - - if ii % 10 == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) - err = nx.norm(tmp2 - b) # violation of marginal - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - if err < stopThr: - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - - if log: - log['niter'] = ii - log['log_u'] = u - log['log_v'] = v - log['u'] = nx.exp(u) - log['v'] = nx.exp(v) - - return nx.exp(get_logT(u, v)), log - - else: - return nx.exp(get_logT(u, v)) - - -def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None): - r""" - Solve the entropic regularization optimal transport problem and return the OT matrix - - The algorithm used is based on the paper :ref:`[22] ` - which is a stochastic version of the Sinkhorn-Knopp - algorithm :ref:`[2] ` - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.bregman.greenkhorn(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - - .. _references-greenkhorn: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information - Processing Systems (NIPS) 26, 2013 - - .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time - approximation algorithms for optimal transport via Sinkhorn - iteration, Advances in Neural Information Processing - Systems (NIPS) 31, 2017 - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - if nx.__name__ in ("jax", "tf"): - raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " - "compatible with neither JAX nor TF") - - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] - - dim_a = a.shape[0] - dim_b = b.shape[0] - - K = nx.exp(-M / reg) - - if warmstart is None: - u = nx.full((dim_a,), 1. / dim_a, type_as=K) - v = nx.full((dim_b,), 1. / dim_b, type_as=K) - else: - u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) - G = u[:, None] * K * v[None, :] - - viol = nx.sum(G, axis=1) - a - viol_2 = nx.sum(G, axis=0) - b - stopThr_val = 1 - if log: - log = dict() - log['u'] = u - log['v'] = v - - for ii in range(numItermax): - i_1 = nx.argmax(nx.abs(viol)) - i_2 = nx.argmax(nx.abs(viol_2)) - m_viol_1 = nx.abs(viol[i_1]) - m_viol_2 = nx.abs(viol_2[i_2]) - stopThr_val = nx.maximum(m_viol_1, m_viol_2) - - if m_viol_1 > m_viol_2: - old_u = u[i_1] - new_u = a[i_1] / nx.dot(K[i_1, :], v) - G[i_1, :] = new_u * K[i_1, :] * v - - viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] - viol_2 += (K[i_1, :].T * (new_u - old_u) * v) - u[i_1] = new_u - else: - old_v = v[i_2] - new_v = b[i_2] / nx.dot(K[:, i_2].T, u) - G[:, i_2] = u * K[:, i_2] * new_v - # aviol = (G@one_m - a) - # aviol_2 = (G.T@one_n - b) - viol += (-old_v + new_v) * K[:, i_2] * u - viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2] - v[i_2] = new_v - - if stopThr_val <= stopThr: - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - - if log: - log["n_iter"] = ii - log['u'] = u - log['v'] = v - - if log: - return G, log - else: - return G - - -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, - warmstart=None, verbose=False, print_period=20, - log=False, warn=True, **kwargs): - r""" - Solve the entropic regularization OT problem with log stabilization - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target - weights (histograms, both sum to 1) - - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` - but with the log stabilization - proposed in :ref:`[10] ` an defined in - :ref:`[9] ` (Algo 3.1) . - - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) - samples in the target domain - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` - for log scaling - warmstart : table of vectors - if given then starting values for alpha and beta log scalings - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> import ot - >>> a=[.5,.5] - >>> b=[.5,.5] - >>> M=[[0.,1.],[1.,0.]] - >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - - .. _references-sinkhorn-stabilized: - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of - Optimal Transport, Advances in Neural Information Processing - Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms - for Entropy Regularized Transport Problems. - arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] - - # test if multiple target - if len(b.shape) > 1: - n_hists = b.shape[1] - a = a[:, None] - else: - n_hists = 0 - - # init data - dim_a = len(a) - dim_b = len(b) - - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if warmstart is None: - alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) - else: - alpha, beta = warmstart - - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - else: - u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) - u /= dim_a - v /= dim_b - - def get_K(alpha, beta): - """log space computation""" - return nx.exp(-(M - alpha.reshape((dim_a, 1)) - - beta.reshape((1, dim_b))) / reg) - - def get_Gamma(alpha, beta, u, v): - """log space gamma computation""" - return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) - / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) - - K = get_K(alpha, beta) - transp = K - err = 1 - for ii in range(numItermax): - - uprev = u - vprev = v - - # sinkhorn update - v = b / (nx.dot(K.T, u)) - u = a / (nx.dot(K, v)) - - # remove numerical problems and store them in K - if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: - if n_hists: - alpha, beta = alpha + reg * \ - nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) - else: - alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) - if n_hists: - u = nx.ones((dim_a, n_hists), type_as=M) / dim_a - v = nx.ones((dim_b, n_hists), type_as=M) / dim_b - else: - u = nx.ones(dim_a, type_as=M) / dim_a - v = nx.ones(dim_b, type_as=M) / dim_b - K = get_K(alpha, beta) - - if ii % print_period == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - if n_hists: - err_u = nx.max(nx.abs(u - uprev)) - err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0) - err_v = nx.max(nx.abs(v - vprev)) - err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0) - err = 0.5 * (err_u + err_v) - else: - transp = get_Gamma(alpha, beta, u, v) - err = nx.norm(nx.sum(transp, axis=0) - b) - if log: - log['err'].append(err) - - if verbose: - if ii % (print_period * 20) == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - if err <= stopThr: - break - - if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): - # we have reached the machine precision - # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %d' % ii) - u = uprev - v = vprev - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - if n_hists: - alpha = alpha[:, None] - beta = beta[:, None] - logu = alpha / reg + nx.log(u) - logv = beta / reg + nx.log(v) - log["n_iter"] = ii - log['logu'] = logu - log['logv'] = logv - log['alpha'] = alpha + reg * nx.log(u) - log['beta'] = beta + reg * nx.log(v) - log['warmstart'] = (log['alpha'], log['beta']) - if n_hists: - res = nx.stack([ - nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) - for i in range(n_hists) - ]) - return res, log - - else: - return get_Gamma(alpha, beta, u, v), log - else: - if n_hists: - res = nx.stack([ - nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) - for i in range(n_hists) - ]) - return res - else: - return get_Gamma(alpha, beta, u, v) - - -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, - numInnerItermax=100, tau=1e3, stopThr=1e-9, - warmstart=None, verbose=False, print_period=10, - log=False, warn=True, **kwargs): - r""" - Solve the entropic regularization optimal transport problem with log - stabilization and epsilon scaling. - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg}\cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[2] ` - but with the log stabilization - proposed in :ref:`[10] ` and the log scaling - proposed in :ref:`[9] ` algorithm 3.2 - - Parameters - ---------- - a : array-like, shape (dim_a,) - samples weights in the source domain - b : array-like, shape (dim_b,) - samples in the target domain - M : array-like, shape (dim_a, dim_b) - loss matrix - reg : float - Regularization term >0 - tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` - for log scaling - warmstart : tuple of vectors - if given then starting values for alpha and beta log scalings - numItermax : int, optional - Max number of iterations - numInnerItermax : int, optional - Max number of iterations in the inner slog stabilized sinkhorn - epsilon0 : int, optional - first epsilon regularization value (then exponential decrease to reg) - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - Returns - ------- - gamma : array-like, shape (dim_a, dim_b) - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - Examples - -------- - >>> import ot - >>> a=[.5, .5] - >>> b=[.5, .5] - >>> M=[[0., 1.], [1., 0.]] - >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) - array([[0.36552929, 0.13447071], - [0.13447071, 0.36552929]]) - - .. _references-sinkhorn-epsilon-scaling: - References - ---------- - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal - Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for - Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.optim.cg : General regularized OT - """ - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - - if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] - if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] - - # init data - dim_a = len(a) - dim_b = len(b) - - # nrelative umerical precision with 64 bits - numItermin = 35 - numItermax = max(numItermin, numItermax) # ensure that last velue is exact - - ii = 0 - if log: - log = {'err': []} - - # we assume that no distances are null except those of the diagonal of - # distances - if warmstart is None: - alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) - else: - alpha, beta = warmstart - - # print(np.min(K)) - def get_reg(n): # exponential decreasing - return (epsilon0 - reg) * np.exp(-n) + reg - - err = 1 - for ii in range(numItermax): - - regi = get_reg(ii) - - G, logi = sinkhorn_stabilized(a, b, M, regi, - numItermax=numInnerItermax, stopThr=stopThr, - warmstart=(alpha, beta), verbose=False, - print_period=20, tau=tau, log=True) - - alpha = logi['alpha'] - beta = logi['beta'] - - if ii % (print_period) == 0: # spsion nearly converged - # we can speed up the process by checking for the error only all - # the 10th iterations - transp = G - err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ - nx.norm(nx.sum(transp, axis=1) - a) ** 2 - if log: - log['err'].append(err) - - if verbose: - if ii % (print_period * 10) == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - if err <= stopThr and ii > numItermin: - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['alpha'] = alpha - log['beta'] = beta - log['warmstart'] = (log['alpha'], log['beta']) - log['niter'] = ii - return G, log - else: - return G - - -def geometricBar(weights, alldistribT): - """return the weighted geometric mean of distributions""" - weights, alldistribT = list_to_array(weights, alldistribT) - nx = get_backend(weights, alldistribT) - assert (len(weights) == alldistribT.shape[1]) - return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) - - -def geometricMean(alldistribT): - """return the geometric mean of distributions""" - alldistribT = list_to_array(alldistribT) - nx = get_backend(alldistribT) - return nx.exp(nx.mean(nx.log(alldistribT), axis=1)) - - -def projR(gamma, p): - """return the KL projection on the row constrints """ - gamma, p = list_to_array(gamma, p) - nx = get_backend(gamma, p) - return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T - - -def projC(gamma, q): - """return the KL projection on the column constrints """ - gamma, q = list_to_array(gamma, q) - nx = get_backend(gamma, q) - return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10) - - -def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling - algorithm as proposed in :ref:`[3] ` - - Parameters - ---------- - A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` - M : array-like, shape (dim, dim) - loss matrix for OT - reg : float - Regularization term > 0 - method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' - weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - - Returns - ------- - a : (dim,) array-like - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-barycenter: - References - ---------- - - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). - Iterative Bregman projections for regularized transportation problems. - SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - - """ - - if method.lower() == 'sinkhorn': - return barycenter_sinkhorn(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_stabilized': - return barycenter_stabilized(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) - elif method.lower() == 'sinkhorn_log': - return _barycenter_sinkhorn_log(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix - scaling algorithm as proposed in :ref:`[3]`. - - Parameters - ---------- - A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` - M : array-like, shape (dim, dim) - loss matrix for OT - reg : float - Regularization term > 0 - weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - - Returns - ------- - a : (dim,) array-like - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-barycenter-sinkhorn: - References - ---------- - - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). - Iterative Bregman projections for regularized transportation problems. - SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - - """ - - A, M = list_to_array(A, M) - - nx = get_backend(A, M) - - if weights is None: - weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] - else: - assert (len(weights) == A.shape[1]) - - if log: - log = {'err': []} - - K = nx.exp(-M / reg) - - err = 1 - - UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) - - u = (geometricMean(UKv) / UKv.T).T - - for ii in range(numItermax): - - UKv = u * nx.dot(K.T, A / nx.dot(K, u)) - u = (u.T * geometricBar(weights, UKv)).T / UKv - - if ii % 10 == 1: - err = nx.sum(nx.std(UKv, axis=1)) - - # log and verbose print - if log: - log['err'].append(err) - - if err < stopThr: - break - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - return geometricBar(weights, UKv), log - else: - return geometricBar(weights, UKv) - - -def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, - numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, - **kwargs): - r""" - Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: - - .. math:: - \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) - - where : - - - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one - - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) - - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations - - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter - - This problem is considered in :ref:`[20] ` (Algorithm 2). - There are two differences with the following codes: - - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in - :ref:`[20] ` (Algorithm 2). This can be seen as a discrete - implementation of the fixed-point algorithm of - :ref:`[43] ` proposed in the continuous setting. - - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the - transport plan in :ref:`[20] ` (Algorithm 2). - - Parameters - ---------- - measures_locations : list of N (k_i,d) array-like - The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space - (:math:`k_i` can be different for each element of the list) - measures_weights : list of N (k_i,) array-like - Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one - representing the weights of each discrete input measure - - X_init : (k,d) array-like - Initialization of the support locations (on `k` atoms) of the barycenter - reg : float - Regularization term >0 - b : (k,) array-like - Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (N,) array-like - Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - - numItermax : int, optional - Max number of iterations - numInnerItermax : int, optional - Max number of iterations when calculating the transport plans with Sinkhorn - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - Returns - ------- - X : (k,d) array-like - Support locations (on k atoms) of the barycenter - - See Also - -------- - ot.bregman.sinkhorn : Entropic regularized OT solver - ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming - - .. _references-free-support-barycenter: - References - ---------- - .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. - - .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. - - """ - nx = get_backend(*measures_locations, *measures_weights, X_init) - - iter_count = 0 - - N = len(measures_locations) - k = X_init.shape[0] - d = X_init.shape[1] - if b is None: - b = nx.ones((k,), type_as=X_init) / k - if weights is None: - weights = nx.ones((N,), type_as=X_init) / N - - X = X_init - - log_dict = {} - displacement_square_norms = [] - - displacement_square_norm = stopThr + 1. - - while (displacement_square_norm > stopThr and iter_count < numItermax): - - T_sum = nx.zeros((k, d), type_as=X_init) - - for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): - M_i = dist(X, measure_locations_i) - T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, - numItermax=numInnerItermax, **kwargs) - T_sum = T_sum + weight_i * 1. / \ - b[:, None] * nx.dot(T_i, measure_locations_i) - - displacement_square_norm = nx.sum((T_sum - X) ** 2) - if log: - displacement_square_norms.append(displacement_square_norm) - - X = T_sum - - if verbose: - print('iteration %d, displacement_square_norm=%f\n', - iter_count, displacement_square_norm) - - iter_count += 1 - - if log: - log_dict['displacement_square_norms'] = displacement_square_norms - return X, log_dict - else: - return X - - -def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the entropic wasserstein barycenter in log-domain - """ - - A, M = list_to_array(A, M) - dim, n_hists = A.shape - - nx = get_backend(A, M) - - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and tf. Use numpy or torch arrays instead." - ) - - if weights is None: - weights = nx.ones(n_hists, type_as=A) / n_hists - else: - assert (len(weights) == A.shape[1]) - - if log: - log = {'err': []} - - M = - M / reg - logA = nx.log(A + 1e-15) - log_KU, G = nx.zeros((2, *logA.shape), type_as=A) - err = 1 - for ii in range(numItermax): - log_bar = nx.zeros(dim, type_as=A) - for k in range(n_hists): - f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) - log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) - log_bar = log_bar + weights[k] * log_KU[:, k] - - if ii % 10 == 1: - err = nx.exp(G + log_KU).std(axis=1).sum() - - # log and verbose print - if log: - log['err'].append(err) - - if err < stopThr: - break - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - G = log_bar[:, None] - log_KU - - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - return nx.exp(log_bar), log - else: - return nx.exp(log_bar) - - -def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling - algorithm as proposed in :ref:`[3] ` - - Parameters - ---------- - A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` - M : array-like, shape (dim, dim) - loss matrix for OT - reg : float - Regularization term > 0 - tau : float - threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` - for log scaling - weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - - Returns - ------- - a : (dim,) array-like - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-barycenter-stabilized: - References - ---------- - - .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). - Iterative Bregman projections for regularized transportation problems. - SIAM Journal on Scientific Computing, 37(2), A1111-A1138. - - """ - - A, M = list_to_array(A, M) - - nx = get_backend(A, M) - - dim, n_hists = A.shape - if weights is None: - weights = nx.ones((n_hists,), type_as=M) / n_hists - else: - assert (len(weights) == A.shape[1]) - - if log: - log = {'err': []} - - u = nx.ones((dim, n_hists), type_as=M) / dim - v = nx.ones((dim, n_hists), type_as=M) / dim - - K = nx.exp(-M / reg) - - err = 1. - alpha = nx.zeros((dim,), type_as=M) - beta = nx.zeros((dim,), type_as=M) - q = nx.ones((dim,), type_as=M) / dim - for ii in range(numItermax): - qprev = q - Kv = nx.dot(K, v) - u = A / Kv - Ktu = nx.dot(K.T, u) - q = geometricBar(weights, Ktu) - Q = q[:, None] - v = Q / Ktu - absorbing = False - if nx.any(u > tau) or nx.any(v > tau): - absorbing = True - alpha += reg * nx.log(nx.max(u, 1)) - beta += reg * nx.log(nx.max(v, 1)) - K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) - v = nx.ones(tuple(v.shape), type_as=v) - Kv = nx.dot(K, v) - if (nx.any(Ktu == 0.) - or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) - or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): - # we have reached the machine precision - # come back to previous solution and quit loop - warnings.warn('Numerical errors at iteration %s' % ii) - q = qprev - break - if (ii % 10 == 0 and not absorbing) or ii == 0: - # we can speed up the process by checking for the error only all - # the 10th iterations - err = nx.max(nx.abs(u * Kv - A)) - if log: - log['err'].append(err) - if err < stopThr: - break - if verbose: - if ii % 50 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - else: - if warn: - warnings.warn("Stabilized Sinkhorn did not converge." + - "Try a larger entropy `reg`" + - "Or a larger absorption threshold `tau`.") - if log: - log['niter'] = ii - log['logu'] = nx.log(u + 1e-16) - log['logv'] = nx.log(v + 1e-16) - return q, log - else: - return q - - -def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): - r"""Compute the debiased Sinkhorn barycenter of distributions A - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence - (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) - - :math:`\mathbf{a}_i` are training distributions in the columns of matrix - :math:`\mathbf{A}` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and - the cost matrix for OT - - The algorithm used for solving the problem is the debiased Sinkhorn - algorithm as proposed in :ref:`[37] ` - - Parameters - ---------- - A : array-like, shape (dim, n_hists) - `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` - M : array-like, shape (dim, dim) - loss matrix for OT - reg : float - Regularization term > 0 - method : str (optional) - method used for the solver either 'sinkhorn' or 'sinkhorn_log' - weights : array-like, shape (n_hists,) - Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - - Returns - ------- - a : (dim,) array-like - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-barycenter-debiased: - References - ---------- - .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International - Conference on Machine Learning, PMLR 119:4692-4701, 2020 - """ - - if method.lower() == 'sinkhorn': - return _barycenter_debiased(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, **kwargs) - elif method.lower() == 'sinkhorn_log': - return _barycenter_debiased_log(A, M, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, warn=True): - r"""Compute the debiased sinkhorn barycenter of distributions A. - """ - - A, M = list_to_array(A, M) - - nx = get_backend(A, M) - - if weights is None: - weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] - else: - assert (len(weights) == A.shape[1]) - - if log: - log = {'err': []} - - K = nx.exp(-M / reg) - - err = 1 - - UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) - - u = (geometricMean(UKv) / UKv.T).T - c = nx.ones(A.shape[0], type_as=A) - bar = nx.ones(A.shape[0], type_as=A) - - for ii in range(numItermax): - bold = bar - UKv = nx.dot(K, A / nx.dot(K, u)) - bar = c * geometricBar(weights, UKv) - u = bar[:, None] / UKv - c = (c * bar / nx.dot(K, c)) ** 0.5 - - if ii % 10 == 9: - err = abs(bar - bold).max() / max(bar.max(), 1.) - - # log and verbose print - if log: - log['err'].append(err) - - # debiased Sinkhorn does not converge monotonically - # guarantee a few iterations are done before stopping - if err < stopThr and ii > 20: - break - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - return bar, log - else: - return bar - - -def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, - stopThr=1e-4, verbose=False, log=False, - warn=True): - r"""Compute the debiased sinkhorn barycenter in log domain. - """ - - A, M = list_to_array(A, M) - dim, n_hists = A.shape - - nx = get_backend(A, M) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) - - if weights is None: - weights = nx.ones(n_hists, type_as=A) / n_hists - else: - assert (len(weights) == A.shape[1]) - - if log: - log = {'err': []} - - M = - M / reg - logA = nx.log(A + 1e-15) - log_KU, G = nx.zeros((2, *logA.shape), type_as=A) - c = nx.zeros(dim, type_as=A) - err = 1 - for ii in range(numItermax): - log_bar = nx.zeros(dim, type_as=A) - for k in range(n_hists): - f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) - log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) - log_bar += weights[k] * log_KU[:, k] - log_bar += c - if ii % 10 == 1: - err = nx.exp(G + log_KU).std(axis=1).sum() - - # log and verbose print - if log: - log['err'].append(err) - - if err < stopThr and ii > 20: - break - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - G = log_bar[:, None] - log_KU - for _ in range(10): - c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) - - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - return nx.exp(log_bar), log - else: - return nx.exp(log_bar) - - -def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, - stopThr=1e-4, verbose=False, log=False, - warn=True, **kwargs): - r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` - where :math:`\mathbf{A}` is a collection of 2D images. - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions - of matrix :math:`\mathbf{A}` - - `reg` is the regularization strength scalar value - - The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm - as proposed in :ref:`[21] ` - - Parameters - ---------- - A : array-like, shape (n_hists, width, height) - `n` distributions (2D images) of size `width` x `height` - reg : float - Regularization term >0 - weights : array-like, shape (n_hists,) - Weights of each image on the simplex (barycentric coodinates) - method : string, optional - method used for the solver either 'sinkhorn' or 'sinkhorn_log' - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (> 0) - stabThr : float, optional - Stabilization threshold to avoid numerical precision issue - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - Returns - ------- - a : array-like, shape (width, height) - 2D Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-convolutional-barycenter-2d: - References - ---------- - - .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, - A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: - Efficient optimal transportation on geometric domains. ACM Transactions - on Graphics (TOG), 34(4), 66 - - .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th - International Conference on Machine Learning, PMLR 119:4692-4701, 2020 - """ - - if method.lower() == 'sinkhorn': - return _convolutional_barycenter2d(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return _convolutional_barycenter2d_log(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, - stopThr=1e-9, stabThr=1e-30, verbose=False, - log=False, warn=True): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images. - """ - - A = list_to_array(A) - - nx = get_backend(A) - - if weights is None: - weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] - else: - assert (len(weights) == A.shape[0]) - - if log: - log = {'err': []} - - bar = nx.ones(A.shape[1:], type_as=A) - bar /= nx.sum(bar) - U = nx.ones(A.shape, type_as=A) - V = nx.ones(A.shape, type_as=A) - err = 1 - - # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, A.shape[1], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-(X - Y) ** 2 / reg) - - t = nx.linspace(0, 1, A.shape[2], type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-(X - Y) ** 2 / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy - - KU = convol_imgs(U) - for ii in range(numItermax): - V = bar[None] / KU - KV = convol_imgs(V) - U = A / KV - KU = convol_imgs(U) - bar = nx.exp( - nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) - ) - if ii % 10 == 9: - err = nx.sum(nx.std(V * KU, axis=0)) - # log and verbose print - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - if err < stopThr: - break - - else: - if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") - if log: - log['niter'] = ii - log['U'] = U - return bar, log - else: - return bar - - -def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, - stopThr=1e-4, stabThr=1e-30, verbose=False, - log=False, warn=True): - r"""Compute the entropic regularized wasserstein barycenter of distributions A - where A is a collection of 2D images in log-domain. - """ - - A = list_to_array(A) - - nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) - - n_hists, width, height = A.shape - - if weights is None: - weights = nx.ones((n_hists,), type_as=A) / n_hists - else: - assert (len(weights) == n_hists) - - if log: - log = {'err': []} - - err = 1 - # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = - (X - Y) ** 2 / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = - (X - Y) ** 2 / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img - - logA = nx.log(A + stabThr) - log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) - err = 1 - for ii in range(numItermax): - log_bar = nx.zeros((width, height), type_as=A) - for k in range(n_hists): - f = logA[k] - convol_img(G[k]) - log_KU[k] = convol_img(f) - log_bar = log_bar + weights[k] * log_KU[k] - - if ii % 10 == 9: - err = nx.exp(G + log_KU).std(axis=0).sum() - # log and verbose print - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - if err < stopThr: - break - G = log_bar[None, :, :] - log_KU - - else: - if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") - if log: - log['niter'] = ii - return nx.exp(log_bar), log - else: - return nx.exp(log_bar) - - -def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", - numItermax=10000, stopThr=1e-3, - verbose=False, log=False, warn=True, - **kwargs): - r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` - where :math:`\mathbf{A}` is a collection of 2D images. - - The function solves the following optimization problem: - - .. math:: - \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) - - where : - - - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein - distance (see :py:func:`ot.bregman.barycenter_debiased`) - - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two - dimensions of matrix :math:`\mathbf{A}` - - `reg` is the regularization strength scalar value - - The algorithm used for solving the problem is the debiased Sinkhorn scaling - algorithm as proposed in :ref:`[37] ` - - Parameters - ---------- - A : array-like, shape (n_hists, width, height) - `n` distributions (2D images) of size `width` x `height` - reg : float - Regularization term >0 - weights : array-like, shape (n_hists,) - Weights of each image on the simplex (barycentric coodinates) - method : string, optional - method used for the solver either 'sinkhorn' or 'sinkhorn_log' - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (> 0) - stabThr : float, optional - Stabilization threshold to avoid numerical precision issue - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - - Returns - ------- - a : array-like, shape (width, height) - 2D Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-convolutional-barycenter2d-debiased: - References - ---------- - - .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International - Conference on Machine Learning, PMLR 119:4692-4701, 2020 - """ - - if method.lower() == 'sinkhorn': - return _convolutional_barycenter2d_debiased(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - elif method.lower() == 'sinkhorn_log': - return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, - numItermax=numItermax, - stopThr=stopThr, verbose=verbose, - log=log, warn=warn, - **kwargs) - else: - raise ValueError("Unknown method '%s'." % method) - - -def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, - stopThr=1e-3, stabThr=1e-15, verbose=False, - log=False, warn=True): - r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. - """ - - A = list_to_array(A) - n_hists, width, height = A.shape - - nx = get_backend(A) - - if weights is None: - weights = nx.ones((n_hists,), type_as=A) / n_hists - else: - assert (len(weights) == n_hists) - - if log: - log = {'err': []} - - bar = nx.ones((width, height), type_as=A) - bar /= width * height - U = nx.ones(A.shape, type_as=A) - V = nx.ones(A.shape, type_as=A) - c = nx.ones(A.shape[1:], type_as=A) - err = 1 - - # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K1 = nx.exp(-(X - Y) ** 2 / reg) - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - K2 = nx.exp(-(X - Y) ** 2 / reg) - - def convol_imgs(imgs): - kx = nx.einsum("...ij,kjl->kil", K1, imgs) - kxy = nx.einsum("...ij,klj->kli", K2, kx) - return kxy - - KU = convol_imgs(U) - for ii in range(numItermax): - V = bar[None] / KU - KV = convol_imgs(V) - U = A / KV - KU = convol_imgs(U) - bar = c * nx.exp( - nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) - ) - - for _ in range(10): - c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 - - if ii % 10 == 9: - err = nx.sum(nx.std(V * KU, axis=0)) - # log and verbose print - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - # debiased Sinkhorn does not converge monotonically - # guarantee a few iterations are done before stopping - if err < stopThr and ii > 20: - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - log['U'] = U - return bar, log - else: - return bar - - -def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, - stopThr=1e-3, stabThr=1e-30, verbose=False, - log=False, warn=True): - r"""Compute the debiased barycenter of 2D images in log-domain. - """ - - A = list_to_array(A) - n_hists, width, height = A.shape - nx = get_backend(A) - if nx.__name__ in ("jax", "tf"): - raise NotImplementedError( - "Log-domain functions are not yet implemented" - " for Jax and TF. Use numpy or torch arrays instead." - ) - if weights is None: - weights = nx.ones((n_hists,), type_as=A) / n_hists - else: - assert (len(weights) == A.shape[0]) - - if log: - log = {'err': []} - - err = 1 - # build the convolution operator - # this is equivalent to blurring on horizontal then vertical directions - t = nx.linspace(0, 1, width, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M1 = - (X - Y) ** 2 / reg - - t = nx.linspace(0, 1, height, type_as=A) - [Y, X] = nx.meshgrid(t, t) - M2 = - (X - Y) ** 2 / reg - - def convol_img(log_img): - log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) - log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T - return log_img - - logA = nx.log(A + stabThr) - log_bar, c = nx.zeros((2, width, height), type_as=A) - log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) - err = 1 - for ii in range(numItermax): - log_bar = nx.zeros((width, height), type_as=A) - for k in range(n_hists): - f = logA[k] - convol_img(G[k]) - log_KU[k] = convol_img(f) - log_bar = log_bar + weights[k] * log_KU[k] - log_bar += c - for _ in range(10): - c = 0.5 * (c + log_bar - convol_img(c)) - - if ii % 10 == 9: - err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) - # log and verbose print - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format( - 'It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - if err < stopThr and ii > 20: - break - G = log_bar[None, :, :] - log_KU - - else: - if warn: - warnings.warn("Convolutional Sinkhorn did not converge. " - "Try a larger number of iterations `numItermax` " - "or a larger entropy `reg`.") - if log: - log['niter'] = ii - return nx.exp(log_bar), log - else: - return nx.exp(log_bar) - - -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, - stopThr=1e-3, verbose=False, log=False, warn=True): - r""" - Compute the unmixing of an observation with a given dictionary using Wasserstein distance - - The function solve the following optimization problem: - - .. math:: - - \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad - (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) + - \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h}) - - - where : - - - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, - its expected shape is `(dim_a, n_atoms)` - - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` - - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` - - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` - - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the - cost matrix (`dim_a`, `dim_a`) for OT data fitting - - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization - term and the cost matrix (`dim_prior`, `n_atoms`) regularization - - :math:`\alpha` weight data fitting and regularization - - The optimization problem is solved following the algorithm described - in :ref:`[4] ` - - - Parameters - ---------- - a : array-like, shape (dim_a) - observed distribution (histogram, sums to 1) - D : array-like, shape (dim_a, n_atoms) - dictionary matrix - M : array-like, shape (dim_a, dim_a) - loss matrix - M0 : array-like, shape (n_atoms, dim_prior) - loss matrix - h0 : array-like, shape (n_atoms,) - prior on the estimated unmixing h - reg : float - Regularization term >0 (Wasserstein data fitting) - reg0 : float - Regularization term >0 (Wasserstein reg with h0) - alpha : float - How much should we trust the prior ([0,1]) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - Returns - ------- - h : array-like, shape (n_atoms,) - Wasserstein barycenter - log : dict - log dictionary return only if log==True in parameters - - - .. _references-unmix: - References - ---------- - - .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, - Supervised planetary unmixing with optimal transport, Workshop - on Hyperspectral Image and Signal Processing : - Evolution in Remote Sensing (WHISPERS), 2016. - """ - - a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) - - nx = get_backend(a, D, M, M0, h0) - - # M = M/np.median(M) - K = nx.exp(-M / reg) - - # M0 = M0/np.median(M0) - K0 = nx.exp(-M0 / reg0) - old = h0 - - err = 1 - # log = {'niter':0, 'all_err':[]} - if log: - log = {'err': []} - - for ii in range(numItermax): - K = projC(K, a) - K0 = projC(K0, h0) - new = nx.sum(K0, axis=1) - # we recombine the current selection from dictionnary - inv_new = nx.dot(D, new) - other = nx.sum(K, axis=1) - # geometric interpolation - delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) - K = projR(K, delta) - K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0 - err = nx.norm(nx.sum(K0, axis=1) - old) - old = new - if log: - log['err'].append(err) - - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - if err < stopThr: - break - else: - if warn: - warnings.warn("Unmixing algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - log['niter'] = ii - return nx.sum(K0, axis=1), log - else: - return nx.sum(K0, axis=1) - - -def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, - stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): - r'''Joint OT and proportion estimation for multi-source target shift as - proposed in :ref:`[27] ` - - The function solves the following optimization problem: - - .. math:: - - \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k - W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) - - s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} - - where : - - - :math:`\lambda_k` is the weight of `k`-th source domain - - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance - (see :py:func:`ot.bregman.sinkhorn`) - - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain - defined as in [p. 5, :ref:`27 `], its expected shape - is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source - domain and `C` is the number of classes - - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` - - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` - - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in - [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` - - The problem consist in solving a Wasserstein barycenter problem to estimate - the proportions :math:`\mathbf{h}` in the target domain. - - The algorithm used for solving the problem is the Iterative Bregman projections algorithm - with two sets of marginal constraints related to the unknown vector - :math:`\mathbf{h}` and uniform target distribution. - - Parameters - ---------- - Xs : list of K array-like(nsk,d) - features of all source domains' samples - Ys : list of K array-like(nsk,) - labels of all source domains' samples - Xt : array-like (nt,d) - samples in the target domain - reg : float - Regularization term > 0 - metric : string, optional (default="sqeuclidean") - The ground metric for the Wasserstein problem - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on relative change in the barycenter (>0) - verbose : bool, optional (default=False) - Controls the verbosity of the optimization algorithm - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - - Returns - ------- - h : (C,) array-like - proportion estimation in the target domain - log : dict - log dictionary return only if log==True in parameters - - - .. _references-jcpot-barycenter: - References - ---------- - - .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia - "Optimal transport for multi-source domain adaptation under target shift", - International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. - ''' - - Xs = list_to_array(*Xs) - Ys = list_to_array(*Ys) - Xt = list_to_array(Xt) - - nx = get_backend(*Xs, *Ys, Xt) - - nbclasses = len(nx.unique(Ys[0])) - nbdomains = len(Xs) - - # log dictionary - if log: - log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} - - K = [] - M = [] - D1 = [] - D2 = [] - - # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 - for d in range(nbdomains): - dom = {} - nsk = Xs[d].shape[0] # get number of elements for this domain - dom['nbelem'] = nsk - classes = nx.unique(Ys[d]) # get number of classes for this domain - - # format classes to start from 0 for convenience - if nx.min(classes) != 0: - Ys[d] -= nx.min(classes) - classes = nx.unique(Ys[d]) - - # build the corresponding D_1 and D_2 matrices - Dtmp1 = np.zeros((nbclasses, nsk)) - Dtmp2 = np.zeros((nbclasses, nsk)) - - for c in classes: - nbelemperclass = float(nx.sum(Ys[d] == c)) - if nbelemperclass != 0: - Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. - Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) - D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) - D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) - - # build the cost matrix and the Gibbs kernel - Mtmp = dist(Xs[d], Xt, metric=metric) - M.append(Mtmp) - - Ktmp = nx.exp(-Mtmp / reg) - K.append(Ktmp) - - # uniform target distribution - a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) - - err = 1 - old_bary = nx.ones((nbclasses,), type_as=Xs[0]) - - for ii in range(numItermax): - - bary = nx.zeros((nbclasses,), type_as=Xs[0]) - - # update coupling matrices for marginal constraints w.r.t. uniform target distribution - for d in range(nbdomains): - K[d] = projC(K[d], a) - other = nx.sum(K[d], axis=1) - bary += nx.log(nx.dot(D1[d], other)) / nbdomains - - bary = nx.exp(bary) - - # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] - for d in range(nbdomains): - new = nx.dot(D2[d].T, bary) - K[d] = projR(K[d], new) - - err = nx.norm(bary - old_bary) - - old_bary = bary - - if log: - log['err'].append(err) - - if err < stopThr: - break - if verbose: - if ii % 200 == 0: - print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - else: - if warn: - warnings.warn("Algorithm did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - bary = bary / nx.sum(bary) - - if log: - log['niter'] = ii - log['M'] = M - log['D1'] = D1 - log['D2'] = D2 - log['gamma'] = K - return bary, log - else: - return bary - - -def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): - r''' - Solve the entropic regularization optimal transport problem and return the - OT matrix from empirical data - - The function solves the following optimization problem: - - .. math:: - \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - reg : float - Regularization term >0 - a : array-like, shape (n_samples_a,) - samples weights in the source domain - b : array-like, shape (n_samples_b,) - samples weights in the target domain - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return - the dual potentials only (to save memory). If False, calculate full - cost matrix and return outputs of sinkhorn function. - batchSize: int or tuple of 2 int, optional - Size of the batches used to compute the sinkhorn update without memory overhead. - When a tuple is provided it sets the size of the left/right batches. - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - - Returns - ------- - gamma : array-like, shape (n_samples_a, n_samples_b) - Regularized optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> n_samples_a = 2 - >>> n_samples_b = 2 - >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) - >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE - array([[4.99977301e-01, 2.26989344e-05], - [2.26989344e-05, 4.99977301e-01]]) - - - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal - Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for - Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. - ''' - - X_s, X_t = list_to_array(X_s, X_t) - - nx = get_backend(X_s, X_t) - - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) - - if isLazy: - if log: - dict_log = {"err": []} - - log_a, log_b = nx.log(a), nx.log(b) - if warmstart is None: - f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) - else: - f, g = warmstart - - if isinstance(batchSize, int): - bs, bt = batchSize, batchSize - elif isinstance(batchSize, tuple) and len(batchSize) == 2: - bs, bt = batchSize[0], batchSize[1] - else: - raise ValueError( - "Batch size must be in integer or a tuple of two integers") - - range_s, range_t = range(0, ns, bs), range(0, nt, bt) - - lse_f = nx.zeros((ns,), type_as=a) - lse_g = nx.zeros((nt,), type_as=a) - - X_s_np = nx.to_numpy(X_s) - X_t_np = nx.to_numpy(X_t) - - for i_ot in range(numIterMax): - - lse_f_cols = [] - for i in range_s: - M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) - M = nx.from_numpy(M, type_as=a) - lse_f_cols.append( - nx.logsumexp(g[None, :] - M / reg, axis=1) - ) - lse_f = nx.concatenate(lse_f_cols, axis=0) - f = log_a - lse_f - - lse_g_cols = [] - for j in range_t: - M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) - M = nx.from_numpy(M, type_as=a) - lse_g_cols.append( - nx.logsumexp(f[:, None] - M / reg, axis=0) - ) - lse_g = nx.concatenate(lse_g_cols, axis=0) - g = log_b - lse_g - - if (i_ot + 1) % 10 == 0: - m1_cols = [] - for i in range_s: - M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) - M = nx.from_numpy(M, type_as=a) - m1_cols.append( - nx.sum(nx.exp(f[i:i + bs, None] + - g[None, :] - M / reg), axis=1) - ) - m1 = nx.concatenate(m1_cols, axis=0) - err = nx.sum(nx.abs(m1 - a)) - if log: - dict_log["err"].append(err) - - if verbose and (i_ot + 1) % 100 == 0: - print("Error in marginal at iteration {} = {}".format( - i_ot + 1, err)) - - if err <= stopThr: - break - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") - if log: - dict_log["u"] = f - dict_log["v"] = g - return (f, g, dict_log) - else: - return (f, g) - - else: - M = dist(X_s, X_t, metric=metric) - if log: - pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=True, warmstart=warmstart, **kwargs) - return pi, log - else: - pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, - verbose=verbose, log=False, warmstart=warmstart, **kwargs) - return pi - - -def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, - verbose=False, log=False, warn=True, warmstart=None, **kwargs): - r''' - Solve the entropic regularization optimal transport problem from empirical - data and return the OT loss - - - The function solves the following optimization problem: - - .. math:: - W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot\Omega(\gamma) - - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - where : - - - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without - the entropic contribution). - - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - reg : float - Regularization term >0 - a : array-like, shape (n_samples_a,) - samples weights in the source domain - b : array-like, shape (n_samples_b,) - samples weights in the target domain - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - isLazy: boolean, optional - If True, then only calculate the cost matrix by block and return - the dual potentials only (to save memory). If False, calculate - full cost matrix and return outputs of sinkhorn function. - batchSize: int or tuple of 2 int, optional - Size of the batches used to compute the sinkhorn update without memory overhead. - When a tuple is provided it sets the size of the left/right batches. - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - W : (n_hists) array-like or float - Optimal transportation loss for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - - >>> n_samples_a = 2 - >>> n_samples_b = 2 - >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) - >>> b = np.full((n_samples_b, 3), 1/n_samples_b) - >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) - array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) - - - References - ---------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation - of Optimal Transport, Advances in Neural Information - Processing Systems (NIPS) 26, 2013 - - .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling - Algorithms for Entropy Regularized Transport Problems. - arXiv preprint arXiv:1610.06519. - - .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). - Scaling algorithms for unbalanced transport problems. - arXiv preprint arXiv:1607.05816. - ''' - - X_s, X_t = list_to_array(X_s, X_t) - - nx = get_backend(X_s, X_t) - - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) - - if isLazy: - if log: - f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - stopThr=stopThr, - isLazy=isLazy, - batchSize=batchSize, - verbose=verbose, log=log, - warn=warn, - warmstart=warmstart) - else: - f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, - numIterMax=numIterMax, - stopThr=stopThr, - isLazy=isLazy, batchSize=batchSize, - verbose=verbose, log=log, - warn=warn, - warmstart=warmstart) - - bs = batchSize if isinstance(batchSize, int) else batchSize[0] - range_s = range(0, ns, bs) - - loss = 0 - - X_s_np = nx.to_numpy(X_s) - X_t_np = nx.to_numpy(X_t) - - for i in range_s: - M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) - M_block = nx.from_numpy(M_block, type_as=a) - pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) - loss += nx.sum(M_block * pi_block) - - if log: - return loss, dict_log - else: - return loss - - else: - M = dist(X_s, X_t, metric=metric) - - if log: - sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, **kwargs) - return sinkhorn_loss, log - else: - sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, - stopThr=stopThr, verbose=verbose, log=log, - warn=warn, warmstart=warmstart, **kwargs) - return sinkhorn_loss - - -def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, verbose=False, - log=False, warn=True, warmstart=None, **kwargs): - r''' - Compute the sinkhorn divergence loss from empirical data - - The function solves the following optimization problems and return the - sinkhorn divergence :math:`S`: - - .. math:: - W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + - \mathrm{reg} \cdot\Omega(\gamma) - - W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F + - \mathrm{reg} \cdot\Omega(\gamma_a) - - W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F + - \mathrm{reg} \cdot\Omega(\gamma_b) - - S &= W - \frac{W_a + W_b}{2} - - .. math:: - s.t. \ \gamma \mathbf{1} &= \mathbf{a} - - \gamma^T \mathbf{1} &= \mathbf{b} - - \gamma &\geq 0 - - \gamma_a \mathbf{1} &= \mathbf{a} - - \gamma_a^T \mathbf{1} &= \mathbf{a} - - \gamma_a &\geq 0 - - \gamma_b \mathbf{1} &= \mathbf{b} - - \gamma_b^T \mathbf{1} &= \mathbf{b} - - \gamma_b &\geq 0 - - where : - - - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) - is the (`n_samples_a`, `n_samples_b`) metric cost matrix - (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) - - :math:`\Omega` is the entropic regularization term - :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) - - and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle - \gamma^*_b , \mathbf{M_b} \rangle_F)/2`. - - .. note: The current implementation does not account for the entropic contributions and thus differs from the - Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions - will be provided in a future release. - - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - reg : float - Regularization term >0 - a : array-like, shape (n_samples_a,) - samples weights in the source domain - b : array-like, shape (n_samples_b,) - samples weights in the target domain - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - warn : bool, optional - if True, raises a warning if the algorithm doesn't convergence. - warmstart: tuple of arrays, shape (dim_a, dim_b), optional - Initialization of dual potentials. If provided, the dual potentials should be given - (that is the logarithm of the u,v sinkhorn scaling vectors) - - Returns - ------- - W : (1,) array-like - Optimal transportation symmetrized loss for the given parameters - log : dict - log dictionary return only if log==True in parameters - - Examples - -------- - >>> n_samples_a = 2 - >>> n_samples_b = 4 - >>> reg = 0.1 - >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) - >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) - >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS - 1.499887176049052 - - - References - ---------- - .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative - Models with Sinkhorn Divergences, Proceedings of the Twenty-First - International Conference on Artificial Intelligence and Statistics, - (AISTATS) 21, 2018 - ''' - X_s, X_t = list_to_array(X_s, X_t) - - nx = get_backend(X_s, X_t) - if warmstart is None: - warmstart_a, warmstart_b = None, None - else: - u, v = warmstart - warmstart_a = (u, u) - warmstart_b = (v, v) - - if log: - sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart, **kwargs) - - sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_a, **kwargs) - - sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_b, **kwargs) - - sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ - (sinkhorn_loss_a + sinkhorn_loss_b) - - log = {} - log['sinkhorn_loss_ab'] = sinkhorn_loss_ab - log['sinkhorn_loss_a'] = sinkhorn_loss_a - log['sinkhorn_loss_b'] = sinkhorn_loss_b - log['log_sinkhorn_ab'] = log_ab - log['log_sinkhorn_a'] = log_a - log['log_sinkhorn_b'] = log_b - - return nx.maximum(0, sinkhorn_div), log - - else: - sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart, **kwargs) - - sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_a, **kwargs) - - sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, - numIterMax=numIterMax, stopThr=stopThr, - verbose=verbose, log=log, warn=warn, - warmstart=warmstart_b, **kwargs) - - sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ - (sinkhorn_loss_a + sinkhorn_loss_b) - return nx.maximum(0, sinkhorn_div) - - -def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, - restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, - verbose=False, log=False): - r""" - Screening Sinkhorn Algorithm for Regularized Optimal Transport - - The function solves an approximate dual of Sinkhorn divergence :ref:`[2] - ` which is written as the following optimization problem: - - .. math:: - - (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad - \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - - \langle \kappa \mathbf{u}, \mathbf{a} \rangle - - \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle - - where: - - .. math:: - - \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} - - .. math:: - - s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} - - e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} - - The parameters `kappa` and `epsilon` are determined w.r.t the couple number - budget of points (`ns_budget`, `nt_budget`), see Equation (5) - in :ref:`[26] ` - - - Parameters - ---------- - a: array-like, shape=(ns,) - samples weights in the source domain - b: array-like, shape=(nt,) - samples weights in the target domain - M: array-like, shape=(ns, nt) - Cost matrix - reg: `float` - Level of the entropy regularisation - ns_budget: `int`, default=None - Number budget of points to be kept in the source domain. - If it is None then 50% of the source sample points will be kept - nt_budget: `int`, default=None - Number budget of points to be kept in the target domain. - If it is None then 50% of the target sample points will be kept - uniform: `bool`, default=False - If `True`, the source and target distribution are supposed to be uniform, - i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` - restricted : `bool`, default=True - If `True`, a warm-start initialization for the L-BFGS-B solver - using a restricted Sinkhorn algorithm with at most 5 iterations - maxiter: `int`, default=10000 - Maximum number of iterations in LBFGS solver - maxfun: `int`, default=10000 - Maximum number of function evaluations in LBFGS solver - pgtol: `float`, default=1e-09 - Final objective function accuracy in LBFGS solver - verbose: `bool`, default=False - If `True`, display informations about the cardinals of the active sets - and the parameters kappa and epsilon - - - .. admonition:: Dependency - - To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck" - package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step. - - If Bottleneck isn't installed, the following error message appears: - - "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" - - - Returns - ------- - gamma : array-like, shape=(ns, nt) - Screened optimal transportation matrix for the given parameters - - log : `dict`, default=False - Log dictionary return only if log==True in parameters - - - .. _references-screenkhorn: - References - ----------- - - .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, - Advances in Neural Information Processing Systems (NIPS) 26, 2013 - - .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). - Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 - - """ - # check if bottleneck module exists - try: - import bottleneck - except ImportError: - warnings.warn( - "Bottleneck module is not installed. Install it from" - " https://pypi.org/project/Bottleneck/ for better performance.") - bottleneck = np - - a, b, M = list_to_array(a, b, M) - - nx = get_backend(M, a, b) - if nx.__name__ in ("jax", "tf"): - raise TypeError("JAX or TF arrays have been received but screenkhorn is not " - "compatible with neither JAX nor TF.") - - ns, nt = M.shape - - # by default, we keep only 50% of the sample data points - if ns_budget is None: - ns_budget = int(np.floor(0.5 * ns)) - if nt_budget is None: - nt_budget = int(np.floor(0.5 * nt)) - - # calculate the Gibbs kernel - K = nx.exp(-M / reg) - - def projection(u, epsilon): - u = nx.maximum(u, epsilon) - return u - - # ----------------------------------------------------------------------------------------------------------------# - # Step 1: Screening pre-processing # - # ----------------------------------------------------------------------------------------------------------------# - - if ns_budget == ns and nt_budget == nt: - # full number of budget points (ns, nt) = (ns_budget, nt_budget) - Isel = nx.from_numpy(np.ones(ns, dtype=bool)) - Jsel = nx.from_numpy(np.ones(nt, dtype=bool)) - epsilon = 0.0 - kappa = 1.0 - - cst_u = 0. - cst_v = 0. - - bounds_u = [(0.0, np.inf)] * ns - bounds_v = [(0.0, np.inf)] * nt - - a_I = a - b_J = b - K_IJ = K - K_IJc = [] - K_IcJ = [] - - vec_eps_IJc = nx.zeros((nt,), type_as=M) - vec_eps_IcJ = nx.zeros((ns,), type_as=M) - - else: - # sum of rows and columns of K - K_sum_cols = nx.sum(K, axis=1) - K_sum_rows = nx.sum(K, axis=0) - - if uniform: - if ns / ns_budget < 4: - aK_sort = nx.sort(K_sum_cols) - epsilon_u_square = a[0] / aK_sort[ns_budget - 1] - else: - aK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy( - K_sum_cols), ns_budget - 1)[ns_budget - 1], - type_as=M - ) - epsilon_u_square = a[0] / aK_sort - - if nt / nt_budget < 4: - bK_sort = nx.sort(K_sum_rows) - epsilon_v_square = b[0] / bK_sort[nt_budget - 1] - else: - bK_sort = nx.from_numpy( - bottleneck.partition(nx.to_numpy( - K_sum_rows), nt_budget - 1)[nt_budget - 1], - type_as=M - ) - epsilon_v_square = b[0] / bK_sort - else: - aK = a / K_sum_cols - bK = b / K_sum_rows - - aK_sort = nx.flip(nx.sort(aK), axis=0) - epsilon_u_square = aK_sort[ns_budget - 1] - - bK_sort = nx.flip(nx.sort(bK), axis=0) - epsilon_v_square = bK_sort[nt_budget - 1] - - # active sets I and J (see Lemma 1 in [26]) - Isel = a >= epsilon_u_square * K_sum_cols - Jsel = b >= epsilon_v_square * K_sum_rows - - if nx.sum(Isel) != ns_budget: - if uniform: - aK = a / K_sum_cols - aK_sort = nx.flip(nx.sort(aK), axis=0) - epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) - Isel = a >= epsilon_u_square * K_sum_cols - ns_budget = nx.sum(Isel) - - if nx.sum(Jsel) != nt_budget: - if uniform: - bK = b / K_sum_rows - bK_sort = nx.flip(nx.sort(bK), axis=0) - epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) - Jsel = b >= epsilon_v_square * K_sum_rows - nt_budget = nx.sum(Jsel) - - epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) - kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) - - if verbose: - print("epsilon = %s\n" % epsilon) - print("kappa = %s\n" % kappa) - print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' - % (sum(Isel), sum(Jsel))) - - # Ic, Jc: complementary of the active sets I and J - Ic = ~Isel - Jc = ~Jsel - - K_IJ = K[np.ix_(Isel, Jsel)] - K_IcJ = K[np.ix_(Ic, Jsel)] - K_IJc = K[np.ix_(Isel, Jc)] - - K_min = nx.min(K_IJ) - if K_min == 0: - K_min = float(np.finfo(float).tiny) - - # a_I, b_J, a_Ic, b_Jc - a_I = a[Isel] - b_J = b[Jsel] - if not uniform: - a_I_min = nx.min(a_I) - a_I_max = nx.max(a_I) - b_J_max = nx.max(b_J) - b_J_min = nx.min(b_J) - else: - a_I_min = a_I[0] - a_I_max = a_I[0] - b_J_max = b_J[0] - b_J_min = b_J[0] - - # box constraints in L-BFGS-B (see Proposition 1 in [26]) - bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget - - bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget - - # pre-calculated constants for the objective - vec_eps_IJc = epsilon * kappa * nx.sum( - K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], - axis=1 - ) - vec_eps_IcJ = (epsilon / kappa) * nx.sum( - nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, - axis=0 - ) - - # initialisation - u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) - v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) - - # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) - if restricted: - if ns_budget != ns or nt_budget != nt: - cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) - cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa - - for _ in range(5): # 5 iterations - K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v - v0 = b_J / (kappa * K_IJ_v) - KIJ_u = nx.dot(K_IJ, v0) + cst_u - u0 = (kappa * a_I) / KIJ_u - - u0 = projection(u0, epsilon / kappa) - v0 = projection(v0, epsilon * kappa) - - else: - u0 = u0 - v0 = v0 - - def restricted_sinkhorn(usc, vsc, max_iter=5): - """ - Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) - """ - for _ in range(max_iter): - K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v - vsc = b_J / (kappa * K_IJ_v) - KIJ_u = nx.dot(K_IJ, vsc) + cst_u - usc = (kappa * a_I) / KIJ_u - - usc = projection(usc, epsilon / kappa) - vsc = projection(vsc, epsilon * kappa) - - return usc, vsc - - def screened_obj(usc, vsc): - part_IJ = ( - nx.dot(nx.dot(usc, K_IJ), vsc) - - kappa * nx.dot(a_I, nx.log(usc)) - - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) - ) - part_IJc = nx.dot(usc, vec_eps_IJc) - part_IcJ = nx.dot(vec_eps_IcJ, vsc) - psi_epsilon = part_IJ + part_IJc + part_IcJ - return psi_epsilon - - def screened_grad(usc, vsc): - # gradients of Psi_(kappa,epsilon) w.r.t u and v - grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc - grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc - return grad_u, grad_v - - def bfgspost(theta): - u = theta[:ns_budget] - v = theta[ns_budget:] - # objective - f = screened_obj(u, v) - # gradient - g_u, g_v = screened_grad(u, v) - g = nx.concatenate([g_u, g_v], axis=0) - return nx.to_numpy(f), nx.to_numpy(g) - - # ----------------------------------------------------------------------------------------------------------------# - # Step 2: L-BFGS-B solver # - # ----------------------------------------------------------------------------------------------------------------# - - u0, v0 = restricted_sinkhorn(u0, v0) - theta0 = nx.concatenate([u0, v0], axis=0) - - bounds = bounds_u + bounds_v # constraint bounds - - def obj(theta): - return bfgspost(nx.from_numpy(theta, type_as=M)) - - theta, _, _ = fmin_l_bfgs_b(func=obj, - x0=theta0, - bounds=bounds, - maxfun=maxfun, - pgtol=pgtol, - maxiter=maxiter) - theta = nx.from_numpy(theta, type_as=M) - - usc = theta[:ns_budget] - vsc = theta[ns_budget:] - - usc_full = nx.full((ns,), epsilon / kappa, type_as=M) - vsc_full = nx.full((nt,), epsilon * kappa, type_as=M) - usc_full[Isel] = usc - vsc_full[Jsel] = vsc - - if log: - log = {} - log['u'] = usc_full - log['v'] = vsc_full - log['Isel'] = Isel - log['Jsel'] = Jsel - - gamma = usc_full[:, None] * K * vsc_full[None, :] - gamma = gamma / nx.sum(gamma) - - if log: - return gamma, log - else: - return gamma diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py new file mode 100644 index 000000000..230982e9c --- /dev/null +++ b/ot/bregman/__init__.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +""" +Solvers related to Bregman projections for entropic regularized OT +""" + +# Author: Remi Flamary +# Cédric Vincent-Cuaz +# +# License: MIT License + +from ._utils import (geometricBar, + geometricMean, + projR, + projC) + +from ._sinkhorn import (sinkhorn, + sinkhorn2, + sinkhorn_knopp, + sinkhorn_log, + greenkhorn, + sinkhorn_stabilized, + sinkhorn_epsilon_scaling) + +from ._barycenter import (barycenter, + barycenter_sinkhorn, + free_support_sinkhorn_barycenter, + barycenter_stabilized, + barycenter_debiased, + jcpot_barycenter) + +from ._convolutional import (convolutional_barycenter2d, + convolutional_barycenter2d_debiased) + +from ._empirical import (empirical_sinkhorn, + empirical_sinkhorn2, + empirical_sinkhorn_divergence) + +from ._screenkhorn import (screenkhorn) + +from ._dictionary import (unmix) + + +__all__ = ['geometricBar', 'geometricMean', 'projR', 'projC', + 'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling', + 'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter', + 'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter', + 'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased', + 'empirical_sinkhorn', 'empirical_sinkhorn2', + 'empirical_sinkhorn_divergence', + 'screenkhorn', + 'unmix' + ] diff --git a/ot/bregman/_barycenter.py b/ot/bregman/_barycenter.py new file mode 100644 index 000000000..abef77d63 --- /dev/null +++ b/ot/bregman/_barycenter.py @@ -0,0 +1,922 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic regularized wasserstein barycenters +""" + +# Author: Remi Flamary +# Nicolas Courty +# Hicham Janati +# Ievgen Redko +# +# License: MIT License + +import warnings +import numpy as np + +from ..utils import dist, list_to_array, unif +from ..backend import get_backend + +from ._utils import geometricBar, geometricMean, projR, projC +from ._sinkhorn import sinkhorn + + +def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + if `method` is `sinkhorn` or `sinkhorn_stabilized` or `sinkhorn_log`. + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter: + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + if method.lower() == 'sinkhorn': + return barycenter_sinkhorn(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return barycenter_stabilized(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_sinkhorn_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[3]`. + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter-sinkhorn: + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + + for ii in range(numItermax): + + UKv = u * nx.dot(K.T, A / nx.dot(K, u)) + u = (u.T * geometricBar(weights, UKv)).T / UKv + + if ii % 10 == 1: + err = nx.sum(nx.std(UKv, axis=1)) + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return geometricBar(weights, UKv), log + else: + return geometricBar(weights, UKv) + + +def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None, + numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None, + **kwargs): + r""" + Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally: + + .. math:: + \min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex) + - `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations + - :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + + This problem is considered in :ref:`[20] ` (Algorithm 2). + There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in + :ref:`[20] ` (Algorithm 2). This can be seen as a discrete + implementation of the fixed-point algorithm of + :ref:`[43] ` proposed in the continuous setting. + - at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the + transport plan in :ref:`[20] ` (Algorithm 2). + + Parameters + ---------- + measures_locations : list of N (k_i,d) array-like + The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space + (:math:`k_i` can be different for each element of the list) + measures_weights : list of N (k_i,) array-like + Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one + representing the weights of each discrete input measure + + X_init : (k,d) array-like + Initialization of the support locations (on `k` atoms) of the barycenter + reg : float + Regularization term >0 + b : (k,) array-like + Initialization of the weights of the barycenter (non-negatives, sum to 1) + weights : (N,) array-like + Initialization of the coefficients of the barycenter (non-negatives, sum to 1) + + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations when calculating the transport plans with Sinkhorn + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + Returns + ------- + X : (k,d) array-like + Support locations (on k atoms) of the barycenter + + See Also + -------- + ot.bregman.sinkhorn : Entropic regularized OT solver + ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming + + .. _references-free-support-barycenter: + References + ---------- + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. + + """ + nx = get_backend(*measures_locations, *measures_weights, X_init) + + iter_count = 0 + + N = len(measures_locations) + k = X_init.shape[0] + d = X_init.shape[1] + if b is None: + b = nx.ones((k,), type_as=X_init) / k + if weights is None: + weights = nx.ones((N,), type_as=X_init) / N + + X = X_init + + log_dict = {} + displacement_square_norms = [] + + displacement_square_norm = stopThr + 1. + + while (displacement_square_norm > stopThr and iter_count < numItermax): + + T_sum = nx.zeros((k, d), type_as=X_init) + + for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights): + M_i = dist(X, measure_locations_i) + T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, + numItermax=numInnerItermax, **kwargs) + T_sum = T_sum + weight_i * 1. / \ + b[:, None] * nx.dot(T_i, measure_locations_i) + + displacement_square_norm = nx.sum((T_sum - X) ** 2) + if log: + displacement_square_norms.append(displacement_square_norm) + + X = T_sum + + if verbose: + print('iteration %d, displacement_square_norm=%f\n', + iter_count, displacement_square_norm) + + iter_count += 1 + + if log: + log_dict['displacement_square_norms'] = displacement_square_norms + return X, log_dict + else: + return X + + +def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic wasserstein barycenter in log-domain + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and tf. Use numpy or torch arrays instead." + ) + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar = log_bar + weights[k] * log_KU[:, k] + + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` with stabilization. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling + algorithm as proposed in :ref:`[3] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + tau : float + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter-stabilized: + References + ---------- + + .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). + Iterative Bregman projections for regularized transportation problems. + SIAM Journal on Scientific Computing, 37(2), A1111-A1138. + + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + dim, n_hists = A.shape + if weights is None: + weights = nx.ones((n_hists,), type_as=M) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + u = nx.ones((dim, n_hists), type_as=M) / dim + v = nx.ones((dim, n_hists), type_as=M) / dim + + K = nx.exp(-M / reg) + + err = 1. + alpha = nx.zeros((dim,), type_as=M) + beta = nx.zeros((dim,), type_as=M) + q = nx.ones((dim,), type_as=M) / dim + for ii in range(numItermax): + qprev = q + Kv = nx.dot(K, v) + u = A / Kv + Ktu = nx.dot(K.T, u) + q = geometricBar(weights, Ktu) + Q = q[:, None] + v = Q / Ktu + absorbing = False + if nx.any(u > tau) or nx.any(v > tau): + absorbing = True + alpha += reg * nx.log(nx.max(u, 1)) + beta += reg * nx.log(nx.max(v, 1)) + K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) + v = nx.ones(tuple(v.shape), type_as=v) + Kv = nx.dot(K, v) + if (nx.any(Ktu == 0.) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %s' % ii) + q = qprev + break + if (ii % 10 == 0 and not absorbing) or ii == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + err = nx.max(nx.abs(u * Kv - A)) + if log: + log['err'].append(err) + if err < stopThr: + break + if verbose: + if ii % 50 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Stabilized Sinkhorn did not converge." + + "Try a larger entropy `reg`" + + "Or a larger absorption threshold `tau`.") + if log: + log['niter'] = ii + log['logu'] = nx.log(u + 1e-16) + log['logv'] = nx.log(v + 1e-16) + return q, log + else: + return q + + +def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs): + r"""Compute the debiased Sinkhorn barycenter of distributions A + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased Sinkhorn divergence + (see :py:func:`ot.bregman.empirical_sinkhorn_divergence`) + - :math:`\mathbf{a}_i` are training distributions in the columns of matrix + :math:`\mathbf{A}` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and + the cost matrix for OT + + The algorithm used for solving the problem is the debiased Sinkhorn + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (dim, n_hists) + `n_hists` training distributions :math:`\mathbf{a}_i` of size `dim` + M : array-like, shape (dim, dim) + loss matrix for OT + reg : float + Regularization term > 0 + method : str (optional) + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + weights : array-like, shape (n_hists,) + Weights of each histogram :math:`\mathbf{a}_i` on the simplex (barycentric coodinates) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : (dim,) array-like + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-barycenter-debiased: + References + ---------- + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _barycenter_debiased(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, **kwargs) + elif method.lower() == 'sinkhorn_log': + return _barycenter_debiased_log(A, M, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, warn=True): + r"""Compute the debiased sinkhorn barycenter of distributions A. + """ + + A, M = list_to_array(A, M) + + nx = get_backend(A, M) + + if weights is None: + weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1] + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + K = nx.exp(-M / reg) + + err = 1 + + UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T) + + u = (geometricMean(UKv) / UKv.T).T + c = nx.ones(A.shape[0], type_as=A) + bar = nx.ones(A.shape[0], type_as=A) + + for ii in range(numItermax): + bold = bar + UKv = nx.dot(K, A / nx.dot(K, u)) + bar = c * geometricBar(weights, UKv) + u = bar[:, None] / UKv + c = (c * bar / nx.dot(K, c)) ** 0.5 + + if ii % 10 == 9: + err = abs(bar - bold).max() / max(bar.max(), 1.) + + # log and verbose print + if log: + log['err'].append(err) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return bar, log + else: + return bar + + +def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False, + warn=True): + r"""Compute the debiased sinkhorn barycenter in log domain. + """ + + A, M = list_to_array(A, M) + dim, n_hists = A.shape + + nx = get_backend(A, M) + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) + + if weights is None: + weights = nx.ones(n_hists, type_as=A) / n_hists + else: + assert (len(weights) == A.shape[1]) + + if log: + log = {'err': []} + + M = - M / reg + logA = nx.log(A + 1e-15) + log_KU, G = nx.zeros((2, *logA.shape), type_as=A) + c = nx.zeros(dim, type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros(dim, type_as=A) + for k in range(n_hists): + f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1) + log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0) + log_bar += weights[k] * log_KU[:, k] + log_bar += c + if ii % 10 == 1: + err = nx.exp(G + log_KU).std(axis=1).sum() + + # log and verbose print + if log: + log['err'].append(err) + + if err < stopThr and ii > 20: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + G = log_bar[:, None] - log_KU + for _ in range(10): + c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100, + stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs): + r'''Joint OT and proportion estimation for multi-source target shift as + proposed in :ref:`[27] ` + + The function solves the following optimization problem: + + .. math:: + + \mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k + W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a}) + + s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h} + + where : + + - :math:`\lambda_k` is the weight of `k`-th source domain + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}_2^{(k)}` is a matrix of weights related to `k`-th source domain + defined as in [p. 5, :ref:`27 `], its expected shape + is :math:`(n_k, C)` where :math:`n_k` is the number of elements in the `k`-th source + domain and `C` is the number of classes + - :math:`\mathbf{h}` is a vector of estimated proportions in the target domain of size `C` + - :math:`\mathbf{a}` is a uniform vector of weights in the target domain of size `n` + - :math:`\mathbf{D}_1^{(k)}` is a matrix of class assignments defined as in + [p. 5, :ref:`27 `], its expected shape is :math:`(n_k, C)` + + The problem consist in solving a Wasserstein barycenter problem to estimate + the proportions :math:`\mathbf{h}` in the target domain. + + The algorithm used for solving the problem is the Iterative Bregman projections algorithm + with two sets of marginal constraints related to the unknown vector + :math:`\mathbf{h}` and uniform target distribution. + + Parameters + ---------- + Xs : list of K array-like(nsk,d) + features of all source domains' samples + Ys : list of K array-like(nsk,) + labels of all source domains' samples + Xt : array-like (nt,d) + samples in the target domain + reg : float + Regularization term > 0 + metric : string, optional (default="sqeuclidean") + The ground metric for the Wasserstein problem + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on relative change in the barycenter (>0) + verbose : bool, optional (default=False) + Controls the verbosity of the optimization algorithm + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + h : (C,) array-like + proportion estimation in the target domain + log : dict + log dictionary return only if log==True in parameters + + + .. _references-jcpot-barycenter: + References + ---------- + + .. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia + "Optimal transport for multi-source domain adaptation under target shift", + International Conference on Artificial Intelligence and Statistics (AISTATS), 2019. + ''' + + Xs = list_to_array(*Xs) + Ys = list_to_array(*Ys) + Xt = list_to_array(Xt) + + nx = get_backend(*Xs, *Ys, Xt) + + nbclasses = len(nx.unique(Ys[0])) + nbdomains = len(Xs) + + # log dictionary + if log: + log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []} + + K = [] + M = [] + D1 = [] + D2 = [] + + # For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2 + for d in range(nbdomains): + dom = {} + nsk = Xs[d].shape[0] # get number of elements for this domain + dom['nbelem'] = nsk + classes = nx.unique(Ys[d]) # get number of classes for this domain + + # format classes to start from 0 for convenience + if nx.min(classes) != 0: + Ys[d] -= nx.min(classes) + classes = nx.unique(Ys[d]) + + # build the corresponding D_1 and D_2 matrices + Dtmp1 = np.zeros((nbclasses, nsk)) + Dtmp2 = np.zeros((nbclasses, nsk)) + + for c in classes: + nbelemperclass = float(nx.sum(Ys[d] == c)) + if nbelemperclass != 0: + Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1. + Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass) + D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0])) + D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0])) + + # build the cost matrix and the Gibbs kernel + Mtmp = dist(Xs[d], Xt, metric=metric) + M.append(Mtmp) + + Ktmp = nx.exp(-Mtmp / reg) + K.append(Ktmp) + + # uniform target distribution + a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0]) + + err = 1 + old_bary = nx.ones((nbclasses,), type_as=Xs[0]) + + for ii in range(numItermax): + + bary = nx.zeros((nbclasses,), type_as=Xs[0]) + + # update coupling matrices for marginal constraints w.r.t. uniform target distribution + for d in range(nbdomains): + K[d] = projC(K[d], a) + other = nx.sum(K[d], axis=1) + bary += nx.log(nx.dot(D1[d], other)) / nbdomains + + bary = nx.exp(bary) + + # update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27] + for d in range(nbdomains): + new = nx.dot(D2[d].T, bary) + K[d] = projR(K[d], new) + + err = nx.norm(bary - old_bary) + + old_bary = bary + + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + bary = bary / nx.sum(bary) + + if log: + log['niter'] = ii + log['M'] = M + log['D1'] = D1 + log['D2'] = D2 + log['gamma'] = K + return bary, log + else: + return bary diff --git a/ot/bregman/_convolutional.py b/ot/bregman/_convolutional.py new file mode 100644 index 000000000..baea2aec3 --- /dev/null +++ b/ot/bregman/_convolutional.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic regularized Wasserstein convolutional barycenters +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +import warnings + +from ..utils import list_to_array +from ..backend import get_backend + + +def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000, + stopThr=1e-4, verbose=False, log=False, + warn=True, **kwargs): + r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions + of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm + as proposed in :ref:`[21] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-convolutional-barycenter-2d: + References + ---------- + + .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, + A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: + Efficient optimal transportation on geometric domains. ACM Transactions + on Graphics (TOG), 34(4), 66 + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th + International Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000, + stopThr=1e-9, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images. + """ + + A = list_to_array(A) + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0] + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + bar = nx.ones(A.shape[1:], type_as=A) + bar /= nx.sum(bar) + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, A.shape[1], type_as=A) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, A.shape[2], type_as=A) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) + if ii % 10 == 9: + err = nx.sum(nx.std(V * KU, axis=0)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-4, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the entropic regularized wasserstein barycenter of distributions A + where A is a collection of 2D images in log-domain. + """ + + A = list_to_array(A) + + nx = get_backend(A) + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) + + n_hists, width, height = A.shape + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width, type_as=A) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height, type_as=A) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + + if ii % 10 == 9: + err = nx.exp(G + log_KU).std(axis=0).sum() + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) + + +def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn", + numItermax=10000, stopThr=1e-3, + verbose=False, log=False, warn=True, + **kwargs): + r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` + where :math:`\mathbf{A}` is a collection of 2D images. + + The function solves the following optimization problem: + + .. math:: + \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) + + where : + + - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein + distance (see :py:func:`ot.bregman.barycenter_debiased`) + - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two + dimensions of matrix :math:`\mathbf{A}` + - `reg` is the regularization strength scalar value + + The algorithm used for solving the problem is the debiased Sinkhorn scaling + algorithm as proposed in :ref:`[37] ` + + Parameters + ---------- + A : array-like, shape (n_hists, width, height) + `n` distributions (2D images) of size `width` x `height` + reg : float + Regularization term >0 + weights : array-like, shape (n_hists,) + Weights of each image on the simplex (barycentric coodinates) + method : string, optional + method used for the solver either 'sinkhorn' or 'sinkhorn_log' + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (> 0) + stabThr : float, optional + Stabilization threshold to avoid numerical precision issue + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + + Returns + ------- + a : array-like, shape (width, height) + 2D Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-convolutional-barycenter2d-debiased: + References + ---------- + + .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International + Conference on Machine Learning, PMLR 119:4692-4701, 2020 + """ + + if method.lower() == 'sinkhorn': + return _convolutional_barycenter2d_debiased(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights, + numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-15, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + + nx = get_backend(A) + + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == n_hists) + + if log: + log = {'err': []} + + bar = nx.ones((width, height), type_as=A) + bar /= width * height + U = nx.ones(A.shape, type_as=A) + V = nx.ones(A.shape, type_as=A) + c = nx.ones(A.shape[1:], type_as=A) + err = 1 + + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width, type_as=A) + [Y, X] = nx.meshgrid(t, t) + K1 = nx.exp(-(X - Y) ** 2 / reg) + + t = nx.linspace(0, 1, height, type_as=A) + [Y, X] = nx.meshgrid(t, t) + K2 = nx.exp(-(X - Y) ** 2 / reg) + + def convol_imgs(imgs): + kx = nx.einsum("...ij,kjl->kil", K1, imgs) + kxy = nx.einsum("...ij,klj->kli", K2, kx) + return kxy + + KU = convol_imgs(U) + for ii in range(numItermax): + V = bar[None] / KU + KV = convol_imgs(V) + U = A / KV + KU = convol_imgs(U) + bar = c * nx.exp( + nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0) + ) + + for _ in range(10): + c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 + + if ii % 10 == 9: + err = nx.sum(nx.std(V * KU, axis=0)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + # debiased Sinkhorn does not converge monotonically + # guarantee a few iterations are done before stopping + if err < stopThr and ii > 20: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + log['U'] = U + return bar, log + else: + return bar + + +def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000, + stopThr=1e-3, stabThr=1e-30, verbose=False, + log=False, warn=True): + r"""Compute the debiased barycenter of 2D images in log-domain. + """ + + A = list_to_array(A) + n_hists, width, height = A.shape + nx = get_backend(A) + if nx.__name__ in ("jax", "tf"): + raise NotImplementedError( + "Log-domain functions are not yet implemented" + " for Jax and TF. Use numpy or torch arrays instead." + ) + if weights is None: + weights = nx.ones((n_hists,), type_as=A) / n_hists + else: + assert (len(weights) == A.shape[0]) + + if log: + log = {'err': []} + + err = 1 + # build the convolution operator + # this is equivalent to blurring on horizontal then vertical directions + t = nx.linspace(0, 1, width, type_as=A) + [Y, X] = nx.meshgrid(t, t) + M1 = - (X - Y) ** 2 / reg + + t = nx.linspace(0, 1, height, type_as=A) + [Y, X] = nx.meshgrid(t, t) + M2 = - (X - Y) ** 2 / reg + + def convol_img(log_img): + log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1) + log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T + return log_img + + logA = nx.log(A + stabThr) + log_bar, c = nx.zeros((2, width, height), type_as=A) + log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) + err = 1 + for ii in range(numItermax): + log_bar = nx.zeros((width, height), type_as=A) + for k in range(n_hists): + f = logA[k] - convol_img(G[k]) + log_KU[k] = convol_img(f) + log_bar = log_bar + weights[k] * log_KU[k] + log_bar += c + for _ in range(10): + c = 0.5 * (c + log_bar - convol_img(c)) + + if ii % 10 == 9: + err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) + # log and verbose print + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr and ii > 20: + break + G = log_bar[None, :, :] - log_KU + + else: + if warn: + warnings.warn("Convolutional Sinkhorn did not converge. " + "Try a larger number of iterations `numItermax` " + "or a larger entropy `reg`.") + if log: + log['niter'] = ii + return nx.exp(log_bar), log + else: + return nx.exp(log_bar) diff --git a/ot/bregman/_dictionary.py b/ot/bregman/_dictionary.py new file mode 100644 index 000000000..80f00f762 --- /dev/null +++ b/ot/bregman/_dictionary.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +Dictionary Learning based on Bregman projections for entropic regularized OT +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +import warnings + + +from ..utils import list_to_array +from ..backend import get_backend + +from ._utils import projC, projR + + +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, + stopThr=1e-3, verbose=False, log=False, warn=True): + r""" + Compute the unmixing of an observation with a given dictionary using Wasserstein distance + + The function solve the following optimization problem: + + .. math:: + + \mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad + (1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) + + \alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h}) + + + where : + + - :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance + with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`) + - :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`, + its expected shape is `(dim_a, n_atoms)` + - :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms` + - :math:`\mathbf{a}` is an observed distribution of dimension `dim_a` + - :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior` + - `reg` and :math:`\mathbf{M}` are respectively the regularization term and the + cost matrix (`dim_a`, `dim_a`) for OT data fitting + - `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization + term and the cost matrix (`dim_prior`, `n_atoms`) regularization + - :math:`\alpha` weight data fitting and regularization + + The optimization problem is solved following the algorithm described + in :ref:`[4] ` + + + Parameters + ---------- + a : array-like, shape (dim_a) + observed distribution (histogram, sums to 1) + D : array-like, shape (dim_a, n_atoms) + dictionary matrix + M : array-like, shape (dim_a, dim_a) + loss matrix + M0 : array-like, shape (n_atoms, dim_prior) + loss matrix + h0 : array-like, shape (n_atoms,) + prior on the estimated unmixing h + reg : float + Regularization term >0 (Wasserstein data fitting) + reg0 : float + Regularization term >0 (Wasserstein reg with h0) + alpha : float + How much should we trust the prior ([0,1]) + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + h : array-like, shape (n_atoms,) + Wasserstein barycenter + log : dict + log dictionary return only if log==True in parameters + + + .. _references-unmix: + References + ---------- + + .. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti, + Supervised planetary unmixing with optimal transport, Workshop + on Hyperspectral Image and Signal Processing : + Evolution in Remote Sensing (WHISPERS), 2016. + """ + + a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0) + + nx = get_backend(a, D, M, M0, h0) + + # M = M/np.median(M) + K = nx.exp(-M / reg) + + # M0 = M0/np.median(M0) + K0 = nx.exp(-M0 / reg0) + old = h0 + + err = 1 + # log = {'niter':0, 'all_err':[]} + if log: + log = {'err': []} + + for ii in range(numItermax): + K = projC(K, a) + K0 = projC(K0, h0) + new = nx.sum(K0, axis=1) + # we recombine the current selection from dictionnary + inv_new = nx.dot(D, new) + other = nx.sum(K, axis=1) + # geometric interpolation + delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new)) + K = projR(K, delta) + K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0 + err = nx.norm(nx.sum(K0, axis=1) - old) + old = new + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Unmixing algorithm did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + return nx.sum(K0, axis=1), log + else: + return nx.sum(K0, axis=1) diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py new file mode 100644 index 000000000..55c5ada3f --- /dev/null +++ b/ot/bregman/_empirical.py @@ -0,0 +1,548 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic regularized OT for empirical distributions +""" + +# Author: Remi Flamary +# Kilian Fatras +# Quang Huy Tran +# +# License: MIT License + +import warnings + +from ..utils import dist, list_to_array, unif +from ..backend import get_backend + +from ._sinkhorn import sinkhorn, sinkhorn2 + + +def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): + r''' + Solve the entropic regularization optimal transport problem and return the + OT matrix from empirical data + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate full + cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batches used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + + Returns + ------- + gamma : array-like, shape (n_samples_a, n_samples_b) + Regularized optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import numpy as np + >>> n_samples_a = 2 + >>> n_samples_b = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False) # doctest: +NORMALIZE_WHITESPACE + array([[4.99977301e-01, 2.26989344e-05], + [2.26989344e-05, 4.99977301e-01]]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + ''' + + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + if isLazy: + if log: + dict_log = {"err": []} + + log_a, log_b = nx.log(a), nx.log(b) + if warmstart is None: + f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a) + else: + f, g = warmstart + + if isinstance(batchSize, int): + bs, bt = batchSize, batchSize + elif isinstance(batchSize, tuple) and len(batchSize) == 2: + bs, bt = batchSize[0], batchSize[1] + else: + raise ValueError( + "Batch size must be in integer or a tuple of two integers") + + range_s, range_t = range(0, ns, bs), range(0, nt, bt) + + lse_f = nx.zeros((ns,), type_as=a) + lse_g = nx.zeros((nt,), type_as=a) + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + + for i_ot in range(numIterMax): + + lse_f_cols = [] + for i in range_s: + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_f_cols.append( + nx.logsumexp(g[None, :] - M / reg, axis=1) + ) + lse_f = nx.concatenate(lse_f_cols, axis=0) + f = log_a - lse_f + + lse_g_cols = [] + for j in range_t: + M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric) + M = nx.from_numpy(M, type_as=a) + lse_g_cols.append( + nx.logsumexp(f[:, None] - M / reg, axis=0) + ) + lse_g = nx.concatenate(lse_g_cols, axis=0) + g = log_b - lse_g + + if (i_ot + 1) % 10 == 0: + m1_cols = [] + for i in range_s: + M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M = nx.from_numpy(M, type_as=a) + m1_cols.append( + nx.sum(nx.exp(f[i:i + bs, None] + + g[None, :] - M / reg), axis=1) + ) + m1 = nx.concatenate(m1_cols, axis=0) + err = nx.sum(nx.abs(m1 - a)) + if log: + dict_log["err"].append(err) + + if verbose and (i_ot + 1) % 100 == 0: + print("Error in marginal at iteration {} = {}".format( + i_ot + 1, err)) + + if err <= stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + dict_log["u"] = f + dict_log["v"] = g + return (f, g, dict_log) + else: + return (f, g) + + else: + M = dist(X_s, X_t, metric=metric) + if log: + pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=True, warmstart=warmstart, **kwargs) + return pi, log + else: + pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, + verbose=verbose, log=False, warmstart=warmstart, **kwargs) + return pi + + +def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): + r''' + Solve the entropic regularization optimal transport problem from empirical + data and return the OT loss + + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`n_samples_a`, `n_samples_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without + the entropic contribution). + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + isLazy: boolean, optional + If True, then only calculate the cost matrix by block and return + the dual potentials only (to save memory). If False, calculate + full cost matrix and return outputs of sinkhorn function. + batchSize: int or tuple of 2 int, optional + Size of the batches used to compute the sinkhorn update without memory overhead. + When a tuple is provided it sets the size of the left/right batches. + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + W : (n_hists) array-like or float + Optimal transportation loss for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import numpy as np + >>> n_samples_a = 2 + >>> n_samples_b = 2 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> b = np.full((n_samples_b, 3), 1/n_samples_b) + >>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False) + array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05]) + + + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling + Algorithms for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + ''' + + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + if isLazy: + if log: + f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, + batchSize=batchSize, + verbose=verbose, log=log, + warn=warn, + warmstart=warmstart) + else: + f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, + numIterMax=numIterMax, + stopThr=stopThr, + isLazy=isLazy, batchSize=batchSize, + verbose=verbose, log=log, + warn=warn, + warmstart=warmstart) + + bs = batchSize if isinstance(batchSize, int) else batchSize[0] + range_s = range(0, ns, bs) + + loss = 0 + + X_s_np = nx.to_numpy(X_s) + X_t_np = nx.to_numpy(X_t) + + for i in range_s: + M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric) + M_block = nx.from_numpy(M_block, type_as=a) + pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg) + loss += nx.sum(M_block * pi_block) + + if log: + return loss, dict_log + else: + return loss + + else: + M = dist(X_s, X_t, metric=metric) + + if log: + sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, warmstart=warmstart, **kwargs) + return sinkhorn_loss, log + else: + sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, warmstart=warmstart, **kwargs) + return sinkhorn_loss + + +def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): + r''' + Compute the sinkhorn divergence loss from empirical data + + The function solves the following optimization problems and return the + sinkhorn divergence :math:`S`: + + .. math:: + W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma) + + W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_a) + + W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F + + \mathrm{reg} \cdot\Omega(\gamma_b) + + S &= W - \frac{W_a + W_b}{2} + + .. math:: + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + \gamma_a \mathbf{1} &= \mathbf{a} + + \gamma_a^T \mathbf{1} &= \mathbf{a} + + \gamma_a &\geq 0 + + \gamma_b \mathbf{1} &= \mathbf{b} + + \gamma_b^T \mathbf{1} &= \mathbf{b} + + \gamma_b &\geq 0 + + where : + + - :math:`\mathbf{M}` (resp. :math:`\mathbf{M_a}`, :math:`\mathbf{M_b}`) + is the (`n_samples_a`, `n_samples_b`) metric cost matrix + (resp (`n_samples_a, n_samples_a`) and (`n_samples_b`, `n_samples_b`)) + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) + + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle + \gamma^*_b , \mathbf{M_b} \rangle_F)/2`. + + .. note: The current implementation does not account for the entropic contributions and thus differs from the + Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions + will be provided in a future release. + + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + W : (1,) array-like + Optimal transportation symmetrized loss for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import numpy as np + >>> n_samples_a = 2 + >>> n_samples_b = 4 + >>> reg = 0.1 + >>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1)) + >>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1)) + >>> empirical_sinkhorn_divergence(X_s, X_t, reg) # doctest: +ELLIPSIS + 1.499887176049052 + + + References + ---------- + .. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative + Models with Sinkhorn Divergences, Proceedings of the Twenty-First + International Conference on Artificial Intelligence and Statistics, + (AISTATS) 21, 2018 + ''' + X_s, X_t = list_to_array(X_s, X_t) + + nx = get_backend(X_s, X_t) + if warmstart is None: + warmstart_a, warmstart_b = None, None + else: + u, v = warmstart + warmstart_a = (u, u) + warmstart_b = (v, v) + + if log: + sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) + + sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) + + sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) + + log = {} + log['sinkhorn_loss_ab'] = sinkhorn_loss_ab + log['sinkhorn_loss_a'] = sinkhorn_loss_a + log['sinkhorn_loss_b'] = sinkhorn_loss_b + log['log_sinkhorn_ab'] = log_ab + log['log_sinkhorn_a'] = log_a + log['log_sinkhorn_b'] = log_b + + return nx.maximum(0, sinkhorn_div), log + + else: + sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart, **kwargs) + + sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_a, **kwargs) + + sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric, + numIterMax=numIterMax, stopThr=stopThr, + verbose=verbose, log=log, warn=warn, + warmstart=warmstart_b, **kwargs) + + sinkhorn_div = sinkhorn_loss_ab - 0.5 * \ + (sinkhorn_loss_a + sinkhorn_loss_b) + return nx.maximum(0, sinkhorn_div) diff --git a/ot/bregman/_screenkhorn.py b/ot/bregman/_screenkhorn.py new file mode 100644 index 000000000..8c53f73ae --- /dev/null +++ b/ot/bregman/_screenkhorn.py @@ -0,0 +1,389 @@ +# -*- coding: utf-8 -*- +""" +Screening Sinkhorn Algorithms for Regularized Optimal Transport +""" + +# Author: Remi Flamary +# Mokhtar Z. Alaya +# +# License: MIT License + +import warnings + +import numpy as np +from scipy.optimize import fmin_l_bfgs_b + +from ..utils import list_to_array +from ..backend import get_backend + + +def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False, + restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09, + verbose=False, log=False): + r""" + Screening Sinkhorn Algorithm for Regularized Optimal Transport + + The function solves an approximate dual of Sinkhorn divergence :ref:`[2] + ` which is written as the following optimization problem: + + .. math:: + + (\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad + \mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} - + \langle \kappa \mathbf{u}, \mathbf{a} \rangle - + \langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle + + where: + + .. math:: + + \mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and} + + .. math:: + + s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\} + + e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\} + + The parameters `kappa` and `epsilon` are determined w.r.t the couple number + budget of points (`ns_budget`, `nt_budget`), see Equation (5) + in :ref:`[26] ` + + + Parameters + ---------- + a: array-like, shape=(ns,) + samples weights in the source domain + b: array-like, shape=(nt,) + samples weights in the target domain + M: array-like, shape=(ns, nt) + Cost matrix + reg: `float` + Level of the entropy regularisation + ns_budget: `int`, default=None + Number budget of points to be kept in the source domain. + If it is None then 50% of the source sample points will be kept + nt_budget: `int`, default=None + Number budget of points to be kept in the target domain. + If it is None then 50% of the target sample points will be kept + uniform: `bool`, default=False + If `True`, the source and target distribution are supposed to be uniform, + i.e., :math:`a_i = 1 / ns` and :math:`b_j = 1 / nt` + restricted : `bool`, default=True + If `True`, a warm-start initialization for the L-BFGS-B solver + using a restricted Sinkhorn algorithm with at most 5 iterations + maxiter: `int`, default=10000 + Maximum number of iterations in LBFGS solver + maxfun: `int`, default=10000 + Maximum number of function evaluations in LBFGS solver + pgtol: `float`, default=1e-09 + Final objective function accuracy in LBFGS solver + verbose: `bool`, default=False + If `True`, display informations about the cardinals of the active sets + and the parameters kappa and epsilon + + + .. admonition:: Dependency + + To gain more efficiency, :py:func:`ot.bregman.screenkhorn` needs to call the "Bottleneck" + package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step. + + If Bottleneck isn't installed, the following error message appears: + + "Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/" + + + Returns + ------- + gamma : array-like, shape=(ns, nt) + Screened optimal transportation matrix for the given parameters + + log : `dict`, default=False + Log dictionary return only if log==True in parameters + + + .. _references-screenkhorn: + References + ----------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, + Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). + Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019 + + """ + # check if bottleneck module exists + try: + import bottleneck + except ImportError: + warnings.warn( + "Bottleneck module is not installed. Install it from" + " https://pypi.org/project/Bottleneck/ for better performance.") + bottleneck = np + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received but screenkhorn is not " + "compatible with neither JAX nor TF.") + + ns, nt = M.shape + + # by default, we keep only 50% of the sample data points + if ns_budget is None: + ns_budget = int(np.floor(0.5 * ns)) + if nt_budget is None: + nt_budget = int(np.floor(0.5 * nt)) + + # calculate the Gibbs kernel + K = nx.exp(-M / reg) + + def projection(u, epsilon): + u = nx.maximum(u, epsilon) + return u + + # ----------------------------------------------------------------------------------------------------------------# + # Step 1: Screening pre-processing # + # ----------------------------------------------------------------------------------------------------------------# + + if ns_budget == ns and nt_budget == nt: + # full number of budget points (ns, nt) = (ns_budget, nt_budget) + Isel = nx.from_numpy(np.ones(ns, dtype=bool)) + Jsel = nx.from_numpy(np.ones(nt, dtype=bool)) + epsilon = 0.0 + kappa = 1.0 + + cst_u = 0. + cst_v = 0. + + bounds_u = [(0.0, np.inf)] * ns + bounds_v = [(0.0, np.inf)] * nt + + a_I = a + b_J = b + K_IJ = K + K_IJc = [] + K_IcJ = [] + + vec_eps_IJc = nx.zeros((nt,), type_as=M) + vec_eps_IcJ = nx.zeros((ns,), type_as=M) + + else: + # sum of rows and columns of K + K_sum_cols = nx.sum(K, axis=1) + K_sum_rows = nx.sum(K, axis=0) + + if uniform: + if ns / ns_budget < 4: + aK_sort = nx.sort(K_sum_cols) + epsilon_u_square = a[0] / aK_sort[ns_budget - 1] + else: + aK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy( + K_sum_cols), ns_budget - 1)[ns_budget - 1], + type_as=M + ) + epsilon_u_square = a[0] / aK_sort + + if nt / nt_budget < 4: + bK_sort = nx.sort(K_sum_rows) + epsilon_v_square = b[0] / bK_sort[nt_budget - 1] + else: + bK_sort = nx.from_numpy( + bottleneck.partition(nx.to_numpy( + K_sum_rows), nt_budget - 1)[nt_budget - 1], + type_as=M + ) + epsilon_v_square = b[0] / bK_sort + else: + aK = a / K_sum_cols + bK = b / K_sum_rows + + aK_sort = nx.flip(nx.sort(aK), axis=0) + epsilon_u_square = aK_sort[ns_budget - 1] + + bK_sort = nx.flip(nx.sort(bK), axis=0) + epsilon_v_square = bK_sort[nt_budget - 1] + + # active sets I and J (see Lemma 1 in [26]) + Isel = a >= epsilon_u_square * K_sum_cols + Jsel = b >= epsilon_v_square * K_sum_rows + + if nx.sum(Isel) != ns_budget: + if uniform: + aK = a / K_sum_cols + aK_sort = nx.flip(nx.sort(aK), axis=0) + epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1]) + Isel = a >= epsilon_u_square * K_sum_cols + ns_budget = nx.sum(Isel) + + if nx.sum(Jsel) != nt_budget: + if uniform: + bK = b / K_sum_rows + bK_sort = nx.flip(nx.sort(bK), axis=0) + epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1]) + Jsel = b >= epsilon_v_square * K_sum_rows + nt_budget = nx.sum(Jsel) + + epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4) + kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2) + + if verbose: + print("epsilon = %s\n" % epsilon) + print("kappa = %s\n" % kappa) + print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n' + % (sum(Isel), sum(Jsel))) + + # Ic, Jc: complementary of the active sets I and J + Ic = ~Isel + Jc = ~Jsel + + K_IJ = K[np.ix_(Isel, Jsel)] + K_IcJ = K[np.ix_(Ic, Jsel)] + K_IJc = K[np.ix_(Isel, Jc)] + + K_min = nx.min(K_IJ) + if K_min == 0: + K_min = float(np.finfo(float).tiny) + + # a_I, b_J, a_Ic, b_Jc + a_I = a[Isel] + b_J = b[Jsel] + if not uniform: + a_I_min = nx.min(a_I) + a_I_max = nx.max(a_I) + b_J_max = nx.max(b_J) + b_J_min = nx.min(b_J) + else: + a_I_min = a_I[0] + a_I_max = a_I[0] + b_J_max = b_J[0] + b_J_min = b_J[0] + + # box constraints in L-BFGS-B (see Proposition 1 in [26]) + bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + + bounds_v = [( + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + + # pre-calculated constants for the objective + vec_eps_IJc = epsilon * kappa * nx.sum( + K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :], + axis=1 + ) + vec_eps_IcJ = (epsilon / kappa) * nx.sum( + nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ, + axis=0 + ) + + # initialisation + u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M) + v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M) + + # pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26]) + if restricted: + if ns_budget != ns or nt_budget != nt: + cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1) + cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa + + for _ in range(5): # 5 iterations + K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v + v0 = b_J / (kappa * K_IJ_v) + KIJ_u = nx.dot(K_IJ, v0) + cst_u + u0 = (kappa * a_I) / KIJ_u + + u0 = projection(u0, epsilon / kappa) + v0 = projection(v0, epsilon * kappa) + + else: + u0 = u0 + v0 = v0 + + def restricted_sinkhorn(usc, vsc, max_iter=5): + """ + Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B) + """ + for _ in range(max_iter): + K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v + vsc = b_J / (kappa * K_IJ_v) + KIJ_u = nx.dot(K_IJ, vsc) + cst_u + usc = (kappa * a_I) / KIJ_u + + usc = projection(usc, epsilon / kappa) + vsc = projection(vsc, epsilon * kappa) + + return usc, vsc + + def screened_obj(usc, vsc): + part_IJ = ( + nx.dot(nx.dot(usc, K_IJ), vsc) + - kappa * nx.dot(a_I, nx.log(usc)) + - (1. / kappa) * nx.dot(b_J, nx.log(vsc)) + ) + part_IJc = nx.dot(usc, vec_eps_IJc) + part_IcJ = nx.dot(vec_eps_IcJ, vsc) + psi_epsilon = part_IJ + part_IJc + part_IcJ + return psi_epsilon + + def screened_grad(usc, vsc): + # gradients of Psi_(kappa,epsilon) w.r.t u and v + grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc + grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc + return grad_u, grad_v + + def bfgspost(theta): + u = theta[:ns_budget] + v = theta[ns_budget:] + # objective + f = screened_obj(u, v) + # gradient + g_u, g_v = screened_grad(u, v) + g = nx.concatenate([g_u, g_v], axis=0) + return nx.to_numpy(f), nx.to_numpy(g) + + # ----------------------------------------------------------------------------------------------------------------# + # Step 2: L-BFGS-B solver # + # ----------------------------------------------------------------------------------------------------------------# + + u0, v0 = restricted_sinkhorn(u0, v0) + theta0 = nx.concatenate([u0, v0], axis=0) + + bounds = bounds_u + bounds_v # constraint bounds + + def obj(theta): + return bfgspost(nx.from_numpy(theta, type_as=M)) + + theta, _, _ = fmin_l_bfgs_b(func=obj, + x0=theta0, + bounds=bounds, + maxfun=maxfun, + pgtol=pgtol, + maxiter=maxiter) + theta = nx.from_numpy(theta, type_as=M) + + usc = theta[:ns_budget] + vsc = theta[ns_budget:] + + usc_full = nx.full((ns,), epsilon / kappa, type_as=M) + vsc_full = nx.full((nt,), epsilon * kappa, type_as=M) + usc_full[Isel] = usc + vsc_full[Jsel] = vsc + + if log: + log = {} + log['u'] = usc_full + log['v'] = vsc_full + log['Isel'] = Isel + log['Jsel'] = Jsel + + gamma = usc_full[:, None] * K * vsc_full[None, :] + gamma = gamma / nx.sum(gamma) + + if log: + return gamma, log + else: + return gamma diff --git a/ot/bregman/_sinkhorn.py b/ot/bregman/_sinkhorn.py new file mode 100644 index 000000000..beb101fe8 --- /dev/null +++ b/ot/bregman/_sinkhorn.py @@ -0,0 +1,1350 @@ +# -*- coding: utf-8 -*- +""" +Bregman projections solvers for entropic regularized OT +""" + +# Author: Remi Flamary +# Nicolas Courty +# Titouan Vayer +# Alexander Tong +# Quang Huy Tran +# +# License: MIT License + +import warnings + +import numpy as np + +from ..utils import list_to_array +from ..backend import get_backend + + +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn','sinkhorn_log', + 'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see + those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + .. _references-sinkhorn: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling + :ref:`[9] ` :ref:`[10] ` + + """ + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'greenkhorn': + return greenkhorn(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, log=log, + warn=warn, warmstart=warmstart) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, + **kwargs) + elif method.lower() == 'sinkhorn_epsilon_scaling': + return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the loss + + The function solves the following optimization problem: + + .. math:: + W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without + the entropic contribution). + + .. note:: This function is backend-compatible and will work on arrays + from all compatible backends. + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` + + + **Choosing a Sinkhorn solver** + + By default and when using a regularization parameter that is not too small + the default sinkhorn solver should be enough. If you need to use a small + regularization to get sharper OT matrices, you should use the + :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical + errors. This last solver can be very slow in practice and might not even + converge to a reasonable OT matrix in a finite time. This is why + :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value + of the regularization (and using warm start) sometimes leads to better + solutions. Note that the greedy version of the sinkhorn + :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening + version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a + fast approximation of the Sinkhorn problem. For use of GPU and gradient + computation with small number of iterations we strongly recommend the + :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for + numerical problems. + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + method : str + method used for the solver either 'sinkhorn','sinkhorn_log', + 'sinkhorn_stabilized', see those function for specific parameters + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + W : (n_hists) float/array-like + Optimal transportation loss for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn2(a, b, M, 1) + 0.26894142136999516 + + + .. _references-sinkhorn2: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation + algorithms for optimal transport via Sinkhorn iteration, + Advances in Neural Information Processing Systems (NIPS) 31, 2017 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). + Interpolating between optimal transport and MMD using Sinkhorn + divergences. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] ` + ot.bregman.greenkhorn : Greenkhorn :ref:`[21] ` + ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn + :ref:`[9] ` :ref:`[10] ` + """ + + M, a, b = list_to_array(M, a, b) + nx = get_backend(M, a, b) + + if len(b.shape) < 2: + if method.lower() == 'sinkhorn': + res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'sinkhorn_log': + res = sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + if log: + return nx.sum(M * res[0]), res[1] + else: + return nx.sum(M * res) + + else: + + if method.lower() == 'sinkhorn': + return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'sinkhorn_log': + return sinkhorn_log(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, verbose=verbose, + log=log, warn=warn, warmstart=warmstart, + **kwargs) + elif method.lower() == 'sinkhorn_stabilized': + return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax, + stopThr=stopThr, warmstart=warmstart, + verbose=verbose, log=log, warn=warn, + **kwargs) + else: + raise ValueError("Unknown method '%s'." % method) + + +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None, **kwargs): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp + matrix scaling algorithm as proposed in :ref:`[2] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-knopp: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + + K = nx.exp(M / (-reg)) + + Kp = (1 / a).reshape(-1, 1) * K + + err = 1 + for ii in range(numItermax): + uprev = u + vprev = v + KtransposeU = nx.dot(K.T, u) + v = b / KtransposeU + u = 1. / nx.dot(Kp, v) + + if (nx.any(KtransposeU == 0) + or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) + or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Warning: numerical errors at iteration %d' % ii) + u = uprev + v = vprev + break + if ii % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if n_hists: + tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v) + else: + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.einsum('i,ij,j->j', u, K, v) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if err < stopThr: + break + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['niter'] = ii + log['u'] = u + log['v'] = v + + if n_hists: # return only loss + res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) + if log: + return res, log + else: + return res + + else: # return OT matrix + + if log: + return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log + else: + return u.reshape((-1, 1)) * K * v.reshape((1, -1)) + + +def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None, **kwargs): + r""" + Solve the entropic regularization optimal transport problem in log space + and return the OT matrix + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm :ref:`[2] ` with the + implementation from :ref:`[34] ` + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.sinkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-log: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., + Trouvé, A., & Peyré, G. (2019, April). Interpolating between + optimal transport and MMD using Sinkhorn divergences. In The + 22nd International Conference on Artificial Intelligence and + Statistics (pp. 2681-2690). PMLR. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M) + if len(b) == 0: + b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M) + + # init data + dim_a = len(a) + dim_b = b.shape[0] + + if len(b.shape) > 1: + n_hists = b.shape[1] + else: + n_hists = 0 + + # in case of multiple historgrams + if n_hists > 1 and warmstart is None: + warmstart = [None] * n_hists + + if n_hists: # we do not want to use tensors sor we do a loop + + lst_loss = [] + lst_u = [] + lst_v = [] + + for k in range(n_hists): + res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr, + verbose=verbose, log=log, warmstart=warmstart[k], **kwargs) + + if log: + lst_loss.append(nx.sum(M * res[0])) + lst_u.append(res[1]['log_u']) + lst_v.append(res[1]['log_v']) + else: + lst_loss.append(nx.sum(M * res)) + res = nx.stack(lst_loss) + if log: + log = {'log_u': nx.stack(lst_u, 1), + 'log_v': nx.stack(lst_v, 1), } + log['u'] = nx.exp(log['log_u']) + log['v'] = nx.exp(log['log_v']) + return res, log + else: + return res + + else: + + if log: + log = {'err': []} + + Mr = - M / reg + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + u = nx.zeros(dim_a, type_as=M) + v = nx.zeros(dim_b, type_as=M) + else: + u, v = warmstart + + def get_logT(u, v): + if n_hists: + return Mr[:, :, None] + u + v + else: + return Mr + u[:, None] + v[None, :] + + loga = nx.log(a) + logb = nx.log(b) + + err = 1 + for ii in range(numItermax): + + v = logb - nx.logsumexp(Mr + u[:, None], 0) + u = loga - nx.logsumexp(Mr + v[None, :], 1) + + if ii % 10 == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0) + err = nx.norm(tmp2 - b) # violation of marginal + if log: + log['err'].append(err) + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + if err < stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + + if log: + log['niter'] = ii + log['log_u'] = u + log['log_v'] = v + log['u'] = nx.exp(u) + log['v'] = nx.exp(v) + + return nx.exp(get_logT(u, v)), log + + else: + return nx.exp(get_logT(u, v)) + + +def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, + log=False, warn=True, warmstart=None): + r""" + Solve the entropic regularization optimal transport problem and return the OT matrix + + The algorithm used is based on the paper :ref:`[22] ` + which is a stochastic version of the Sinkhorn-Knopp + algorithm :ref:`[2] ` + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + warmstart: tuple of arrays, shape (dim_a, dim_b), optional + Initialization of dual potentials. If provided, the dual potentials should be given + (that is the logarithm of the u,v sinkhorn scaling vectors) + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.bregman.greenkhorn(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-greenkhorn: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information + Processing Systems (NIPS) 26, 2013 + + .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time + approximation algorithms for optimal transport via Sinkhorn + iteration, Advances in Neural Information Processing + Systems (NIPS) 31, 2017 + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + if nx.__name__ in ("jax", "tf"): + raise TypeError("JAX or TF arrays have been received. Greenkhorn is not " + "compatible with neither JAX nor TF") + + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] + + dim_a = a.shape[0] + dim_b = b.shape[0] + + K = nx.exp(-M / reg) + + if warmstart is None: + u = nx.full((dim_a,), 1. / dim_a, type_as=K) + v = nx.full((dim_b,), 1. / dim_b, type_as=K) + else: + u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) + G = u[:, None] * K * v[None, :] + + viol = nx.sum(G, axis=1) - a + viol_2 = nx.sum(G, axis=0) - b + stopThr_val = 1 + if log: + log = dict() + log['u'] = u + log['v'] = v + + for ii in range(numItermax): + i_1 = nx.argmax(nx.abs(viol)) + i_2 = nx.argmax(nx.abs(viol_2)) + m_viol_1 = nx.abs(viol[i_1]) + m_viol_2 = nx.abs(viol_2[i_2]) + stopThr_val = nx.maximum(m_viol_1, m_viol_2) + + if m_viol_1 > m_viol_2: + old_u = u[i_1] + new_u = a[i_1] / nx.dot(K[i_1, :], v) + G[i_1, :] = new_u * K[i_1, :] * v + + viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1] + viol_2 += (K[i_1, :].T * (new_u - old_u) * v) + u[i_1] = new_u + else: + old_v = v[i_2] + new_v = b[i_2] / nx.dot(K[:, i_2].T, u) + G[:, i_2] = u * K[:, i_2] * new_v + # aviol = (G@one_m - a) + # aviol_2 = (G.T@one_n - b) + viol += (-old_v + new_v) * K[:, i_2] * u + viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2] + v[i_2] = new_v + + if stopThr_val <= stopThr: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + + if log: + log["n_iter"] = ii + log['u'] = u + log['v'] = v + + if log: + return G, log + else: + return G + + +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, + log=False, warn=True, **kwargs): + r""" + Solve the entropic regularization OT problem with log stabilization + + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target + weights (histograms, both sum to 1) + + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` an defined in + :ref:`[9] ` (Algo 3.1) . + + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) + samples in the target domain + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + tau : float + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}` + for log scaling + warmstart : table of vectors + if given then starting values for alpha and beta log scalings + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> import ot + >>> a=[.5,.5] + >>> b=[.5,.5] + >>> M=[[0.,1.],[1.,0.]] + >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + + .. _references-sinkhorn-stabilized: + References + ---------- + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of + Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms + for Entropy Regularized Transport Problems. + arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] + + # test if multiple target + if len(b.shape) > 1: + n_hists = b.shape[1] + a = a[:, None] + else: + n_hists = 0 + + # init data + dim_a = len(a) + dim_b = len(b) + + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) + else: + alpha, beta = warmstart + + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M) + u /= dim_a + v /= dim_b + + def get_K(alpha, beta): + """log space computation""" + return nx.exp(-(M - alpha.reshape((dim_a, 1)) + - beta.reshape((1, dim_b))) / reg) + + def get_Gamma(alpha, beta, u, v): + """log space gamma computation""" + return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) + / reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b)))) + + K = get_K(alpha, beta) + transp = K + err = 1 + for ii in range(numItermax): + + uprev = u + vprev = v + + # sinkhorn update + v = b / (nx.dot(K.T, u)) + u = a / (nx.dot(K, v)) + + # remove numerical problems and store them in K + if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau: + if n_hists: + alpha, beta = alpha + reg * \ + nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v)) + else: + alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v) + if n_hists: + u = nx.ones((dim_a, n_hists), type_as=M) / dim_a + v = nx.ones((dim_b, n_hists), type_as=M) / dim_b + else: + u = nx.ones(dim_a, type_as=M) / dim_a + v = nx.ones(dim_b, type_as=M) / dim_b + K = get_K(alpha, beta) + + if ii % print_period == 0: + # we can speed up the process by checking for the error only all + # the 10th iterations + if n_hists: + err_u = nx.max(nx.abs(u - uprev)) + err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0) + err_v = nx.max(nx.abs(v - vprev)) + err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0) + err = 0.5 * (err_u + err_v) + else: + transp = get_Gamma(alpha, beta, u, v) + err = nx.norm(nx.sum(transp, axis=0) - b) + if log: + log['err'].append(err) + + if verbose: + if ii % (print_period * 20) == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + if err <= stopThr: + break + + if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)): + # we have reached the machine precision + # come back to previous solution and quit loop + warnings.warn('Numerical errors at iteration %d' % ii) + u = uprev + v = vprev + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + if n_hists: + alpha = alpha[:, None] + beta = beta[:, None] + logu = alpha / reg + nx.log(u) + logv = beta / reg + nx.log(v) + log["n_iter"] = ii + log['logu'] = logu + log['logv'] = logv + log['alpha'] = alpha + reg * nx.log(u) + log['beta'] = beta + reg * nx.log(v) + log['warmstart'] = (log['alpha'], log['beta']) + if n_hists: + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) + return res, log + + else: + return get_Gamma(alpha, beta, u, v), log + else: + if n_hists: + res = nx.stack([ + nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M) + for i in range(n_hists) + ]) + return res + else: + return get_Gamma(alpha, beta, u, v) + + +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, + numInnerItermax=100, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=10, + log=False, warn=True, **kwargs): + r""" + Solve the entropic regularization optimal transport problem with log + stabilization and epsilon scaling. + The function solves the following optimization problem: + + .. math:: + \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg}\cdot\Omega(\gamma) + + s.t. \ \gamma \mathbf{1} &= \mathbf{a} + + \gamma^T \mathbf{1} &= \mathbf{b} + + \gamma &\geq 0 + + where : + + - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in :ref:`[2] ` + but with the log stabilization + proposed in :ref:`[10] ` and the log scaling + proposed in :ref:`[9] ` algorithm 3.2 + + Parameters + ---------- + a : array-like, shape (dim_a,) + samples weights in the source domain + b : array-like, shape (dim_b,) + samples in the target domain + M : array-like, shape (dim_a, dim_b) + loss matrix + reg : float + Regularization term >0 + tau : float + threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}` + for log scaling + warmstart : tuple of vectors + if given then starting values for alpha and beta log scalings + numItermax : int, optional + Max number of iterations + numInnerItermax : int, optional + Max number of iterations in the inner slog stabilized sinkhorn + epsilon0 : int, optional + first epsilon regularization value (then exponential decrease to reg) + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + warn : bool, optional + if True, raises a warning if the algorithm doesn't convergence. + + Returns + ------- + gamma : array-like, shape (dim_a, dim_b) + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + Examples + -------- + >>> import ot + >>> a=[.5, .5] + >>> b=[.5, .5] + >>> M=[[0., 1.], [1., 0.]] + >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1) + array([[0.36552929, 0.13447071], + [0.13447071, 0.36552929]]) + + .. _references-sinkhorn-epsilon-scaling: + References + ---------- + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal + Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 + + .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for + Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.optim.cg : General regularized OT + """ + + a, b, M = list_to_array(a, b, M) + + nx = get_backend(M, a, b) + + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=M) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=M) / M.shape[1] + + # init data + dim_a = len(a) + dim_b = len(b) + + # nrelative umerical precision with 64 bits + numItermin = 35 + numItermax = max(numItermin, numItermax) # ensure that last velue is exact + + ii = 0 + if log: + log = {'err': []} + + # we assume that no distances are null except those of the diagonal of + # distances + if warmstart is None: + alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M) + else: + alpha, beta = warmstart + + # print(np.min(K)) + def get_reg(n): # exponential decreasing + return (epsilon0 - reg) * np.exp(-n) + reg + + err = 1 + for ii in range(numItermax): + + regi = get_reg(ii) + + G, logi = sinkhorn_stabilized(a, b, M, regi, + numItermax=numInnerItermax, stopThr=stopThr, + warmstart=(alpha, beta), verbose=False, + print_period=20, tau=tau, log=True) + + alpha = logi['alpha'] + beta = logi['beta'] + + if ii % (print_period) == 0: # spsion nearly converged + # we can speed up the process by checking for the error only all + # the 10th iterations + transp = G + err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \ + nx.norm(nx.sum(transp, axis=1) - a) ** 2 + if log: + log['err'].append(err) + + if verbose: + if ii % (print_period * 10) == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + if err <= stopThr and ii > numItermin: + break + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") + if log: + log['alpha'] = alpha + log['beta'] = beta + log['warmstart'] = (log['alpha'], log['beta']) + log['niter'] = ii + return G, log + else: + return G diff --git a/ot/bregman/_utils.py b/ot/bregman/_utils.py new file mode 100644 index 000000000..9535cd1c0 --- /dev/null +++ b/ot/bregman/_utils.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +""" +Common tools of Bregman projections solvers for entropic regularized OT +""" + +# Author: Remi Flamary +# Nicolas Courty +# +# License: MIT License + +from ..utils import list_to_array +from ..backend import get_backend + + +def geometricBar(weights, alldistribT): + """return the weighted geometric mean of distributions""" + weights, alldistribT = list_to_array(weights, alldistribT) + nx = get_backend(weights, alldistribT) + assert (len(weights) == alldistribT.shape[1]) + return nx.exp(nx.dot(nx.log(alldistribT), weights.T)) + + +def geometricMean(alldistribT): + """return the geometric mean of distributions""" + alldistribT = list_to_array(alldistribT) + nx = get_backend(alldistribT) + return nx.exp(nx.mean(nx.log(alldistribT), axis=1)) + + +def projR(gamma, p): + """return the KL projection on the row constraints """ + gamma, p = list_to_array(gamma, p) + nx = get_backend(gamma, p) + return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T + + +def projC(gamma, q): + """return the KL projection on the column constraints """ + gamma, q = list_to_array(gamma, q) + nx = get_backend(gamma, q) + return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)