From 223396e0168b3e36e358746a7529bfc1646871a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 18 Oct 2023 15:16:00 +0200 Subject: [PATCH 01/13] add new API for gromov --- ot/__init__.py | 2 +- ot/da.py | 1 + ot/gromov/_bregman.py | 10 ++ ot/gromov/_gw.py | 10 +- ot/solvers.py | 264 ++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 14 ++- test/test_solvers.py | 80 +++++++++++++ 7 files changed, 376 insertions(+), 5 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 44e87eabe..95eccf56a 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -50,7 +50,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve +from .solvers import solve, solve_gromov # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/da.py b/ot/da.py index 8764268f0..3628db51e 100644 --- a/ot/da.py +++ b/ot/da.py @@ -2274,6 +2274,7 @@ class NearestBrenierPotential(BaseTransport): ot.mapping.nearest_brenier_potential_fit : Fitting the SSNB on source and target data ot.mapping.nearest_brenier_potential_predict_bounds : Predicting SSNB images on new source data """ + def __init__(self, strongly_convex_constant=0.6, gradient_lipschitz_constant=1.4, log=False, its=100, seed=None): self.strongly_convex_constant = strongly_convex_constant self.gradient_lipschitz_constant = gradient_lipschitz_constant diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 6dc705949..654daf232 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -330,6 +330,7 @@ def entropic_gromov_wasserstein2( learning for graph matching and node embedding. In International Conference on Machine Learning (ICML), 2019. """ + T, logv = entropic_gromov_wasserstein( C1, C2, p, q, loss_fun, epsilon, symmetric, G0, max_iter, tol, solver, warmstart, verbose, log=True, **kwargs) @@ -815,12 +816,21 @@ def entropic_fused_gromov_wasserstein2( (ICML). 2019. """ + + nx = get_backend(M, C1, C2) + T, logv = entropic_fused_gromov_wasserstein( M, C1, C2, p, q, loss_fun, epsilon, symmetric, alpha, G0, max_iter, tol, solver, warmstart, verbose, log=True, **kwargs) logv['T'] = T + lin_term = nx.sum(T * M) + gw_term = (logv['gw_dist'] - (1 - alpha) * lin_term) / alpha + + log_fgw['quad_loss'] = gw_term * alpha + log_fgw['lin_loss'] = lin_term * (1 - alpha) + if log: return logv['fgw_dist'], logv else: diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 20373f33b..c1d7ace75 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -582,6 +582,13 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', fgw_dist = log_fgw['fgw_dist'] log_fgw['T'] = T + # compute separate terms for gradients and log + lin_term = nx.sum(T * M) + gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha + + log_fgw['quad_loss'] = gw_term * alpha + log_fgw['lin_loss'] = lin_term * (1 - alpha) + if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) @@ -591,8 +598,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', log_fgw['v'] - nx.mean(log_fgw['v']), alpha * gC1, alpha * gC2, (1 - alpha) * T)) else: - lin_term = nx.sum(T * M) - gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), (log_fgw['u'] - nx.mean(log_fgw['u']), log_fgw['v'] - nx.mean(log_fgw['v']), diff --git a/ot/solvers.py b/ot/solvers.py index bba2734e5..2c9ed6992 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -14,6 +14,9 @@ from .bregman import sinkhorn_log from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual +from .gromov import gromov_wasserstein2, fused_gromov_wasserstein2, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 + +#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, @@ -345,3 +348,264 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, value_linear=value_linear, plan=plan, status=status, backend=nx) return res + + +def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha=0.5, reg=None, + reg_type="entropy", unbalanced=None, unbalanced_type='KL', + n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, + verbose=False): + r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object + + The function solves the following optimization problem: + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and + `reg_type`. By default ``reg=None`` and there is no regularization. The + unbalanced marginal penalization can be selected with `unbalanced` + (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` + and the function solves the exact optimal transport problem (respecting the + marginals). + + Parameters + ---------- + Ca : array_like, shape (dim_a, dim_a) + Cost matrix in the source domain + Cb : array_like, shape (dim_b, dim_b) + Cost matrix in the target domain + M : array_like, shape (dim_a, dim_b), optional + Linear cost matrix for Fused Gromov-Wasserstein (default is None). + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + loss : str, optional + Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` + symmetric : bool, optional + Use symmetric version of the Gromov-Wasserstein problem, by default None + tests wether the matrices are symmetric or True/False to avoid the test. + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R`, by default "entropic" (only used when + ``reg!=None``) + alpha : float, optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for + Gromov problem (when M is not provided). By default ``alpha=None` corresponds to to + ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused + Gromov-Wasserstein problem (``M!=None``) + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT), Not implemented yet + unbalanced_type : str, optional + Type of unbalanced penalization unction :math:`U` either "KL", "L2", + "TV", by default "KL" , Not implemented yet + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each + solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + tol : float, optional + Tolerance for solution precision, by default None (default values in + each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + Notes + ----- + The following methods are available for solving the Gromov-Wasserstein + problem: + + """ + + # detect backend + nx = get_backend(Ca, Cb, M, a, b) + + # create uniform weights if not given + if a is None: + a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] + if b is None: + b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] + + # default values for solutions + potentials = None + value = None + value_linear = None + value_quad = None + plan = None + status = None + + loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} + + if reg is None or reg == 0: # exact OT + + if unbalanced is None: # Exact balanced OT + + if M is None or alpha == 1: # Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + potentials = (log['u'], log['v']) + + elif alpha == 0: # Wasserstein problem + + # default values for EMD solver + if max_iter is None: + max_iter = 1000000 + + value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) + + value = value_linear + potentials = (log['u'], log['v']) + plan = log['G'] + status = log["warning"] if log["warning"] is not None else 'Converged' + value_quad = 0 + + else: # Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + potentials = (log['u'], log['v']) + + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT + + raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) + + else: + raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) + + else: # regularized OT + + if unbalanced is None: # Balanced regularized OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + # potentials = (log['log_u'], log['log_v']) #TODO + + elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, + stopThr=tol, log=True, + verbose=verbose) + + value_linear = nx.sum(M * plan) + value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) + potentials = (log['log_u'], log['log_v']) + + elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if method is None: + method = 'PGD' + + value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: + raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + + else: # unbalanced AND regularized OT + + raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + # if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': + + # if max_iter is None: + # max_iter = 1000 + # if tol is None: + # tol = 1e-9 + + # plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + + # value_linear = nx.sum(M * plan) + + # value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) + + # potentials = (log['logu'], log['logv']) + + # elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: + + # if max_iter is None: + # max_iter = 1000 + # if tol is None: + # tol = 1e-12 + + # plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) + + # value_linear = nx.sum(M * plan) + + # value = log['loss'] + + # else: + # raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, + value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) + + return res diff --git a/ot/utils.py b/ot/utils.py index 72df4294f..8cbb0db25 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -731,11 +731,12 @@ class UndefinedParameter(Exception): class OTResult: - def __init__(self, potentials=None, value=None, value_linear=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): + def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): self._potentials = potentials self._value = value self._value_linear = value_linear + self._value_quad = value_quad self._plan = plan self._log = log self._sparse_plan = sparse_plan @@ -828,7 +829,8 @@ def lazy_plan(self): @property def value(self): - """Full transport cost, including possible regularization terms.""" + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" if self._value is not None: return self._value else: @@ -842,6 +844,14 @@ def value_linear(self): else: raise NotImplementedError() + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + if self._value_quad is not None: + return self._value_quad + else: + raise NotImplementedError() + # Marginal constraints ------------------------- @property def marginals(self): diff --git a/test/test_solvers.py b/test/test_solvers.py index b792aca94..0624a426e 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -17,6 +17,11 @@ lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] +lst_reg_type_gromov = ['entropy'] +lst_unbalanced_type_gromov = ['KL'] +lst_unbalanced_gromov = [None, 0.9] +lst_alpha = [0, 0.4, 0.9, 1] + def assert_allclose_sol(sol1, sol2): @@ -131,3 +136,78 @@ def test_solve_not_implemented(nx): # pairs of incompatible divergences with pytest.raises(NotImplementedError): ot.solve(M, reg=1.0, reg_type='kl', unbalanced=1.0, unbalanced_type='tv') + + +def test_solve_gromov(nx): + + np.random.seed(0) + + n_samples_s = 5 + n_samples_t = 10 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + sol0 = ot.solve_gromov(Ca, Cb) # GW + sol = ot.solve_gromov(Ca, Cb, a=a, b=b) # GW + sol0_fgw = ot.solve_gromov(Ca, Cb, M) # FGW + + # check some attributes + sol.potentials + sol.marginals + + assert_allclose_sol(sol0, sol) + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov(Cax, Cbx, a=ax, b=bx) # GW + solx_fgw = ot.solve_gromov(Cax, Cbx, Mx) # FGW + + assert_allclose_sol(sol, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type,alpha", itertools.product(lst_reg, lst_reg_type_gromov, lst_unbalanced_gromov, lst_unbalanced_type_gromov, lst_alpha)) +def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha): + + np.random.seed(0) + + n_samples_s = 5 + n_samples_t = 10 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + try: + + sol0 = ot.solve_gromov(Ca, Cb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) # GW + sol0_fgw = ot.solve_gromov(Ca, Cb, M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha) # FGW + + # solve in backend + ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) + + solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) # GW + solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha) # FGW + + assert_allclose_sol(sol0, solx) + assert_allclose_sol(sol0_fgw, solx_fgw) + + except NotImplementedError: + pass From 242c10d93e17eed006307d8f1fa0af8733cf3307 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 18 Oct 2023 15:41:32 +0200 Subject: [PATCH 02/13] small bug entropic fgw --- ot/gromov/_bregman.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 654daf232..ddf34f793 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -826,10 +826,10 @@ def entropic_fused_gromov_wasserstein2( logv['T'] = T lin_term = nx.sum(T * M) - gw_term = (logv['gw_dist'] - (1 - alpha) * lin_term) / alpha + gw_term = (logv['fgw_dist'] - (1 - alpha) * lin_term) / alpha - log_fgw['quad_loss'] = gw_term * alpha - log_fgw['lin_loss'] = lin_term * (1 - alpha) + logv['quad_loss'] = gw_term * alpha + logv['lin_loss'] = lin_term * (1 - alpha) if log: return logv['fgw_dist'], logv From c60b412d8410c7f416ae377a9461297ea994094c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 18 Oct 2023 15:52:47 +0200 Subject: [PATCH 03/13] f* pep8 --- ot/solvers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 2c9ed6992..483e8e735 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -542,8 +542,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alph tol = 1e-9 plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, - stopThr=tol, log=True, - verbose=verbose) + stopThr=tol, log=True, + verbose=verbose) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) From 1bb8fbab9f1dc6dbd498063fcc7e8a4425d69153 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 16:18:08 +0200 Subject: [PATCH 04/13] add semirelaxed --- ot/gromov/_semirelaxed.py | 6 ++- ot/solvers.py | 77 +++++++++++++++++++++++++++++++++++++-- test/test_solvers.py | 14 +++---- 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index b36a81c75..0b905c1fa 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -488,6 +488,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo q = nx.sum(T, 0) srfgw_dist = log_fgw['srfgw_dist'] log_fgw['T'] = T + log_fgw['lin_loss'] = nx.sum(M * T) * (1 - alpha) + log_fgw['quad_loss'] = srfgw_dist - log_fgw['lin_loss'] if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) @@ -979,7 +981,9 @@ def df(G): if log: qG = nx.sum(G, 0) marginal_product = nx.outer(ones_p, nx.dot(qG, fC2t)) - log['srfgw_dist'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + (1 - alpha) * nx.sum(M * G) + log['lin_loss'] = nx.sum(M * G) * (1 - alpha) + log['quad_loss'] = alpha * gwloss(constC + marginal_product, hC1, hC2, G, nx) + log['srfgw_dist'] = log['lin_loss'] + log['quad_loss'] return G, log else: return G diff --git a/ot/solvers.py b/ot/solvers.py index 483e8e735..8a867170f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -14,7 +14,11 @@ from .bregman import sinkhorn_log from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual -from .gromov import gromov_wasserstein2, fused_gromov_wasserstein2, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 +from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, + entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2, + semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_fused_gromov_wasserstein2, + entropic_semirelaxed_gromov_wasserstein2) #, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 @@ -350,7 +354,8 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, return res -def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha=0.5, reg=None, +def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, + alpha=0.5, reg=None, reg_type="entropy", unbalanced=None, unbalanced_type='KL', n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, verbose=False): @@ -458,7 +463,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alph if reg is None or reg == 0: # exact OT - if unbalanced is None: # Exact balanced OT + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT if M is None or alpha == 1: # Gromov-Wasserstein problem @@ -505,6 +510,39 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alph plan = log['T'] potentials = (log['u'], log['v']) + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_quad = value + if alpha == 1: # set to 0 for FGW with alpha=1 + value_linear = 0 + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # Semi relaxed Fused Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 10000 + if tol is None: + tol = 1e-9 + + value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) @@ -514,7 +552,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alph else: # regularized OT - if unbalanced is None: # Balanced regularized OT + if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem @@ -570,6 +608,37 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alph else: raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) + elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT + + if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + plan = log['T'] + value_linear = 0 + value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) + + else: # Entropic Semi-relaxed FGW problem + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + + value_linear = log['lin_loss'] + value_quad = log['quad_loss'] + plan = log['T'] + value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + else: # unbalanced AND regularized OT raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) diff --git a/test/test_solvers.py b/test/test_solvers.py index 0624a426e..220b6e554 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -18,7 +18,7 @@ lst_unbalanced_type = ['KL', 'L2', 'TV'] lst_reg_type_gromov = ['entropy'] -lst_unbalanced_type_gromov = ['KL'] +lst_unbalanced_type_gromov = ['semirelaxed'] lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] @@ -112,7 +112,7 @@ def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): assert_allclose_sol(sol, solb) except NotImplementedError: - pass + pytest.skip("Not implemented") def test_solve_not_implemented(nx): @@ -142,8 +142,8 @@ def test_solve_gromov(nx): np.random.seed(0) - n_samples_s = 5 - n_samples_t = 10 + n_samples_s = 3 + n_samples_t = 5 Ca = np.random.rand(n_samples_s, n_samples_s) Ca = (Ca + Ca.T) / 2 @@ -181,8 +181,8 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha np.random.seed(0) - n_samples_s = 5 - n_samples_t = 10 + n_samples_s = 3 + n_samples_t = 5 Ca = np.random.rand(n_samples_s, n_samples_s) Ca = (Ca + Ca.T) / 2 @@ -210,4 +210,4 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: - pass + pytest.skip("Not implemented") From c38047fb1c15dcd48cc860a685f68713932ebabe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 16:44:33 +0200 Subject: [PATCH 05/13] al is working and is tested --- ot/__init__.py | 2 +- ot/solvers.py | 51 +++++++++++--------------------------------- test/test_solvers.py | 15 +++++++------ 3 files changed, 21 insertions(+), 47 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index 95eccf56a..f16b6fcfc 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -65,7 +65,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/solvers.py b/ot/solvers.py index 8a867170f..d49f33c96 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -461,6 +461,10 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} + if loss.lower() not in loss_dict.keys(): + raise (ValueError('Unknown GW loss="{}"'.format(loss))) + loss_fun = loss_dict[loss.lower()] + if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT @@ -473,7 +477,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 @@ -503,7 +507,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] @@ -520,7 +524,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 @@ -536,7 +540,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] @@ -564,7 +568,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if method is None: method = 'PGD' - value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 @@ -597,7 +601,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if method is None: method = 'PGD' - value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] @@ -618,7 +622,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_dict[loss.lower()], log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 @@ -632,7 +636,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_dict[loss.lower()], alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] @@ -643,37 +647,6 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) - # if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': - - # if max_iter is None: - # max_iter = 1000 - # if tol is None: - # tol = 1e-9 - - # plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) - - # value_linear = nx.sum(M * plan) - - # value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) - - # potentials = (log['logu'], log['logv']) - - # elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: - - # if max_iter is None: - # max_iter = 1000 - # if tol is None: - # tol = 1e-12 - - # plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) - - # value_linear = nx.sum(M * plan) - - # value = log['loss'] - - # else: - # raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) - res = OTResult(potentials=potentials, value=value, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) diff --git a/test/test_solvers.py b/test/test_solvers.py index 220b6e554..479b1b3cd 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -18,7 +18,8 @@ lst_unbalanced_type = ['KL', 'L2', 'TV'] lst_reg_type_gromov = ['entropy'] -lst_unbalanced_type_gromov = ['semirelaxed'] +lst_gw_losses = ['L2', 'KL'] +lst_unbalanced_type_gromov = ['KL', 'semirelaxed'] lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] @@ -176,8 +177,8 @@ def test_solve_gromov(nx): assert_allclose_sol(sol0_fgw, solx_fgw) -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type,alpha", itertools.product(lst_reg, lst_reg_type_gromov, lst_unbalanced_gromov, lst_unbalanced_type_gromov, lst_alpha)) -def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha): +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type,alpha,loss", itertools.product(lst_reg, lst_reg_type_gromov, lst_unbalanced_gromov, lst_unbalanced_type_gromov, lst_alpha, lst_gw_losses)) +def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha, loss): np.random.seed(0) @@ -197,14 +198,14 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha try: - sol0 = ot.solve_gromov(Ca, Cb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) # GW - sol0_fgw = ot.solve_gromov(Ca, Cb, M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha) # FGW + sol0 = ot.solve_gromov(Ca, Cb, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW + sol0_fgw = ot.solve_gromov(Ca, Cb, M, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW # solve in backend ax, bx, Mx, Cax, Cbx = nx.from_numpy(a, b, M, Ca, Cb) - solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type) # GW - solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha) # FGW + solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW + solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW assert_allclose_sol(sol0, solx) assert_allclose_sol(sol0_fgw, solx_fgw) From 91f54dcbc9bcb3fdc6bad62fd511dd44774b87ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 17:19:11 +0200 Subject: [PATCH 06/13] documentation for solve_gromov --- ot/solvers.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) diff --git a/ot/solvers.py b/ot/solvers.py index d49f33c96..02da6ab1c 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -440,6 +440,115 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, The following methods are available for solving the Gromov-Wasserstein problem: + - **Classical Gromov-Wasserstein (GW) problem :ref:`[3] `** (default parameters): + + .. math:: + \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + res = ot.solve_gromov(Ca, Cb) # uniform weights + res = ot.solve_gromov(Ca, Cb, a=a, b=) # given weights + res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss + + plan = res.plan # GW plan + value = res.value # GW value + + - **Fused Gromov-Wasserstein (FGW) problem ref:`[24] `** (when ``M!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value + + - **Regularized (Fused) Gromov-Wasserstein (GW) problem ref:`[12] `** (when ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) + res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy + + plan = res.plan # FGW plan + loss_linear_term = res.value_linear # Wasserstein part of the loss + loss_quad_term = res.value_quad # Gromov part of the loss + loss = res.value # FGW value (including regularization) + + - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) problemref:`[48] `** (when ``unbalanced='semirelaxed'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T} \geq 0 + + can be solved with the following code: + + .. code-block:: python + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW + res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW + + plan = res.plan # FGW plan + right_marginal = res.marginal_b # right marginal of the plan + + .. _references-solve-gromov: + References + ---------- + + .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric + approach to object matching. Foundations of computational mathematics, + 11(4), 417-487. + + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), + Gromov-Wasserstein averaging of kernel and distance matrices + International Conference on Machine Learning (ICML). + + .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. + (2019). Optimal Transport for structured data with application on graphs + Proceedings of the 36th International Conference on Machine Learning + (ICML). + + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, + Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and + applications on graphs. International Conference on Learning + Representations (ICLR), 2022. + """ # detect backend From 840af7de290abbab37e7ffd681b0db0cc96628c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 18:26:47 +0200 Subject: [PATCH 07/13] =?UTF-8?q?update=20documentaton=20and=20add=20?= =?UTF-8?q?=C3=A2rtial?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ot/solvers.py | 107 ++++++++++++++++++++++++++++++++++++++----- test/test_gromov.py | 20 ++++---- test/test_solvers.py | 4 +- 3 files changed, 108 insertions(+), 23 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 02da6ab1c..8bc8fed49 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -19,6 +19,7 @@ semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2) +from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 #, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 @@ -400,15 +401,16 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha : float, optional Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for - Gromov problem (when M is not provided). By default ``alpha=None` corresponds to to + Gromov problem (when M is not provided). By default ``alpha=None`` + corresponds to to ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT), Not implemented yet unbalanced_type : str, optional - Type of unbalanced penalization unction :math:`U` either "KL", "L2", - "TV", by default "KL" , Not implemented yet + Type of unbalanced penalization unction :math:`U` either "KL", "semirelaxed", + "partial", by default "KL" , Not implemented yet n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional @@ -440,10 +442,10 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, The following methods are available for solving the Gromov-Wasserstein problem: - - **Classical Gromov-Wasserstein (GW) problem :ref:`[3] `** (default parameters): + - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): .. math:: - \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} @@ -454,6 +456,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, can be solved with the following code: .. code-block:: python + res = ot.solve_gromov(Ca, Cb) # uniform weights res = ot.solve_gromov(Ca, Cb, a=a, b=) # given weights res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss @@ -461,11 +464,11 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, plan = res.plan # GW plan value = res.value # GW value - - **Fused Gromov-Wasserstein (FGW) problem ref:`[24] `** (when ``M!=None``): + - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} @@ -476,6 +479,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, can be solved with the following code: .. code-block:: python + res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha @@ -484,11 +488,11 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value - - **Regularized (Fused) Gromov-Wasserstein (GW) problem ref:`[12] `** (when ``reg!=None``): + - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} @@ -499,6 +503,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, can be solved with the following code: .. code-block:: python + res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy @@ -507,11 +512,11 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value (including regularization) - - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) problemref:`[48] `** (when ``unbalanced='semirelaxed'``): + - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + - \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} @@ -520,6 +525,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, can be solved with the following code: .. code-block:: python + res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW @@ -527,6 +533,27 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, plan = res.plan # FGW plan right_marginal = res.marginal_b # right marginal of the plan + - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + + s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} + + \mathbf{T}^T \mathbf{1} \leq \mathbf{b} + + \mathbf{T} \geq 0 + + \mathbf{1}^T\mathbf{T}\mathbf{1} = m + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 + + .. _references-solve-gromov: References ---------- @@ -549,6 +576,10 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, applications on graphs. International Conference on Learning Representations (ICLR), 2022. + .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport + with Applications on Positive-Unlabeled Learning, Advances in Neural + Information Processing Systems (NeurIPS), 2020. + """ # detect backend @@ -656,6 +687,33 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, plan = log['T'] # potentials = (log['u'], log['v']) TODO + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise(NotImplementedError('Partial FGW not implemented yet')) + elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) @@ -752,6 +810,33 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, plan = log['T'] value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) + elif unbalanced_type.lower() in ['partial']: # Partial OT + + if M is None: # Partial Gromov-Wasserstein problem + + if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise (ValueError('Partial GW mass given in reg is too large')) + if loss.lower() != 'l2': + raise (NotImplementedError('Partial GW only implemented with L2 loss')) + if symmetric is not None: + raise (NotImplementedError('Partial GW only implemented with symmetric=True')) + + # default values for solver + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-7 + + value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) + + value_quad = value + plan = log['T'] + # potentials = (log['u'], log['v']) TODO + + else: # partial FGW + + raise(NotImplementedError('Partial entropic FGW not implemented yet')) + else: # unbalanced AND regularized OT raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) diff --git a/test/test_gromov.py b/test/test_gromov.py index 846e69f2b..06f843a4a 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -17,7 +17,7 @@ def test_gromov(nx): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -80,7 +80,7 @@ def test_gromov(nx): def test_asymmetric_gromov(nx): - n_samples = 30 # nb samples + n_samples = 20 # nb samples rng = np.random.RandomState(0) C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) @@ -124,7 +124,7 @@ def test_asymmetric_gromov(nx): def test_gromov_dtype_device(nx): # setup - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -160,7 +160,7 @@ def test_gromov_dtype_device(nx): @pytest.mark.skipif(not tf, reason="tf not installed") def test_gromov_device_tf(): nx = ot.backend.TensorflowBackend() - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) @@ -192,7 +192,7 @@ def test_gromov_device_tf(): def test_gromov2_gradients(): - n_samples = 50 # nb samples + n_samples = 20 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -257,7 +257,7 @@ def test_gromov2_gradients(): def test_gw_helper_backend(nx): - n_samples = 20 # nb samples + n_samples = 10 # nb samples mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 1]]) @@ -301,7 +301,7 @@ def line_search(cost, G, deltaG, Mi, cost_G): pytest.param('unknown_loss', marks=pytest.mark.xfail(raises=ValueError)), ]) def test_gw_helper_validation(loss_fun): - n_samples = 20 # nb samples + n_samples = 10 # nb samples mu = np.array([0, 0]) cov = np.array([[1, 0], [0, 1]]) xs = ot.datasets.make_2D_samples_gauss(n_samples, mu, cov, random_state=0) @@ -548,7 +548,7 @@ def test_entropic_gromov_dtype_device(nx): @pytest.skip_backend("tf", reason="test very slow with tf backend") def test_entropic_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -613,7 +613,7 @@ def test_entropic_fgw(nx): def test_entropic_proximal_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples mu_s = np.array([0, 0]) cov_s = np.array([[1, 0], [0, 1]]) @@ -678,7 +678,7 @@ def test_entropic_proximal_fgw(nx): def test_asymmetric_entropic_fgw(nx): - n_samples = 10 # nb samples + n_samples = 5 # nb samples rng = np.random.RandomState(0) C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) diff --git a/test/test_solvers.py b/test/test_solvers.py index 479b1b3cd..3e6849a10 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -12,14 +12,14 @@ import ot -lst_reg = [None, 1.0] +lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] lst_unbalanced = [None, 0.9] lst_unbalanced_type = ['KL', 'L2', 'TV'] lst_reg_type_gromov = ['entropy'] lst_gw_losses = ['L2', 'KL'] -lst_unbalanced_type_gromov = ['KL', 'semirelaxed'] +lst_unbalanced_type_gromov = ['KL', 'semirelaxed', 'partial'] lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] From 0e1198dab6ced93af3694b4d01072963f7fb04cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 18:28:41 +0200 Subject: [PATCH 08/13] pep8 --- ot/solvers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index 8bc8fed49..3eeac9d7d 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -402,7 +402,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for Gromov problem (when M is not provided). By default ``alpha=None`` - corresponds to to + corresponds to to ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional From c5323cef6bf307e13d19f492d148e04acac9abf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 19 Oct 2023 18:31:25 +0200 Subject: [PATCH 09/13] pep8 --- ot/solvers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 3eeac9d7d..82e715f7d 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -712,7 +712,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, else: # partial FGW - raise(NotImplementedError('Partial FGW not implemented yet')) + raise (NotImplementedError('Partial FGW not implemented yet')) elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT @@ -835,7 +835,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, else: # partial FGW - raise(NotImplementedError('Partial entropic FGW not implemented yet')) + raise (NotImplementedError('Partial entropic FGW not implemented yet')) else: # unbalanced AND regularized OT From 11225652b044510c2ec849105cc23f18b203ce01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 20 Oct 2023 16:57:38 +0200 Subject: [PATCH 10/13] better tests + release file --- RELEASES.md | 1 + ot/solvers.py | 2 +- test/test_solvers.py | 43 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 3 ++- 4 files changed, 47 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 951b5f327..30c248044 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -8,6 +8,7 @@ + Callbacks for generalized conditional gradient in `ot.da.sinkhorn_l1l2_gl` are now vectorized to improve performance (PR #507) + The `linspace` method of the backends now has the `type_as` argument to convert to the same dtype and device. (PR #533) + The `convolutional_barycenter2d` and `convolutional_barycenter2d_debiased` functions now work with different devices.. (PR #533) ++ New API for Gromov-Wasserstein solvers with `ot.solve_gromov` function (PR #536) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/solvers.py b/ot/solvers.py index 82e715f7d..8f8f242a7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -602,7 +602,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} if loss.lower() not in loss_dict.keys(): - raise (ValueError('Unknown GW loss="{}"'.format(loss))) + raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) loss_fun = loss_dict[loss.lower()] if reg is None or reg == 0: # exact OT diff --git a/test/test_solvers.py b/test/test_solvers.py index 3e6849a10..f0f5b638f 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -207,8 +207,51 @@ def test_solve_gromov_grid(nx, reg, reg_type, unbalanced, unbalanced_type, alpha solx = ot.solve_gromov(Cax, Cbx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, loss=loss) # GW solx_fgw = ot.solve_gromov(Cax, Cbx, Mx, reg=reg, reg_type=reg_type, unbalanced=unbalanced, unbalanced_type=unbalanced_type, alpha=alpha, loss=loss) # FGW + solx.value_quad + assert_allclose_sol(sol0, solx) assert_allclose_sol(sol0_fgw, solx_fgw) except NotImplementedError: pytest.skip("Not implemented") + + +def test_solve_gromov_not_implemented(nx): + + np.random.seed(0) + + n_samples_s = 3 + n_samples_t = 5 + + Ca = np.random.rand(n_samples_s, n_samples_s) + Ca = (Ca + Ca.T) / 2 + + Cb = np.random.rand(n_samples_t, n_samples_t) + Cb = (Cb + Cb.T) / 2 + + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + M = np.random.rand(n_samples_s, n_samples_t) + + Ca, Cb, M, a, b = nx.from_numpy(Ca, Cb, M, a, b) + + # test not implemented and check raise + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, loss='weird loss') + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, unbalanced=1, unbalanced_type='cryptic divergence') + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, reg=1, reg_type='cryptic divergence') + + # detect partial not implemented and error detect in value + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=1.5) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, M, unbalanced_type='partial', unbalanced=0.5) + with pytest.raises(ValueError): + ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) + with pytest.raises(NotImplementedError): + ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) diff --git a/test/test_utils.py b/test/test_utils.py index 787fbe68a..40324518e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -386,7 +386,8 @@ def test_OTResult(): 'sparse_plan', 'status', 'value', - 'value_linear'] + 'value_linear', + 'value_quad'] for at in lst_attributes: with pytest.raises(NotImplementedError): getattr(res, at) From 71ad677cd8bc5909e387ac8f71507fd2bda771cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 23 Oct 2023 16:16:19 +0200 Subject: [PATCH 11/13] take comments into account$ --- ot/gromov/_bregman.py | 4 +--- ot/gromov/_gw.py | 4 +--- ot/solvers.py | 19 +++++++++---------- 3 files changed, 11 insertions(+), 16 deletions(-) diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index ddf34f793..146e82631 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -826,9 +826,7 @@ def entropic_fused_gromov_wasserstein2( logv['T'] = T lin_term = nx.sum(T * M) - gw_term = (logv['fgw_dist'] - (1 - alpha) * lin_term) / alpha - - logv['quad_loss'] = gw_term * alpha + logv['quad_loss'] = (logv['fgw_dist'] - (1 - alpha) * lin_term) logv['lin_loss'] = lin_term * (1 - alpha) if log: diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index c1d7ace75..a19a28052 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -584,9 +584,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', # compute separate terms for gradients and log lin_term = nx.sum(T * M) - gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha - - log_fgw['quad_loss'] = gw_term * alpha + log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term) log_fgw['lin_loss'] = lin_term * (1 - alpha) if loss_fun == 'square_loss': diff --git a/ot/solvers.py b/ot/solvers.py index 8f8f242a7..f22e72f96 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -59,11 +59,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional - Type of unbalanced penalization unction :math:`U` either "KL", "L2", "TV", by default "KL" + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) + Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional @@ -391,30 +391,29 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` symmetric : bool, optional Use symmetric version of the Gromov-Wasserstein problem, by default None - tests wether the matrices are symmetric or True/False to avoid the test. + tests whether the matrices are symmetric or True/False to avoid the test. reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional - Type of regularization :math:`R`, by default "entropic" (only used when + Type of regularization :math:`R`, by default "entropy" (only used when ``reg!=None``) alpha : float, optional Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for Gromov problem (when M is not provided). By default ``alpha=None`` - corresponds to to - ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused - Gromov-Wasserstein problem (``M!=None``) + corresponds to ``alpha=1`` for Gromov problem (``M==None``) and + ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT), Not implemented yet unbalanced_type : str, optional - Type of unbalanced penalization unction :math:`U` either "KL", "semirelaxed", - "partial", by default "KL" , Not implemented yet + Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", + "partial", by default "KL" but note that it is not implemented yet. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional - Maximum number of iteration, by default None (default values in each + Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None From e39476b9e9a06a8f236ab3ec177ce8651474457f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 23 Oct 2023 16:19:22 +0200 Subject: [PATCH 12/13] it should work now --- ot/gromov/_gw.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index a19a28052..d5e4c7f13 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -586,6 +586,7 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', lin_term = nx.sum(T * M) log_fgw['quad_loss'] = (fgw_dist - (1 - alpha) * lin_term) log_fgw['lin_loss'] = lin_term * (1 - alpha) + gw_term = log_fgw['quad_loss'] / alpha if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) From 5d78f60f706521aaf6e8b1a0f3a728399cc06766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 23 Oct 2023 16:41:27 +0200 Subject: [PATCH 13/13] last commets cedric --- ot/solvers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index f22e72f96..0313cf588 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -412,6 +412,10 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, "partial", by default "KL" but note that it is not implemented yet. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 + method : str, optional + Method for solving the problem, for entropic problems "PGD" is projected + gradient descent and "PPA" for proximal point, default None for + automatic selection ("PGD"). max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) @@ -457,7 +461,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, .. code-block:: python res = ot.solve_gromov(Ca, Cb) # uniform weights - res = ot.solve_gromov(Ca, Cb, a=a, b=) # given weights + res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss plan = res.plan # GW plan