From 0508b3575a06bfadc0a53b0e5bfdee4266b39229 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 10 Nov 2023 12:54:14 +0100 Subject: [PATCH 01/14] add first geomlowss wrapper --- ot/geomloss.py | 114 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 ot/geomloss.py diff --git a/ot/geomloss.py b/ot/geomloss.py new file mode 100644 index 000000000..437d81b63 --- /dev/null +++ b/ot/geomloss.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +""" +wrapper functions for geomloss +""" + +try: + import geomloss + from geomloss import SamplesLoss + import torch + from torch.autograd import grad + from .utils import get_backend, LazyTensor, dist +except: + geomloss = False + +def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', reg=1e-1, nx=None): + """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) + + Parameters + ---------- + X_a : array-like, shape (n_samples_a, dim) + samples in the source domain + X_b : array-like, shape (n_samples_b, dim) + samples in the target domain + f : array-like, shape (n_samples_a,) + First dual potentials (log space) + g : array-like, shape (n_samples_b,) + Second dual potentials (log space) + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation + reg : float, default=1e-1 + Regularization term >0 + nx : Backend(), default=None + Numerical backend used + + + Returns + ------- + T : LazyTensor + Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T) + """ + + if nx is None: + nx = get_backend(X_a, X_b, f, g) + + shape = (X_a.shape[0], X_b.shape[0]) + + def func(i, j, X_a, X_b, f, g, a, b, metric, reg): + C = dist(X_a[i], X_b[j], metric=metric) + return nx.exp((f[i, None] + g[None, j] - C) / reg)* (a[i,None]*b[None,j]) + + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, reg=reg) + + return T + +def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None): + + if geomloss: + + + nx = get_backend(X_s, X_t, a, b) + + if nx.__name__ != 'torch': + raise ValueError('geomloss only support torch backend') + + # after that we are all in torch + + if a is None: + a = torch.ones(X_s.shape[0], dtype=X_s.dtype, device=X_s.device) / X_s.shape[0] + if b is None: + b = torch.ones(X_t.shape[0], dtype=X_t.dtype, device=X_t.device) / X_t.shape[0] + + if metric == 'sqeuclidean': + p=2 + blur = reg/2 # because geomloss divides cost by two + elif metric == 'euclidean': + p=1 + blur = reg + else: + raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') + + X_s.requires_grad = True + X_t.requires_grad = True + a.requires_grad = True + b.requires_grad = True + + loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend='auto', debias=False, verbose=verbose) + + value = loss(a, X_s, b, X_t) # linear + entropic/KL reg? + + if metric == 'sqeuclidean': + value *= 2 # because geomloss divides cost by two + + f, g = grad(value, [a, b]) + + if log: + log = {} + log['f'] = f + log['g'] = g + log['value'] = value + + log['lazy_tensor'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, reg=reg, nx=nx) + + return value, log + + else: + return value + + + else: + raise ImportError('geomloss not installed') + + \ No newline at end of file From 3e8c7cad9de2ebe631c15ee0651928a695c10f0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 10 Nov 2023 13:38:40 +0100 Subject: [PATCH 02/14] pep8 --- ot/geomloss.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/ot/geomloss.py b/ot/geomloss.py index 437d81b63..620dbbdfa 100644 --- a/ot/geomloss.py +++ b/ot/geomloss.py @@ -9,9 +9,10 @@ import torch from torch.autograd import grad from .utils import get_backend, LazyTensor, dist -except: +except BaseException: geomloss = False + def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', reg=1e-1, nx=None): """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) @@ -46,19 +47,19 @@ def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', def func(i, j, X_a, X_b, f, g, a, b, metric, reg): C = dist(X_a[i], X_b[j], metric=metric) - return nx.exp((f[i, None] + g[None, j] - C) / reg)* (a[i,None]*b[None,j]) + return nx.exp((f[i, None] + g[None, j] - C) / reg) * (a[i, None] * b[None, j]) T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, reg=reg) - return T + return T + def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None): + numIterMax=10000, stopThr=1e-9, + verbose=False, log=False, warn=True, warmstart=None): if geomloss: - nx = get_backend(X_s, X_t, a, b) if nx.__name__ != 'torch': @@ -72,10 +73,10 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid b = torch.ones(X_t.shape[0], dtype=X_t.dtype, device=X_t.device) / X_t.shape[0] if metric == 'sqeuclidean': - p=2 - blur = reg/2 # because geomloss divides cost by two + p = 2 + blur = reg / 2 # because geomloss divides cost by two elif metric == 'euclidean': - p=1 + p = 1 blur = reg else: raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') @@ -84,10 +85,10 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid X_t.requires_grad = True a.requires_grad = True b.requires_grad = True - + loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend='auto', debias=False, verbose=verbose) - value = loss(a, X_s, b, X_t) # linear + entropic/KL reg? + value = loss(a, X_s, b, X_t) # linear + entropic/KL reg? if metric == 'sqeuclidean': value *= 2 # because geomloss divides cost by two @@ -107,8 +108,5 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid else: return value - else: raise ImportError('geomloss not installed') - - \ No newline at end of file From 550cd8ccb43cb2b57ae2f15ecac438de9bceca89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 12:25:21 +0100 Subject: [PATCH 03/14] working geomlos wrapper --- README.md | 3 + ot/bregman/__init__.py | 6 +- ot/bregman/_geomloss.py | 220 ++++++++++++++++++++++++++++++++++++++++ ot/geomloss.py | 112 -------------------- 4 files changed, 227 insertions(+), 114 deletions(-) create mode 100644 ot/bregman/_geomloss.py delete mode 100644 ot/geomloss.py diff --git a/README.md b/README.md index f622c5aab..dc7c5dfaa 100644 --- a/README.md +++ b/README.md @@ -334,3 +334,6 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil [59] Taylor A. B. (2017). [Convex interpolation and performance estimation of first-order methods for convex optimization.](https://dial.uclouvain.be/pr/boreal/object/boreal%3A182881/datastream/PDF_01/view) PhD thesis, Catholic University of Louvain, Louvain-la-Neuve, Belgium, 2017. +[60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms](https://arxiv.org/pdf/2107.02010.pdf). In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. + +[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462. \ No newline at end of file diff --git a/ot/bregman/__init__.py b/ot/bregman/__init__.py index 230982e9c..0bcb4214d 100644 --- a/ot/bregman/__init__.py +++ b/ot/bregman/__init__.py @@ -39,6 +39,8 @@ from ._dictionary import (unmix) +from ._geomloss import (empirical_sinkhorn2_geomloss, geomloss) + __all__ = ['geometricBar', 'geometricMean', 'projR', 'projC', 'sinkhorn', 'sinkhorn2', 'sinkhorn_knopp', 'sinkhorn_log', @@ -46,8 +48,8 @@ 'barycenter', 'barycenter_sinkhorn', 'free_support_sinkhorn_barycenter', 'barycenter_stabilized', 'barycenter_debiased', 'jcpot_barycenter', 'convolutional_barycenter2d', 'convolutional_barycenter2d_debiased', - 'empirical_sinkhorn', 'empirical_sinkhorn2', - 'empirical_sinkhorn_divergence', + 'empirical_sinkhorn', 'empirical_sinkhorn2', 'empirical_sinkhorn2_geomloss' + 'empirical_sinkhorn_divergence', 'geomloss', 'screenkhorn', 'unmix' ] diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py new file mode 100644 index 000000000..028672e6b --- /dev/null +++ b/ot/bregman/_geomloss.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +""" +Wrapper functions for geomloss +""" + +# Author: Remi Flamary +# +# License: MIT License + +import numpy as np +try: + import geomloss + from geomloss import SamplesLoss + import torch + from torch.autograd import grad + from .utils import get_backend, LazyTensor, dist +except BaseException: + geomloss = False + + +def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', blur=0.1, nx=None): + """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) + + Parameters + ---------- + X_a : array-like, shape (n_samples_a, dim) + samples in the source domain + X_torch: array-like, shape (n_samples_b, dim) + samples in the target domain + f : array-like, shape (n_samples_a,) + First dual potentials (log space) + g : array-like, shape (n_samples_b,) + Second dual potentials (log space) + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation + blur : float, default=1e-1 + blur term (blur=sqrt(reg)) >0 + nx : Backend(), default=None + Numerical backend used + + + Returns + ------- + T : LazyTensor + Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T) + """ + + if nx is None: + nx = get_backend(X_a, X_b, f, g) + + shape = (X_a.shape[0], X_b.shape[0]) + + def func(i, j, X_a, X_b, f, g, a, b, metric, blur): + if metric == 'sqeuclidean': + C = dist(X_a[i], X_b[j], metric=metric)/2 + else: + C = dist(X_a[i], X_b[j], metric=metric) + return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j]) + + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur) + + return T + + +def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', scaling=0.95, + verbose=False, debias=False, log=False, backend='auto'): + r""" Solve the entropic regularization optimal transport problem with geomloss + + The function solves the following optimization problem: + + .. math:: + \gamma = arg\min_\gamma <\gamma,C>_F + reg\cdot\Omega(\gamma) + + s.t. \gamma 1 = a + + \gamma^T 1= b + + \gamma\geq 0 + + where : + + - :math:`C` is the cost matrix such that :math:`C_{i,j}=d(x_i^s,x_j^t)` and + :math:`d` is a metric. + - :math:`\Omega` is the entropic regularization term + :math:`\Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})-\gamma_{i,j}+1` + - :math:`a` and :math:`b` are source and target weights (sum to 1) + + The algorithm used for solving the problem is the Sinkhorn-Knopp matrix + scaling algorithm as proposed in and computed in log space for + better stability and epsilon-scaling. The solution is computed ina lzy way + using the Geomloss [60] and the KeOps library [61]. + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,), default=None + samples weights in the source domain + b : array-like, shape (n_samples_b,), default=None + samples weights in the target domain + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation Only acepted values are + 'sqeuclidean' and 'euclidean'. + scaling : float, default=0.95 + Scaling parameter used for epsilon scaling. Value close to one promote + precision while value close to zero promote speed. + verbose : bool, default=False + Print information + debias : bool, default=False + Use the debiased version of Sinkhorn algorithm [12]_. + log : bool, default=False + Return log dictionary containing all computed objects + backend : str, default='auto' + Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online' + and 'multiscale' are accepted values. + + Returns + ------- + value : float + OT value + log : dict + Log dictionary return only if log==True in parameters + + References + ---------- + + .. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast + and scalable optimal transport for brain tractograms. In Medical Image + Computing and Computer Assisted Intervention–MICCAI 2019: 22nd + International Conference, Shenzhen, China, October 13–17, 2019, + Proceedings, Part III 22 (pp. 636-644). Springer International + Publishing. + + .. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. + (2021). Kernel operations on the gpu, with autodiff, without memory + overflows. The Journal of Machine Learning Research, 22(1), 3457-3462. + + """ + + + + + if geomloss: + + nx = get_backend(X_s, X_t, a, b) + + if nx.__name__ not in ['torch', 'numpy']: + raise ValueError('geomloss only support torch or numpy backend') + + if nx.__name__ == 'numpy': + X_s_torch= torch.tensor(X_s) + X_t_torch = torch.tensor(X_t) + if a is not None: + a_torch= torch.tensor(a) + if b is not None: + b_torch = torch.tensor(b) + else: + X_s_torch= X_s + X_t_torch= X_t + + a_torch= a + b_torch= b + # after that we are all in torch + + # after that we are all in torch + + if a_torch is None: + a_torch= torch.ones(X_s_torch.shape[0], dtype=X_s_torch.dtype, device=X_s_torch.device) / X_s_torch.shape[0] + if b_torch is None: + b_torch= torch.ones(X_t_torch.shape[0], dtype=X_t_torch.dtype, device=X_t_torch.device) / X_t_torch.shape[0] + + # set blur value and p + if metric == 'sqeuclidean': + p = 2 + blur = np.sqrt(reg / 2) # because geomloss divides cost by two + elif metric == 'euclidean': + p = 1 + blur = np.sqrt(reg) + else: + raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') + + # force gradients for computing dual + a_torch.requires_grad = True + b_torch.requires_grad = True + + loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend=backend, debias=debias, scaling=scaling, verbose=verbose) + + # compute value + value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg? + + # get dual potentials + f, g = grad(value, [a_torch, b_torch]) + + if metric == 'sqeuclidean': + value *= 2 # because geomloss divides cost by two + + if nx.__name__ == 'numpy': + f = f.cpu().detach().numpy() + g = g.cpu().detach().numpy() + value = value.cpu().detach().numpy() + + if log: + log = {} + log['f'] = f + log['g'] = g + log['value'] = value + + log['lazy_tensor'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx) + + return value, log + + else: + return value + + else: + raise ImportError('geomloss not installed') diff --git a/ot/geomloss.py b/ot/geomloss.py deleted file mode 100644 index 620dbbdfa..000000000 --- a/ot/geomloss.py +++ /dev/null @@ -1,112 +0,0 @@ -# -*- coding: utf-8 -*- -""" -wrapper functions for geomloss -""" - -try: - import geomloss - from geomloss import SamplesLoss - import torch - from torch.autograd import grad - from .utils import get_backend, LazyTensor, dist -except BaseException: - geomloss = False - - -def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', reg=1e-1, nx=None): - """ Get a LazyTensor of sinkhorn solution T = exp((f+g^T-C)/reg)*(ab^T) - - Parameters - ---------- - X_a : array-like, shape (n_samples_a, dim) - samples in the source domain - X_b : array-like, shape (n_samples_b, dim) - samples in the target domain - f : array-like, shape (n_samples_a,) - First dual potentials (log space) - g : array-like, shape (n_samples_b,) - Second dual potentials (log space) - metric : str, default='sqeuclidean' - Metric used for the cost matrix computation - reg : float, default=1e-1 - Regularization term >0 - nx : Backend(), default=None - Numerical backend used - - - Returns - ------- - T : LazyTensor - Lowrank tensor T = exp((f+g^T-C)/reg)*(ab^T) - """ - - if nx is None: - nx = get_backend(X_a, X_b, f, g) - - shape = (X_a.shape[0], X_b.shape[0]) - - def func(i, j, X_a, X_b, f, g, a, b, metric, reg): - C = dist(X_a[i], X_b[j], metric=metric) - return nx.exp((f[i, None] + g[None, j] - C) / reg) * (a[i, None] * b[None, j]) - - T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, reg=reg) - - return T - - -def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', - numIterMax=10000, stopThr=1e-9, - verbose=False, log=False, warn=True, warmstart=None): - - if geomloss: - - nx = get_backend(X_s, X_t, a, b) - - if nx.__name__ != 'torch': - raise ValueError('geomloss only support torch backend') - - # after that we are all in torch - - if a is None: - a = torch.ones(X_s.shape[0], dtype=X_s.dtype, device=X_s.device) / X_s.shape[0] - if b is None: - b = torch.ones(X_t.shape[0], dtype=X_t.dtype, device=X_t.device) / X_t.shape[0] - - if metric == 'sqeuclidean': - p = 2 - blur = reg / 2 # because geomloss divides cost by two - elif metric == 'euclidean': - p = 1 - blur = reg - else: - raise ValueError('geomloss only supports sqeuclidean and euclidean metrics') - - X_s.requires_grad = True - X_t.requires_grad = True - a.requires_grad = True - b.requires_grad = True - - loss = SamplesLoss(loss='sinkhorn', p=p, blur=blur, backend='auto', debias=False, verbose=verbose) - - value = loss(a, X_s, b, X_t) # linear + entropic/KL reg? - - if metric == 'sqeuclidean': - value *= 2 # because geomloss divides cost by two - - f, g = grad(value, [a, b]) - - if log: - log = {} - log['f'] = f - log['g'] = g - log['value'] = value - - log['lazy_tensor'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, reg=reg, nx=nx) - - return value, log - - else: - return value - - else: - raise ImportError('geomloss not installed') From 79d16d8da88ebae1346b41d37cbeed6eccc40577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 12:25:44 +0100 Subject: [PATCH 04/14] pep8 --- ot/bregman/_geomloss.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py index 028672e6b..2ba2abad9 100644 --- a/ot/bregman/_geomloss.py +++ b/ot/bregman/_geomloss.py @@ -52,7 +52,7 @@ def get_sinkhorn_geomloss_lazytensor(X_a, X_b, f, g, a, b, metric='sqeuclidean', def func(i, j, X_a, X_b, f, g, a, b, metric, blur): if metric == 'sqeuclidean': - C = dist(X_a[i], X_b[j], metric=metric)/2 + C = dist(X_a[i], X_b[j], metric=metric) / 2 else: C = dist(X_a[i], X_b[j], metric=metric) return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (a[i, None] * b[None, j]) @@ -77,7 +77,7 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid \gamma\geq 0 - where : + where : - :math:`C` is the cost matrix such that :math:`C_{i,j}=d(x_i^s,x_j^t)` and :math:`d` is a metric. @@ -103,21 +103,21 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid b : array-like, shape (n_samples_b,), default=None samples weights in the target domain metric : str, default='sqeuclidean' - Metric used for the cost matrix computation Only acepted values are + Metric used for the cost matrix computation Only acepted values are 'sqeuclidean' and 'euclidean'. scaling : float, default=0.95 - Scaling parameter used for epsilon scaling. Value close to one promote + Scaling parameter used for epsilon scaling. Value close to one promote precision while value close to zero promote speed. verbose : bool, default=False - Print information + Print information debias : bool, default=False Use the debiased version of Sinkhorn algorithm [12]_. log : bool, default=False Return log dictionary containing all computed objects backend : str, default='auto' - Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online' + Numerical backend for geomloss. Only 'auto' and 'tensorized' 'online' and 'multiscale' are accepted values. - + Returns ------- value : float @@ -135,15 +135,12 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid Proceedings, Part III 22 (pp. 636-644). Springer International Publishing. - .. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. - (2021). Kernel operations on the gpu, with autodiff, without memory + .. [61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. + (2021). Kernel operations on the gpu, with autodiff, without memory overflows. The Journal of Machine Learning Research, 22(1), 3457-3462. """ - - - if geomloss: nx = get_backend(X_s, X_t, a, b) @@ -152,26 +149,26 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid raise ValueError('geomloss only support torch or numpy backend') if nx.__name__ == 'numpy': - X_s_torch= torch.tensor(X_s) + X_s_torch = torch.tensor(X_s) X_t_torch = torch.tensor(X_t) if a is not None: - a_torch= torch.tensor(a) + a_torch = torch.tensor(a) if b is not None: b_torch = torch.tensor(b) else: - X_s_torch= X_s - X_t_torch= X_t + X_s_torch = X_s + X_t_torch = X_t - a_torch= a - b_torch= b + a_torch = a + b_torch = b # after that we are all in torch # after that we are all in torch if a_torch is None: - a_torch= torch.ones(X_s_torch.shape[0], dtype=X_s_torch.dtype, device=X_s_torch.device) / X_s_torch.shape[0] + a_torch = torch.ones(X_s_torch.shape[0], dtype=X_s_torch.dtype, device=X_s_torch.device) / X_s_torch.shape[0] if b_torch is None: - b_torch= torch.ones(X_t_torch.shape[0], dtype=X_t_torch.dtype, device=X_t_torch.device) / X_t_torch.shape[0] + b_torch = torch.ones(X_t_torch.shape[0], dtype=X_t_torch.dtype, device=X_t_torch.device) / X_t_torch.shape[0] # set blur value and p if metric == 'sqeuclidean': @@ -201,7 +198,7 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid if nx.__name__ == 'numpy': f = f.cpu().detach().numpy() g = g.cpu().detach().numpy() - value = value.cpu().detach().numpy() + value = value.cpu().detach().numpy() if log: log = {} From 686f5d35093a9000a3163ff997f3bc84976caa89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 12:29:28 +0100 Subject: [PATCH 05/14] small edit --- ot/bregman/_geomloss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py index 2ba2abad9..dd9aabd28 100644 --- a/ot/bregman/_geomloss.py +++ b/ot/bregman/_geomloss.py @@ -128,7 +128,7 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid References ---------- - .. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast + .. [60] Feydy, J., Roussillon, P., Trouvé, A., & Gori, P. (2019). [Fast and scalable optimal transport for brain tractograms. In Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, From 1e9bd8241d407657535ec9aa2dbb9b9fdc531ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 13:35:39 +0100 Subject: [PATCH 06/14] test for geomloss wrapper --- RELEASES.md | 1 + ot/bregman/_geomloss.py | 25 ++++++++++++------------- test/test_bregman.py | 27 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 13 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 68ff9ea9b..65a2b93ce 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -17,6 +17,7 @@ + Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) + Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) + New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563) ++ Wrapper for `geomloss`` solver on empirical samples (PR #571) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/bregman/_geomloss.py b/ot/bregman/_geomloss.py index dd9aabd28..594d3b8e0 100644 --- a/ot/bregman/_geomloss.py +++ b/ot/bregman/_geomloss.py @@ -13,8 +13,8 @@ from geomloss import SamplesLoss import torch from torch.autograd import grad - from .utils import get_backend, LazyTensor, dist -except BaseException: + from ..utils import get_backend, LazyTensor, dist +except ImportError: geomloss = False @@ -148,28 +148,27 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid if nx.__name__ not in ['torch', 'numpy']: raise ValueError('geomloss only support torch or numpy backend') + if a is None: + a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0] + if b is None: + b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0] + if nx.__name__ == 'numpy': X_s_torch = torch.tensor(X_s) X_t_torch = torch.tensor(X_t) - if a is not None: - a_torch = torch.tensor(a) - if b is not None: - b_torch = torch.tensor(b) + + a_torch = torch.tensor(a) + b_torch = torch.tensor(b) + else: X_s_torch = X_s X_t_torch = X_t a_torch = a b_torch = b - # after that we are all in torch # after that we are all in torch - if a_torch is None: - a_torch = torch.ones(X_s_torch.shape[0], dtype=X_s_torch.dtype, device=X_s_torch.device) / X_s_torch.shape[0] - if b_torch is None: - b_torch = torch.ones(X_t_torch.shape[0], dtype=X_t_torch.dtype, device=X_t_torch.device) / X_t_torch.shape[0] - # set blur value and p if metric == 'sqeuclidean': p = 2 @@ -206,7 +205,7 @@ def empirical_sinkhorn2_geomloss(X_s, X_t, reg, a=None, b=None, metric='sqeuclid log['g'] = g log['value'] = value - log['lazy_tensor'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx) + log['lazy_plan'] = get_sinkhorn_geomloss_lazytensor(X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx) return value, log diff --git a/test/test_bregman.py b/test/test_bregman.py index 67257f899..8a85fac0d 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1057,6 +1057,33 @@ def test_empirical_sinkhorn(nx): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") +@pytest.skip_backend("jax") +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) +def test_geomloss_solver(nx, metric): + # test sinkhorn + n = 10 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + G_sqe = nx.to_numpy(ot.bregman.empirical_sinkhorn(X_sb, X_tb, 1, metric=metric)) + + value, log = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric, log=True) + G_geomloss = nx.to_numpy(log['lazy_plan'][:]) + + # call with log = False + value = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric) + + # check equality of plans + np.testing.assert_allclose(G_sqe, G_geomloss, atol=1e-03) # metric sqeuclidian + + def test_lazy_empirical_sinkhorn(nx): # test sinkhorn n = 10 From bc68d27fe0e622c5771884c37c4a3cf217d6cf56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 13:36:35 +0100 Subject: [PATCH 07/14] test for geomloss wrapper --- test/test_bregman.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_bregman.py b/test/test_bregman.py index 8a85fac0d..db21b5b4f 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1083,6 +1083,10 @@ def test_geomloss_solver(nx, metric): # check equality of plans np.testing.assert_allclose(G_sqe, G_geomloss, atol=1e-03) # metric sqeuclidian + # check error on wrong metric + with pytest.raises(ValueError): + ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric='wrong_metric') + def test_lazy_empirical_sinkhorn(nx): # test sinkhorn From d734129689cbe42cfce95ec4e48292c2c8f51018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 13:41:30 +0100 Subject: [PATCH 08/14] ad geomloss to tests --- requirements.txt | 3 ++- test/test_bregman.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6ac25eb3c..e80a57111 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ jaxlib tensorflow pytest torch_geometric -cvxpy \ No newline at end of file +cvxpy +geomloss \ No newline at end of file diff --git a/test/test_bregman.py b/test/test_bregman.py index db21b5b4f..40f993e8e 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -15,6 +15,7 @@ import ot from ot.backend import tf, torch +from ot.bregman import geomloss @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) @@ -1057,6 +1058,7 @@ def test_empirical_sinkhorn(nx): np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) +@pytest.mark.skipif(not geomloss, reason="pytorch not installed") @pytest.skip_backend('tf') @pytest.skip_backend("cupy") @pytest.skip_backend("jax") From 55333a5a615ccd456d94972fc95321bca03f36ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 13:45:24 +0100 Subject: [PATCH 09/14] pep8 test --- test/test_bregman.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 40f993e8e..1a92a1037 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1079,8 +1079,10 @@ def test_geomloss_solver(nx, metric): value, log = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric, log=True) G_geomloss = nx.to_numpy(log['lazy_plan'][:]) + print(value) + # call with log = False - value = ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric) + ot.bregman.empirical_sinkhorn2_geomloss(X_sb, X_tb, 1, metric=metric) # check equality of plans np.testing.assert_allclose(G_sqe, G_geomloss, atol=1e-03) # metric sqeuclidian From 48ff260b3f568862d1ec81c88fa5bb064b1cb0c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 14:57:24 +0100 Subject: [PATCH 10/14] add option in solve_sample --- ot/solvers.py | 35 ++++++++++++++++++++++++++++++++--- test/test_solvers.py | 43 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index aed7e8ffe..b65f1149d 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -11,7 +11,7 @@ from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log, empirical_sinkhorn2 +from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -23,6 +23,8 @@ from .gaussian import empirical_bures_wasserstein_distance from .factored import factored_optimal_transport +lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale'] + def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, @@ -865,7 +867,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. @@ -922,6 +924,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Maximum number of iteration, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional @@ -939,6 +945,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) See :any:`OTResult` for more information. @@ -1148,7 +1155,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t """ - if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: + if method is not None and method.lower() in lst_method_lazy: lazy0 = lazy lazy = True @@ -1221,6 +1228,28 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if not lazy0: # store plan if not lazy plan = lazy_plan[:] + elif method.startswith('geomloss'): # Geomloss solver for entropi OT + + split_method = method.split('_') + if len(split_method) == 2: + backend = split_method[1] + else: + if lazy0 is None: + backend = 'auto' + elif lazy0: + backend = 'online' + else: + backend = 'tensorized' + + value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) + + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + # return scaled potentials (to be consistent with other solvers) + potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) + elif reg is None or reg == 0: # exact OT if unbalanced is None: # balanced EMD solver not available for lazy diff --git a/test/test_solvers.py b/test/test_solvers.py index c6e1a3770..7c94037f5 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -10,7 +10,7 @@ import pytest import ot - +from ot.bregman import geomloss lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] @@ -348,6 +348,47 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +@pytest.mark.skipif(not geomloss, reason="pytorch not installed") +@pytest.skip_backend('tf') +@pytest.skip_backend("cupy") +@pytest.skip_backend("jax") +@pytest.mark.parametrize("metric", ["sqeuclidean", "euclidean"]) +def test_solve_sample_geomloss(nx, metric): + # test solve_sample when is_Lazy = False + n_samples_s = 13 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) + + sol0 = ot.solve_sample(xb, yb, ab, bb, reg=1) + + # solve signe weights + sol = ot.solve_sample(xb, yb, ab, bb, reg=1, method='geomloss') + assert_allclose_sol(sol0, sol) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=False, method='geomloss') + assert_allclose_sol(sol0, sol) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_tensorized') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_online') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss_multiscale') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + sol1 = ot.solve_sample(xb, yb, ab, bb, reg=1, lazy=True, method='geomloss') + np.testing.assert_allclose(nx.to_numpy(sol1.lazy_plan[:]), nx.to_numpy(sol.lazy_plan[:]), rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("method_params", lst_method_params_solve_sample) def test_solve_sample_methods(nx, method_params): From 54fb598048d4f2b34f3b5e11b99463e8c9451a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 15:16:01 +0100 Subject: [PATCH 11/14] limyt to rceent python for geomloss --- ot/utils.py | 2 +- test/test_solvers.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/utils.py b/ot/utils.py index f64c2fea6..cb29b21c9 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -527,7 +527,7 @@ def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): """ if nx is None: - nx = get_backend(a[0]) + nx = get_backend(a[0:1]) if axis is None: res = 0.0 diff --git a/test/test_solvers.py b/test/test_solvers.py index 7c94037f5..bf07b7af8 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -8,6 +8,7 @@ import itertools import numpy as np import pytest +import sys import ot from ot.bregman import geomloss @@ -348,6 +349,7 @@ def test_solve_sample_lazy(nx): np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") @pytest.mark.skipif(not geomloss, reason="pytorch not installed") @pytest.skip_backend('tf') @pytest.skip_backend("cupy") From 591dcb5039b08ce7b32a15e11ecf9a331c6ca253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 15:35:26 +0100 Subject: [PATCH 12/14] add keops as depedency --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e80a57111..6af50f127 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ tensorflow pytest torch_geometric cvxpy -geomloss \ No newline at end of file +geomloss +pykeops \ No newline at end of file From c30a7d7d822cad71e41d01f4e72eed9d59582c70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 21 Nov 2023 21:30:07 +0100 Subject: [PATCH 13/14] update documenation geomloss --- ot/solvers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ot/solvers.py b/ot/solvers.py index b65f1149d..6d13e1986 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1000,6 +1000,26 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) # lazy OT plan lazy_plan = res.lazy_plan + + We also have a very efficient solver with compiled CPU/CUDA code using + geomloss/PyKeOps thta cen be used with the following code: + + .. code-block:: python + + # automatic solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') + + # force O(n) memory efficient solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') + + # force pre-computed cost matrix + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + + # use multiscale solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + + # One can play with speed (small scaling factor) and precision (scaling close to 1) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): From a760e3db798d0b3754d19828e521fe10c9c82e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 22 Nov 2023 08:53:35 +0100 Subject: [PATCH 14/14] pep8 --- README.md | 6 +++++- ot/solvers.py | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index dc7c5dfaa..57b845edb 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,11 @@ POT provides the following generic OT solvers (links to examples): * [OT Network Simplex solver](https://pythonot.github.io/auto_examples/plot_OT_1D.html) for the linear program/ Earth Movers Distance [1] . * [Conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) [6] and [Generalized conditional gradient](https://pythonot.github.io/auto_examples/plot_optim_OTreg.html) for regularized OT [7]. -* Entropic regularization OT solver with [Sinkhorn Knopp Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , stabilized version [9] [10] [34], greedy Sinkhorn [22] and [Screening Sinkhorn [26] ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). +* Entropic regularization OT solver with [Sinkhorn Knopp + Algorithm](https://pythonot.github.io/auto_examples/plot_OT_1D.html) [2] , + stabilized version [9] [10] [34], lazy CPU/GPU solver from geomloss [60] [61], greedy Sinkhorn [22] and [Screening + Sinkhorn [26] + ](https://pythonot.github.io/auto_examples/plot_screenkhorn_1D.html). * Bregman projections for [Wasserstein barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html) [3], [convolutional barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_convolutional_barycenter.html) [21] and unmixing [4]. * Sinkhorn divergence [23] and entropic regularization OT from empirical data. * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] diff --git a/ot/solvers.py b/ot/solvers.py index 6d13e1986..a41762a5c 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1000,9 +1000,9 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) # lazy OT plan lazy_plan = res.lazy_plan - + We also have a very efficient solver with compiled CPU/CUDA code using - geomloss/PyKeOps thta cen be used with the following code: + geomloss/PyKeOps that can be used with the following code: .. code-block:: python @@ -1014,10 +1014,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # force pre-computed cost matrix res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') - + # use multiscale solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') - + # One can play with speed (small scaling factor) and precision (scaling close to 1) res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5)