diff --git a/README.md b/README.md index 6a2cf15fc..266d847c2 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples): * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -198,6 +199,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) * [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -286,3 +288,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 [32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML). + +[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f8669..5a362cf33 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.title('Entropic Gromov Wasserstein') pl.show() + +############################################################################# +# +# Compute GW with a scalable stochastic method with any loss function +# ---------------------------------------------------------------------- + + +def loss(x, y): + return np.abs(x - y) + + +pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, + log=True) + +sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, + log=True) + +print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) +print('Variance estimated: ' + str(plog['gw_dist_std'])) +print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) +print('Variance estimated: ' + str(slog['gw_dist_std'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(pgw.toarray(), cmap='jet') +pl.title('Pointwise Gromov Wasserstein') + +pl.subplot(1, 2, 2) +pl.imshow(sgw, cmap='jet') +pl.title('Sampled Gromov Wasserstein') + +pl.show() diff --git a/ot/gromov.py b/ot/gromov.py index 8f457e9ec..a27217a9e 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -16,6 +16,10 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter from .optim import cg +from .lp import emd_1d, emd +from .utils import check_random_state + +from scipy.sparse import issparse def init_matrix(C1, C2, p, q, loss_fun='square_loss'): @@ -572,6 +576,378 @@ def df(G): return log['fgw_dist'] +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): + r""" + Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) + with a fixed transport plan T. + + The function gives an unbiased approximation of the following equation: + + .. math:: + GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - L : Loss function to account for the misfit between the similarity matrices + - T : Matrix with marginal p and q + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or ndarray, shape (ns, nt) + Transport plan matrix, either a sparse csr matrix or + nb_samples_p : int, optional + nb_samples_p is the number of samples (without replacement) along the first dimension of T. + nb_samples_q : int, optional + nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost. + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + generator = check_random_state(random_state) + + len_p = len(p) + len_q = len(q) + + # It is always better to sample from the biggest distribution first. + if len_p < len_q: + p, q = q, p + len_p, len_q = len_q, len_p + C1, C2 = C2, C1 + T = T.T + + if nb_samples_p is None: + if issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) + else: + nb_samples_p = len_p + else: + # The number of sample along the first dimension is without replacement. + nb_samples_p = min(nb_samples_p, len_p) + if nb_samples_q is None: + nb_samples_q = 1 + if std: + nb_samples_q = max(2, nb_samples_q) + + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) + + index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + + for i in range(nb_samples_p): + if issparse(T): + T_indexi = T[index_i[i], :].toarray()[0] + T_indexj = T[index_j[i], :].toarray()[0] + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) + + for n in range(nb_samples_q): + list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + + if std: + std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 + return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) + else: + return np.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. + This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold_plan : float, optional + Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = generator.choice(len_p, size=1, p=p) + index[1] = generator.choice(len_q, size=1, p=q) + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = generator.choice(len_p, size=1, p=p) + T_index0 = T[index[0], :].toarray()[0] + index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) + + if alpha == 1: + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + else: + new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values bellow the threshold are set to 0. + T.data[T.data < threshold_plan] = 0 + T.eliminate_zeros() + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False, random_state=generator) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = T.copy() + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T, + random_state=generator) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. + This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples_grad : int + Number of samples to approximate the gradient + epsilon : float + Weight of the Kullback-Leiber regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + len_p = len(p) + len_q = len(q) + + generator = check_random_state(random_state) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples_grad, int): + if nb_samples_grad > len_p: + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad + T = np.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) + + for cpt in range(max_iter): + index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = generator.choice(len_q, + size=nb_samples_grad_q, + p=T[index0_i, :] / T[index0_i, :].sum(), + replace=False) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), + np.expand_dims(C2[:, index1], 0)), + axis=2) + else: + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), + np.expand_dims(C2[index1, :], 1)), + axis=0) + + max_Lik = np.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. + log_T = np.log(np.clip(T, np.exp(-200), 1)) + log_T[log_T == -200] = -np.inf + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = ((T - new_T) ** 2).mean() + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = new_T.copy() + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = new_T.copy() + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, random_state=generator) + return T, log + return T + + def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a887..19d61b104 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,89 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) + + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +262,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +279,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(