diff --git a/RELEASES.md b/RELEASES.md index 915b5c34e..68ff9ea9b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -26,6 +26,7 @@ - Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572) - Create `ot/bregman/`repository (Issue #567, PR #569) - Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573) +- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576) ## 0.9.1 @@ -602,4 +603,4 @@ It provides the following solvers: * Optimal transport for domain adaptation with group lasso regularization * Conditional gradient and Generalized conditional gradient for regularized OT. -Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. +Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. \ No newline at end of file diff --git a/ot/gromov/_bregman.py b/ot/gromov/_bregman.py index 89abfdb95..9ff4169dc 100644 --- a/ot/gromov/_bregman.py +++ b/ot/gromov/_bregman.py @@ -457,17 +457,17 @@ def entropic_gromov_barycenters( Cprev = C if warmstartT: T = [entropic_gromov_wasserstein( - Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, T[s], + C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] else: T = [entropic_gromov_wasserstein( - Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None, + C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) + C = update_square_loss(p, lambdas, T, Cs, nx) elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) + C = update_kl_loss(p, lambdas, T, Cs, nx) if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -962,9 +962,9 @@ def entropic_fused_gromov_barycenters( Y = init_Y if warmstartT: - T = [nx.outer(p_, p) for p_ in ps] + T = [None] * S - Ms = [dist(Ys[s], Y) for s in range(len(Ys))] + Ms = [dist(Y, Ys[s]) for s in range(len(Ys))] cpt = 0 err = 1 @@ -984,23 +984,22 @@ def entropic_fused_gromov_barycenters( if warmstartT: T = [entropic_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha, + Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha, T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] else: T = [entropic_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha, + Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha, None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)] if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) + C = update_square_loss(p, lambdas, T, Cs, nx) elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) + C = update_kl_loss(p, lambdas, T, Cs, nx) Ys_temp = [y.T for y in Ys] - T_temp = [Ts.T for Ts in T] - Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p).T - Ms = [dist(Ys[s], Y) for s in range(len(Ys))] + Y = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T + Ms = [dist(Y, Ys[s]) for s in range(len(Ys))] if cpt % 10 == 0: # we can speed up the process by checking for the error only all diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index aba5fa853..810b094c8 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -830,16 +830,18 @@ def gromov_barycenters( Cprev = C if warmstartT: - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s], - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + T = [gromov_wasserstein( + C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s], + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] else: - T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] + T = [gromov_wasserstein( + C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=None, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)] if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T, Cs) + C = update_square_loss(p, lambdas, T, Cs, nx) elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T, Cs) + C = update_kl_loss(p, lambdas, T, Cs, nx) if cpt % 10 == 0: # we can speed up the process by checking for the error only all @@ -898,14 +900,14 @@ def fgw_barycenters( If let to its default value None, uniform weights are taken. alpha : float, optional Alpha parameter for the fgw distance. - fixed_structure : bool - Whether to fix the structure of the barycenter during the updates - fixed_features : bool + fixed_structure : bool, optional + Whether to fix the structure of the barycenter during the updates. + fixed_features : bool, optional Whether to fix the feature of the barycenter during the updates p : array-like, shape (N,), optional Weights in the targeted barycenter. If let to its default value None, uniform distribution is taken. - loss_fun : str + loss_fun : str, optional Loss function used for the solver either 'square_loss' or 'kl_loss' symmetric : bool, optional Either structures are to be assumed symmetric or not. Default value is True. @@ -1024,19 +1026,18 @@ def fgw_barycenters( T = [fused_gromov_wasserstein( Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric, G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)] - # T is N,ns + if not fixed_features: Ys_temp = [y.T for y in Ys] - X = update_feature_matrix(lambdas, Ys_temp, T, p).T + X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T Ms = [dist(X, Ys[s]) for s in range(len(Ys))] if not fixed_structure: - T_temp = [t.T for t in T] if loss_fun == 'square_loss': - C = update_square_loss(p, lambdas, T_temp, Cs) + C = update_square_loss(p, lambdas, T, Cs, nx) elif loss_fun == 'kl_loss': - C = update_kl_loss(p, lambdas, T_temp, Cs) + C = update_kl_loss(p, lambdas, T, Cs, nx) err_feature = nx.norm(X - nx.reshape(Xprev, (N, d))) err_structure = nx.norm(C - Cprev) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 2c1bda823..c1b744333 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -253,10 +253,20 @@ def gwggrad(constC, hC1, hC2, T, nx=None): T, nx) # [12] Prop. 2 misses a 2 factor -def update_square_loss(p, lambdas, T, Cs): +def update_square_loss(p, lambdas, T, Cs, nx=None): r""" - Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` - couplings calculated at each iteration + Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` + :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW + barycenter problem in :ref:`[12]`: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters ---------- @@ -264,25 +274,33 @@ def update_square_loss(p, lambdas, T, Cs): Masses in the targeted barycenter. lambdas : list of float List of the `S` spaces' weights. - T : list of S array-like of shape (ns,N) + T : list of S array-like of shape (N, ns) The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. Cs : list of S array-like, shape(ns,ns) Metric cost matrices. + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ---------- C : array-like, shape (`nt`, `nt`) Updated :math:`\mathbf{C}` matrix. + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ - T = list_to_array(*T) - Cs = list_to_array(*Cs) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) + if nx is None: + nx = get_backend(p, *T, *Cs) + # Correct order mistake in Equation 14 in [12] tmpsum = sum([ lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] + nx.dot(T[s], Cs[s]), + T[s].T ) for s in range(len(T)) ]) ppt = nx.outer(p, p) @@ -290,9 +308,20 @@ def update_square_loss(p, lambdas, T, Cs): return tmpsum / ppt -def update_kl_loss(p, lambdas, T, Cs): +def update_kl_loss(p, lambdas, T, Cs, nx=None): r""" - Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration + Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` + :math:`\mathbf{T}_s` couplings calculated at each iteration of the GW + barycenter problem in :ref:`[12]`: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s) + + Where : + + - :math:`\mathbf{C}_s`: metric cost matrix + - :math:`\mathbf{p}_s`: distribution Parameters @@ -301,25 +330,33 @@ def update_kl_loss(p, lambdas, T, Cs): Weights in the targeted barycenter. lambdas : list of float List of the `S` spaces' weights - T : list of S array-like of shape (ns,N) + T : list of S array-like of shape (N, ns) The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration. Cs : list of S array-like, shape(ns,ns) Metric cost matrices. + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ---------- C : array-like, shape (`ns`, `ns`) updated :math:`\mathbf{C}` matrix + + References + ---------- + .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, + "Gromov-Wasserstein averaging of kernel and distance matrices." + International Conference on Machine Learning (ICML). 2016. + """ - Cs = list_to_array(*Cs) - T = list_to_array(*T) - p = list_to_array(p) - nx = get_backend(p, *T, *Cs) + if nx is None: + nx = get_backend(p, *T, *Cs) + # Correct order mistake in Equation 15 in [12] tmpsum = sum([ lambdas[s] * nx.dot( - nx.dot(T[s].T, Cs[s]), - T[s] + nx.dot(T[s], nx.log(nx.maximum(Cs[s], 1e-15))), + T[s].T ) for s in range(len(T)) ]) ppt = nx.outer(p, p) @@ -327,7 +364,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def update_feature_matrix(lambdas, Ys, Ts, p): +def update_feature_matrix(lambdas, Ys, Ts, p, nx=None): r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. @@ -340,10 +377,12 @@ def update_feature_matrix(lambdas, Ys, Ts, p): masses in the targeted barycenter lambdas : list of float List of the `S` spaces' weights - Ts : list of S array-like, shape (ns,N) + Ts : list of S array-like, shape (N, ns) The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration Ys : list of S array-like, shape (d,ns) The features. + nx : backend, optional + If let to its default value None, a backend test will be conducted. Returns ------- @@ -357,10 +396,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p): "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. """ - p = list_to_array(p) - Ts = list_to_array(*Ts) - Ys = list_to_array(*Ys) - nx = get_backend(*Ys, *Ts, p) + if nx is None: + nx = get_backend(*Ys, *Ts, p) p = 1. / p tmpsum = sum([ diff --git a/test/test_gromov.py b/test/test_gromov.py index 78fcb96ad..15958604f 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -1441,20 +1441,29 @@ def test_fgw_barycenter(nx): p = ot.unif(n_samples) ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p) - - Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False, - fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345 + lambdas = [.5, .5] + Csb = [C1b, C2b] + Ysb = [ysb, ytb] + Xb, Cb, logb = ot.gromov.fgw_barycenters( + n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False, + fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, + random_state=12345, log=True ) + # test correspondance with utils function + recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb) + recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T + + np.testing.assert_allclose(Cb, recovered_Cb) + np.testing.assert_allclose(Xb, recovered_Xb) xalea = rng.randn(n_samples, 2) init_C = ot.dist(xalea, xalea) init_C /= init_C.max() init_Cb = nx.from_numpy(init_C) - with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None` + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None` Xb, Cb = ot.gromov.fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None, + n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None, alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-3 ) @@ -1471,7 +1480,7 @@ def test_fgw_barycenter(nx): init_X = rng.randn(n_samples, ys.shape[1]) init_Xb = nx.from_numpy(init_X) - with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None` + with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None` Xb, Cb, logb = ot.gromov.fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=True, init_X=None, @@ -1490,14 +1499,19 @@ def test_fgw_barycenter(nx): np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) # add test with 'kl_loss' - X, C = ot.gromov.fgw_barycenters( + X, C, log = ot.gromov.fgw_barycenters( n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss', - max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345 + max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, + random_state=12345, log=True ) np.testing.assert_allclose(C.shape, (n_samples, n_samples)) np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1])) + # test correspondance with utils function + recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['Ts_iter'][-1], [C1, C2]) + np.testing.assert_allclose(C, recovered_C) + def test_gromov_wasserstein_linear_unmixing(nx): n = 4