From 5a19690cd50da039f65710270d0871829d0d1eda Mon Sep 17 00:00:00 2001 From: Tanguy Kerdoncuff Date: Sat, 10 Jul 2021 19:54:32 +0200 Subject: [PATCH 1/5] =?UTF-8?q?Add=20two=20new=20algorithms=20to=20solve?= =?UTF-8?q?=20Gromov=20Wasserstein:=C2=A0Sampled=20Gromov=20Wasserstein=20?= =?UTF-8?q?and=20Pointwise=20Gromov=20Wasserstein.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 4 + examples/gromov/plot_gromov.py | 34 ++++ ot/gromov.py | 345 +++++++++++++++++++++++++++++++++ test/test_gromov.py | 91 ++++++++- 4 files changed, 468 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 20e060620..90c99b6bf 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ POT provides the following generic OT solvers (links to examples): * [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) +* [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] @@ -198,6 +199,7 @@ The contributors to this library are * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) * [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) +* [Tanguy Kerdoncuff](https://hv0nnus.github.io/) (Sampled Gromov Wasserstein) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -283,3 +285,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. [31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[33] Kerdoncuff T., Emonet R., Marc S. [Sampled Gromov Wasserstein](https://hal.archives-ouvertes.fr/hal-03232509/document), Machine Learning Journal (MJL), 2021 diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index deb2f8669..5e2028911 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -104,3 +104,37 @@ pl.title('Entropic Gromov Wasserstein') pl.show() + +############################################################################# +# +# Compute GW with a scalable stochastic method with any loss function +# ---------------------------------------------------------------------- + + +def loss(x,y): + return np.abs(x - y) + + +pgw, plog = ot.gromov.pointwise_gromov_wasserstein(C1, C2, p, q, loss, max_iter=100, + log=True) + +sgw, slog = ot.gromov.sampled_gromov_wasserstein(C1, C2, p, q, loss, epsilon=0.1, max_iter=100, + log=True) + +print('Pointwise Gromov-Wasserstein distance estimated: ' + str(plog['gw_dist_estimated'])) +print('Variance estimated: ' + str(plog['gw_dist_std'])) +print('Sampled Gromov-Wasserstein distance: ' + str(slog['gw_dist_estimated'])) +print('Variance estimated: ' + str(slog['gw_dist_std'])) + + +pl.figure(1, (10, 5)) + +pl.subplot(1, 2, 1) +pl.imshow(pgw.toarray(), cmap='jet') +pl.title('Pointwise Gromov Wasserstein') + +pl.subplot(1, 2, 2) +pl.imshow(sgw, cmap='jet') +pl.title('Sampled Gromov Wasserstein') + +pl.show() diff --git a/ot/gromov.py b/ot/gromov.py index 8f457e9ec..1a49174b3 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -16,6 +16,8 @@ from .bregman import sinkhorn from .utils import dist, UndefinedParameter from .optim import cg +from .lp import emd_1d, emd +from scipy.sparse import issparse def init_matrix(C1, C2, p, q, loss_fun='square_loss'): @@ -572,6 +574,349 @@ def df(G): return log['fgw_dist'] +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], std=True): + r""" + Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) + with a fixed transport plan T. + + The function gives an unbiased approximation of the following equation: + + .. math:: + GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - L : Loss function to account for the misfit between the similarity matrices + - T : Matrix with marginal p and q + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + T : csr or ndarray, shape (ns, nt) + Transport plan matrix, either a sparse csr matrix or + number_sample : (int, int), optional + number_sample[0] * number_sample[1] is the total number of samples of each of the two matrix T. + std : bool, optional + Standard deviation associated with the prediction of the gromov-wasserstein cost. + + Returns + ------- + : float + Gromov-wasserstein cost + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + + # It is always better to sample from the biggest distribution first. + if len(p) < len(q): + p, q = q, p + C1, C2 = C2, C1 + T = T.T + + if number_sample[0] < 0: + if issparse(T): + # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced + number_sample[0] = min(int(5*(len(p) * np.log(len(p))) ** 0.5), len(p)) + else: + number_sample[0] = len(p) + else: + # The number of sample along the first dimension is without replacement. + number_sample[0] = min(number_sample[0], len(p)) + if number_sample[1] < 0: + number_sample[1] = 1 + if std: + number_sample[1] = max(2, number_sample[1]) + + index_k = np.zeros((number_sample[0], number_sample[1]), dtype=int) + index_l = np.zeros((number_sample[0], number_sample[1]), dtype=int) + list_value_sample = np.zeros((number_sample[0], number_sample[0], number_sample[1])) + + index_i = np.random.choice(len(p), size=number_sample[0], p=p, replace=False) + index_j = np.random.choice(len(p), size=number_sample[0], p=p, replace=False) + + for i in range(number_sample[0]): + if issparse(T): + T_indexi = T[index_i[i], :].toarray()[0] + T_indexj = T[index_j[i], :].toarray()[0] + else: + T_indexi = T[index_i[i], :] + T_indexj = T[index_j[i], :] + # For each of the row sampled, the column is sampled. + index_k[i] = np.random.choice(len(q), size=number_sample[1], p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = np.random.choice(len(q), size=number_sample[1], p=T_indexj / T_indexj.sum(), replace=True) + + for n in range(number_sample[1]): + list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) + + if std: + std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 + return np.mean(list_value_sample), std_value / (number_sample[0] * number_sample[0]) + else: + return np.mean(list_value_sample) + + +def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, + alpha=1, max_iter=100, threshold=1e-20, log=False, verbose=False): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. + This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + alpha : float + Step of the Frank-Wolfe algorithm, should be between 0 and 1 + max_iter : int, optional + Max number of iterations + threshold : float, optional + Deleting very small value in the transport plan (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + + index = np.zeros(2, dtype=int) + + # Initialize with default marginal + index[0] = np.random.choice(len(p), size=1, p=p) + index[1] = np.random.choice(len(q), size=1, p=q) + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + + best_gw_dist_estimated = np.inf + for cpt in range(max_iter): + index[0] = np.random.choice(len(p), size=1, p=p) + T_index0 = T[index[0], :].toarray()[0] + index[1] = np.random.choice(len(q), size=1, p=T_index0 / T_index0.sum()) + + if alpha == 1: + T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + else: + new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() + T = (1 - alpha) * T + alpha * new_T + # To limit the number of non 0, the values bellow the threshold are set to 0. + T.data[T.data < threshold] = 0 + T.eliminate_zeros() + + if cpt % 10 == 0 or cpt == (max_iter - 1): + gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T, std=False) + + if gw_dist_estimated < best_gw_dist_estimated: + best_gw_dist_estimated = gw_dist_estimated + best_T = T.copy() + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', 'Best gw estimated') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, best_gw_dist_estimated)) + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=best_T) + return best_T, log + return best_T + + +def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, + nb_samples=100, epsilon=1, max_iter=500, log=False, verbose=False): + r""" + Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. + This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + + The function solves the following optimization problem: + + .. math:: + GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. T 1 = p + + T^T 1= q + + T\geq 0 + + Where : + - C1 : Metric cost matrix in the source space + - C2 : Metric cost matrix in the target space + - p : distribution in the source space + - q : distribution in the target space + - L : loss function to account for the misfit between the similarity matrices + + Parameters + ---------- + C1 : ndarray, shape (ns, ns) + Metric cost matrix in the source space + C2 : ndarray, shape (nt, nt) + Metric costfr matrix in the target space + p : ndarray, shape (ns,) + Distribution in the source space + q : ndarray, shape (nt,) + Distribution in the target space + loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} + Loss function used for the distance, the transport plan does not depend on the loss function + nb_samples : int + Number of sample to approximate the gradient + epsilon : float + Weight of the Kullback-Leiber regularization + max_iter : int, optional + Max number of iterations + verbose : bool, optional + Print information along iterations + log : bool, optional + Gives the distance estimated and the standard deviation + + Returns + ------- + T : ndarray, shape (ns, nt) + Optimal coupling between the two spaces + + References + ---------- + .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc + "Sampled Gromov Wasserstein." + Machine Learning Journal (MLJ). 2021. + """ + C1 = np.asarray(C1, dtype=np.float64) + C2 = np.asarray(C2, dtype=np.float64) + p = np.asarray(p, dtype=np.float64) + q = np.asarray(q, dtype=np.float64) + + # The most natural way to define nb_sample is with a simple integer. + if isinstance(nb_samples, int): + if nb_samples > len(p): + # As the sampling along the first dimension is done without replacement, the rest is reported to the second + # dimension. + nb_samples = [len(p), nb_samples // len(p)] + else: + nb_samples = [nb_samples, 1] + + T = np.outer(p, q) + # continue_loop allows to stop the loop if there is several successive small modification of T. + continue_loop = 0 + + # The gradient of GW is more complex if the two matrices are not symmetric. + C_are_symmetric = np.all(np.abs(C1-C1.T) < 1e-5) and np.all(np.abs(C2-C2.T) < 1e-5) + + for cpt in range(max_iter): + index0 = np.random.choice(len(p), size=nb_samples[0], p=p, replace=False) + Lik = 0 + for i, index0_i in enumerate(index0): + index1 = np.random.choice(len(q), + size=nb_samples[1], + p=T[index0_i, :] / T[index0_i, :].sum(), + replace=False) + # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. + if (not C_are_symmetric) and np.random.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples[1])], 1), + np.expand_dims(C2[:, index1], 0)), + axis=2) + else: + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples[1]), :], 2), + np.expand_dims(C2[index1, :], 1)), + axis=0) + + max_Lik = np.max(Lik) + if max_Lik == 0: + continue + # This division by the max is here to facilitate the choice of epsilon. + Lik /= max_Lik + + if epsilon > 0: + # Set to infinity all the numbers bellow exp(-200) to avoid log of 0. + log_T = np.log(np.clip(T, np.exp(-200), 1)) + log_T[log_T == -200] = -np.inf + Lik = Lik - epsilon * log_T + + try: + new_T = sinkhorn(a=p, b=q, M=Lik, reg=epsilon) + except (RuntimeWarning, UserWarning): + print("Warning catched in Sinkhorn: Return last stable T") + break + else: + new_T = emd(a=p, b=q, M=Lik) + + change_T = ((T - new_T) ** 2).mean() + if change_T <= 10e-20: + continue_loop += 1 + if continue_loop > 100: # Number max of low modifications of T + T = new_T.copy() + break + else: + continue_loop = 0 + + if verbose and cpt % 10 == 0: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format('It.', '||T_n - T_{n+1}||') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, change_T)) + T = new_T.copy() + + if log: + log = {} + log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, + p=p, q=q, T=T) + return T, log + return T + + def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" diff --git a/test/test_gromov.py b/test/test_gromov.py index 56414a887..18e373f21 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -33,7 +33,7 @@ def test_gromov(): G = ot.gromov.gromov_wasserstein(C1, C2, p, q, 'square_loss', verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -54,7 +54,7 @@ def test_gromov(): np.testing.assert_allclose(gw, gw_val, atol=1e-1, rtol=1e-1) # cf log=False - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -83,7 +83,7 @@ def test_entropic_gromov(): G = ot.gromov.entropic_gromov_wasserstein( C1, C2, p, q, 'square_loss', epsilon=5e-4, verbose=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( @@ -96,13 +96,92 @@ def test_entropic_gromov(): np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov +def test_pointwise_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + np.random.seed(42) + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, log=True) + + # check constraints + np.testing.assert_allclose( + p[:, np.newaxis], G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q[np.newaxis, :], G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.0 + assert log['gw_dist_std'] == 0.0 + + np.random.seed(42) + G, log = ot.gromov.pointwise_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True) + + assert log['gw_dist_estimated'] == 0.15115630200400282 + assert log['gw_dist_std'] == 0.0016271189171144765 + + +def test_sampled_gromov(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=42) + + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + def loss(x, y): + return np.abs(x - y) + + np.random.seed(42) + G, log = ot.gromov.sampled_gromov_wasserstein( + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True) + + # check constraints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + + assert log['gw_dist_estimated'] == 0.05679474884977278 + assert log['gw_dist_std'] == 0.0005986592106971995 + + def test_gromov_barycenter(): ns = 50 nt = 60 @@ -186,7 +265,7 @@ def test_fgw(): G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, log=True) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence fgw np.testing.assert_allclose( @@ -203,7 +282,7 @@ def test_fgw(): np.testing.assert_allclose(fgw, 0, atol=1e-1, rtol=1e-1) - # check constratints + # check constraints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( From 870107e18fb10aec8d58edd7b8f04d9c71afd4a0 Mon Sep 17 00:00:00 2001 From: Tanguy Kerdoncuff Date: Mon, 6 Sep 2021 17:14:29 +0200 Subject: [PATCH 2/5] Correct some lines in SaGroW and PoGroW to follow pep8 guide. --- examples/gromov/plot_gromov.py | 2 +- ot/gromov.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/gromov/plot_gromov.py b/examples/gromov/plot_gromov.py index 5e2028911..5a362cf33 100644 --- a/examples/gromov/plot_gromov.py +++ b/examples/gromov/plot_gromov.py @@ -111,7 +111,7 @@ # ---------------------------------------------------------------------- -def loss(x,y): +def loss(x, y): return np.abs(x - y) diff --git a/ot/gromov.py b/ot/gromov.py index 1a49174b3..cf2c8985c 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -631,7 +631,7 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], st if number_sample[0] < 0: if issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced - number_sample[0] = min(int(5*(len(p) * np.log(len(p))) ** 0.5), len(p)) + number_sample[0] = min(int(5 * (len(p) * np.log(len(p))) ** 0.5), len(p)) else: number_sample[0] = len(p) else: @@ -854,7 +854,7 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = np.all(np.abs(C1-C1.T) < 1e-5) and np.all(np.abs(C2-C2.T) < 1e-5) + C_are_symmetric = np.all(np.abs(C1 - C1.T) < 1e-5) and np.all(np.abs(C2 - C2.T) < 1e-5) for cpt in range(max_iter): index0 = np.random.choice(len(p), size=nb_samples[0], p=p, replace=False) From a8b3437776924d18de660e12b415ef939119b213 Mon Sep 17 00:00:00 2001 From: Tanguy Kerdoncuff Date: Thu, 9 Sep 2021 09:46:03 +0200 Subject: [PATCH 3/5] Change nb_samples name. Use rdm state. Change symmetric check. --- ot/gromov.py | 122 +++++++++++++++++++++++++++----------------- test/test_gromov.py | 13 ++--- 2 files changed, 80 insertions(+), 55 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index cf2c8985c..f3a7580ca 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -17,6 +17,8 @@ from .utils import dist, UndefinedParameter from .optim import cg from .lp import emd_1d, emd +from .utils import check_random_state + from scipy.sparse import issparse @@ -574,7 +576,8 @@ def df(G): return log['fgw_dist'] -def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], std=True): +def GW_distance_estimation(C1, C2, p, q, loss_fun, T, + nb_samples_p=None, nb_samples_q=None, std=True, random_state=None): r""" Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) with a fixed transport plan T. @@ -585,6 +588,7 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], st GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - L : Loss function to account for the misfit between the similarity matrices @@ -604,10 +608,14 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], st Loss function used for the distance, the transport plan does not depend on the loss function T : csr or ndarray, shape (ns, nt) Transport plan matrix, either a sparse csr matrix or - number_sample : (int, int), optional - number_sample[0] * number_sample[1] is the total number of samples of each of the two matrix T. + nb_samples_p : int, optional + nb_samples_p is the number of samples (without replacement) along the first dimension of T. + nb_samples_q : int, optional + nb_samples_q is the number of samples along the second dimension of T, for each sample along the first. std : bool, optional Standard deviation associated with the prediction of the gromov-wasserstein cost. + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility Returns ------- @@ -621,35 +629,38 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], st Machine Learning Journal (MLJ). 2021. """ + generator = check_random_state(random_state) + I = len(p) + K = len(q) # It is always better to sample from the biggest distribution first. - if len(p) < len(q): + if I < K: p, q = q, p C1, C2 = C2, C1 T = T.T - if number_sample[0] < 0: + if nb_samples_p is None: if issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced - number_sample[0] = min(int(5 * (len(p) * np.log(len(p))) ** 0.5), len(p)) + nb_samples_p = min(int(5 * (I * np.log(I)) ** 0.5), I) else: - number_sample[0] = len(p) + nb_samples_p = I else: # The number of sample along the first dimension is without replacement. - number_sample[0] = min(number_sample[0], len(p)) - if number_sample[1] < 0: - number_sample[1] = 1 + nb_samples_p = min(nb_samples_p, I) + if nb_samples_q is None: + nb_samples_q = 1 if std: - number_sample[1] = max(2, number_sample[1]) + nb_samples_q = max(2, nb_samples_q) - index_k = np.zeros((number_sample[0], number_sample[1]), dtype=int) - index_l = np.zeros((number_sample[0], number_sample[1]), dtype=int) - list_value_sample = np.zeros((number_sample[0], number_sample[0], number_sample[1])) + index_k = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) + list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) - index_i = np.random.choice(len(p), size=number_sample[0], p=p, replace=False) - index_j = np.random.choice(len(p), size=number_sample[0], p=p, replace=False) + index_i = generator.choice(I, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(I, size=nb_samples_p, p=p, replace=False) - for i in range(number_sample[0]): + for i in range(nb_samples_p): if issparse(T): T_indexi = T[index_i[i], :].toarray()[0] T_indexj = T[index_j[i], :].toarray()[0] @@ -657,21 +668,21 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, number_sample=[-1, -1], st T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] # For each of the row sampled, the column is sampled. - index_k[i] = np.random.choice(len(q), size=number_sample[1], p=T_indexi / T_indexi.sum(), replace=True) - index_l[i] = np.random.choice(len(q), size=number_sample[1], p=T_indexj / T_indexj.sum(), replace=True) + index_k[i] = generator.choice(K, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(K, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) - for n in range(number_sample[1]): + for n in range(nb_samples_q): list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) if std: std_value = np.sum(np.std(list_value_sample, axis=2) ** 2) ** 0.5 - return np.mean(list_value_sample), std_value / (number_sample[0] * number_sample[0]) + return np.mean(list_value_sample), std_value / (nb_samples_p * nb_samples_p) else: return np.mean(list_value_sample) def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, - alpha=1, max_iter=100, threshold=1e-20, log=False, verbose=False): + alpha=1, max_iter=100, threshold_plan=0, log=False, verbose=False, random_state=None): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. @@ -688,6 +699,7 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, T\geq 0 Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space @@ -710,12 +722,14 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, Step of the Frank-Wolfe algorithm, should be between 0 and 1 max_iter : int, optional Max number of iterations - threshold : float, optional - Deleting very small value in the transport plan (>0) + threshold_plan : float, optional + Deleting very small value in the transport plan. If above zero, it violate the marginal constraints. verbose : bool, optional Print information along iterations log : bool, optional Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility Returns ------- @@ -733,19 +747,23 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, C2 = np.asarray(C2, dtype=np.float64) p = np.asarray(p, dtype=np.float64) q = np.asarray(q, dtype=np.float64) + I = len(p) + K = len(q) + + generator = check_random_state(random_state) index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = np.random.choice(len(p), size=1, p=p) - index[1] = np.random.choice(len(q), size=1, p=q) + index[0] = generator.choice(I, size=1, p=p) + index[1] = generator.choice(K, size=1, p=q) T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = np.random.choice(len(p), size=1, p=p) + index[0] = generator.choice(I, size=1, p=p) T_index0 = T[index[0], :].toarray()[0] - index[1] = np.random.choice(len(q), size=1, p=T_index0 / T_index0.sum()) + index[1] = generator.choice(K, size=1, p=T_index0 / T_index0.sum()) if alpha == 1: T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() @@ -753,12 +771,12 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, new_T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() T = (1 - alpha) * T + alpha * new_T # To limit the number of non 0, the values bellow the threshold are set to 0. - T.data[T.data < threshold] = 0 + T.data[T.data < threshold_plan] = 0 T.eliminate_zeros() if cpt % 10 == 0 or cpt == (max_iter - 1): gw_dist_estimated = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T, std=False) + p=p, q=q, T=T, std=False, random_state=generator) if gw_dist_estimated < best_gw_dist_estimated: best_gw_dist_estimated = gw_dist_estimated @@ -772,13 +790,15 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, if log: log = {} log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=best_T) + p=p, q=q, T=best_T, + random_state=generator) return best_T, log return best_T def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, - nb_samples=100, epsilon=1, max_iter=500, log=False, verbose=False): + nb_samples_grad=100, epsilon=1, max_iter=500, log=False, verbose=False, + random_state=None): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. @@ -795,6 +815,7 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, T\geq 0 Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space @@ -813,8 +834,8 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, Distribution in the target space loss_fun : function: \mathcal{R} \times \mathcal{R} \shortarrow \mathcal{R} Loss function used for the distance, the transport plan does not depend on the loss function - nb_samples : int - Number of sample to approximate the gradient + nb_samples_grad : int + Number of samples to approximate the gradient epsilon : float Weight of the Kullback-Leiber regularization max_iter : int, optional @@ -823,6 +844,8 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, Print information along iterations log : bool, optional Gives the distance estimated and the standard deviation + random_state : int or RandomState instance, optional + Fix the seed for to allow reproducibility Returns ------- @@ -839,38 +862,43 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, C2 = np.asarray(C2, dtype=np.float64) p = np.asarray(p, dtype=np.float64) q = np.asarray(q, dtype=np.float64) + I = len(p) + K = len(q) + + generator = check_random_state(random_state) # The most natural way to define nb_sample is with a simple integer. - if isinstance(nb_samples, int): - if nb_samples > len(p): + if isinstance(nb_samples_grad, int): + if nb_samples_grad > I: # As the sampling along the first dimension is done without replacement, the rest is reported to the second # dimension. - nb_samples = [len(p), nb_samples // len(p)] + nb_samples_grad_p, nb_samples_grad_q = I, nb_samples_grad // I else: - nb_samples = [nb_samples, 1] - + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 + else: + nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad T = np.outer(p, q) # continue_loop allows to stop the loop if there is several successive small modification of T. continue_loop = 0 # The gradient of GW is more complex if the two matrices are not symmetric. - C_are_symmetric = np.all(np.abs(C1 - C1.T) < 1e-5) and np.all(np.abs(C2 - C2.T) < 1e-5) + C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): - index0 = np.random.choice(len(p), size=nb_samples[0], p=p, replace=False) + index0 = generator.choice(I, size=nb_samples_grad_p, p=p, replace=False) Lik = 0 for i, index0_i in enumerate(index0): - index1 = np.random.choice(len(q), - size=nb_samples[1], + index1 = generator.choice(K, + size=nb_samples_grad_q, p=T[index0_i, :] / T[index0_i, :].sum(), replace=False) # If the matrices C are not symmetric, the gradient has 2 terms, thus the term is chosen randomly. - if (not C_are_symmetric) and np.random.rand(1) > 0.5: - Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples[1])], 1), + if (not C_are_symmetric) and generator.rand(1) > 0.5: + Lik += np.mean(loss_fun(np.expand_dims(C1[:, np.repeat(index0[i], nb_samples_grad_q)], 1), np.expand_dims(C2[:, index1], 0)), axis=2) else: - Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples[1]), :], 2), + Lik += np.mean(loss_fun(np.expand_dims(C1[np.repeat(index0[i], nb_samples_grad_q), :], 2), np.expand_dims(C2[index1, :], 1)), axis=0) @@ -912,7 +940,7 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, if log: log = {} log["gw_dist_estimated"], log["gw_dist_std"] = GW_distance_estimation(C1=C1, C2=C2, loss_fun=loss_fun, - p=p, q=q, T=T) + p=p, q=q, T=T, random_state=generator) return T, log return T diff --git a/test/test_gromov.py b/test/test_gromov.py index 18e373f21..19d61b104 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -125,9 +125,8 @@ def test_pointwise_gromov(): def loss(x, y): return np.abs(x - y) - np.random.seed(42) G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, log=True) + C1, C2, p, q, loss, max_iter=100, log=True, verbose=True, random_state=42) # check constraints np.testing.assert_allclose( @@ -138,12 +137,11 @@ def loss(x, y): assert log['gw_dist_estimated'] == 0.0 assert log['gw_dist_std'] == 0.0 - np.random.seed(42) G, log = ot.gromov.pointwise_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True) + C1, C2, p, q, loss, max_iter=100, alpha=0.1, log=True, verbose=True, random_state=42) - assert log['gw_dist_estimated'] == 0.15115630200400282 - assert log['gw_dist_std'] == 0.0016271189171144765 + assert log['gw_dist_estimated'] == 0.10342276348494964 + assert log['gw_dist_std'] == 0.0015952535464736394 def test_sampled_gromov(): @@ -168,9 +166,8 @@ def test_sampled_gromov(): def loss(x, y): return np.abs(x - y) - np.random.seed(42) G, log = ot.gromov.sampled_gromov_wasserstein( - C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True) + C1, C2, p, q, loss, max_iter=100, epsilon=1, log=True, verbose=True, random_state=42) # check constraints np.testing.assert_allclose( From 12897bbd3ade8b200c49aa6e4af5e202da909d19 Mon Sep 17 00:00:00 2001 From: Tanguy Kerdoncuff Date: Thu, 9 Sep 2021 16:26:38 +0200 Subject: [PATCH 4/5] Change names of len(p) and len(q) in SaGroW and PoGroW. --- ot/gromov.py | 185 +++++++-------------------------------------------- 1 file changed, 25 insertions(+), 160 deletions(-) diff --git a/ot/gromov.py b/ot/gromov.py index f3a7580ca..24dd1e10a 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -24,31 +24,25 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation - Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss function as the loss function of Gromow-Wasserstein discrepancy. - The matrices are computed as described in Proposition 1 in [12] - Where : * C1 : Metric cost matrix in the source space * C2 : Metric cost matrix in the target space * T : A coupling between those two spaces - The square-loss function L(a,b)=|a-b|^2 is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : * f1(a)=(a^2) * f2(b)=(b^2) * h1(a)=a * h2(b)=2*b - The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : * f1(a)=a*log(a)-a * f2(b)=b * h1(a)=a * h2(b)=log(b) - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -58,7 +52,6 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): T : ndarray, shape (ns, nt) Coupling between source and target spaces p : ndarray, shape (ns,) - Returns ------- constC : ndarray, shape (ns, nt) @@ -67,13 +60,11 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): h1(C1) matrix in Eq. (6) hC2 : ndarray, shape (nt, nt) h2(C) matrix in Eq. (6) - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ if loss_fun == 'square_loss': @@ -114,9 +105,7 @@ def h2(b): def tensor_product(constC, hC1, hC2, T): r"""Return the tensor for Gromov-Wasserstein fast computation - The tensor is computed as described in Proposition 1 Eq. (6) in [12]. - Parameters ---------- constC : ndarray, shape (ns, nt) @@ -125,18 +114,15 @@ def tensor_product(constC, hC1, hC2, T): h1(C1) matrix in Eq. (6) hC2 : ndarray, shape (nt, nt) h2(C) matrix in Eq. (6) - Returns ------- tens : ndarray, shape (ns, nt) \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ A = -np.dot(hC1, T).dot(hC2.T) tens = constC + A @@ -146,9 +132,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): """Return the Loss for Gromov-Wasserstein - The loss is computed as described in Proposition 1 Eq. (6) in [12]. - Parameters ---------- constC : ndarray, shape (ns, nt) @@ -159,18 +143,15 @@ def gwloss(constC, hC1, hC2, T): h2(C) matrix in Eq. (6) T : ndarray, shape (ns, nt) Current value of transport matrix T - Returns ------- loss : float Gromov Wasserstein loss - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ tens = tensor_product(constC, hC1, hC2, T) @@ -180,9 +161,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): """Return the gradient for Gromov-Wasserstein - The gradient is computed as described in Proposition 2 in [12]. - Parameters ---------- constC : ndarray, shape (ns, nt) @@ -193,18 +172,15 @@ def gwggrad(constC, hC1, hC2, T): h2(C) matrix in Eq. (6) T : ndarray, shape (ns, nt) Current value of transport matrix T - Returns ------- grad : ndarray, shape (ns, nt) Gromov Wasserstein gradient - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ return 2 * tensor_product(constC, hC1, hC2, T) # [12] Prop. 2 misses a 2 factor @@ -214,7 +190,6 @@ def update_square_loss(p, lambdas, T, Cs): """ Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration - Parameters ---------- p : ndarray, shape (N,) @@ -225,7 +200,6 @@ def update_square_loss(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape(ns,ns) Metric cost matrices. - Returns ---------- C : ndarray, shape (nt, nt) @@ -241,8 +215,6 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): """ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration - - Parameters ---------- p : ndarray, shape (N,) @@ -252,7 +224,6 @@ def update_kl_loss(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape(ns,ns) Metric cost matrices. - Returns ---------- C : ndarray, shape (ns,ns) @@ -268,19 +239,15 @@ def update_kl_loss(p, lambdas, T, Cs): def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) - The function solves the following optimization problem: - .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -293,7 +260,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs Distribution in the target space loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' - max_iter : int, optional Max number of iterations tol : float, optional @@ -307,7 +273,6 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs If there is convergence issues use False. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver - Returns ------- T : ndarray, shape (ns, nt) @@ -315,17 +280,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} log : dict Convergence information and loss. - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. - """ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -349,19 +311,15 @@ def df(G): def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) - The function solves the following optimization problem: - .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -385,24 +343,20 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. - Returns ------- gw_dist : float Gromov-Wasserstein distance log : dict convergence information and Coupling marix - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. - """ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -426,22 +380,17 @@ def df(G): def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" Computes the FGW transport between two graphs see [24] - .. math:: \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - s.t. \gamma 1 = p \gamma^T 1= q \gamma\geq 0 - where : - M is the (ns,nt) metric cost matrix - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in [24]_ - Parameters ---------- M : ndarray, shape (ns, nt) @@ -465,14 +414,12 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, record log if True **kwargs : dict parameters can be directly passed to the ot.optim.cg solver - Returns ------- gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. - References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -502,22 +449,17 @@ def df(G): def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" Computes the FGW distance between two graphs see [24] - .. math:: \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - - s.t. \gamma 1 = p \gamma^T 1= q \gamma\geq 0 - where : - M is the (ns,nt) metric cost matrix - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ - Parameters ---------- M : ndarray, shape (ns, nt) @@ -541,14 +483,12 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Record log if True. **kwargs : dict Parameters can be directly pased to the ot.optim.cg solver. - Returns ------- gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. - References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -581,19 +521,14 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, r""" Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) with a fixed transport plan T. - The function gives an unbiased approximation of the following equation: - .. math:: GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - Where : - - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - L : Loss function to account for the misfit between the similarity matrices - T : Matrix with marginal p and q - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -616,38 +551,37 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, Standard deviation associated with the prediction of the gromov-wasserstein cost. random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility - Returns ------- : float Gromov-wasserstein cost - References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc "Sampled Gromov Wasserstein." Machine Learning Journal (MLJ). 2021. - """ generator = check_random_state(random_state) - I = len(p) - K = len(q) + len_p = len(p) + len_q = len(q) + # It is always better to sample from the biggest distribution first. - if I < K: + if len_p < len_q: p, q = q, p + len_p, len_q = len_q, len_p C1, C2 = C2, C1 T = T.T if nb_samples_p is None: if issparse(T): # If T is sparse, it probably mean that PoGroW was used, thus the number of sample is reduced - nb_samples_p = min(int(5 * (I * np.log(I)) ** 0.5), I) + nb_samples_p = min(int(5 * (len_p * np.log(len_p)) ** 0.5), len_p) else: - nb_samples_p = I + nb_samples_p = len_p else: # The number of sample along the first dimension is without replacement. - nb_samples_p = min(nb_samples_p, I) + nb_samples_p = min(nb_samples_p, len_p) if nb_samples_q is None: nb_samples_q = 1 if std: @@ -657,8 +591,8 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, index_l = np.zeros((nb_samples_p, nb_samples_q), dtype=int) list_value_sample = np.zeros((nb_samples_p, nb_samples_p, nb_samples_q)) - index_i = generator.choice(I, size=nb_samples_p, p=p, replace=False) - index_j = generator.choice(I, size=nb_samples_p, p=p, replace=False) + index_i = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) + index_j = generator.choice(len_p, size=nb_samples_p, p=p, replace=False) for i in range(nb_samples_p): if issparse(T): @@ -668,8 +602,8 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, T_indexi = T[index_i[i], :] T_indexj = T[index_j[i], :] # For each of the row sampled, the column is sampled. - index_k[i] = generator.choice(K, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) - index_l[i] = generator.choice(K, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) + index_k[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexi / T_indexi.sum(), replace=True) + index_l[i] = generator.choice(len_q, size=nb_samples_q, p=T_indexj / T_indexj.sum(), replace=True) for n in range(nb_samples_q): list_value_sample[:, :, n] = loss_fun(C1[np.ix_(index_i, index_j)], C2[np.ix_(index_k[:, n], index_l[:, n])]) @@ -686,26 +620,18 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. - The function solves the following optimization problem: - .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - s.t. T 1 = p - T^T 1= q - T\geq 0 - Where : - - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -730,40 +656,37 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, Gives the distance estimated and the standard deviation random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility - Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces - References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc "Sampled Gromov Wasserstein." Machine Learning Journal (MLJ). 2021. - """ C1 = np.asarray(C1, dtype=np.float64) C2 = np.asarray(C2, dtype=np.float64) p = np.asarray(p, dtype=np.float64) q = np.asarray(q, dtype=np.float64) - I = len(p) - K = len(q) + len_p = len(p) + len_q = len(q) generator = check_random_state(random_state) index = np.zeros(2, dtype=int) # Initialize with default marginal - index[0] = generator.choice(I, size=1, p=p) - index[1] = generator.choice(K, size=1, p=q) + index[0] = generator.choice(len_p, size=1, p=p) + index[1] = generator.choice(len_q, size=1, p=q) T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() best_gw_dist_estimated = np.inf for cpt in range(max_iter): - index[0] = generator.choice(I, size=1, p=p) + index[0] = generator.choice(len_p, size=1, p=p) T_index0 = T[index[0], :].toarray()[0] - index[1] = generator.choice(K, size=1, p=T_index0 / T_index0.sum()) + index[1] = generator.choice(len_q, size=1, p=T_index0 / T_index0.sum()) if alpha == 1: T = emd_1d(C1[index[0]], C2[index[1]], a=p, b=q, dense=False).tocsr() @@ -802,26 +725,18 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. - The function solves the following optimization problem: - .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} - s.t. T 1 = p - T^T 1= q - T\geq 0 - Where : - - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -846,12 +761,10 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, Gives the distance estimated and the standard deviation random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility - Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces - References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc @@ -862,17 +775,17 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, C2 = np.asarray(C2, dtype=np.float64) p = np.asarray(p, dtype=np.float64) q = np.asarray(q, dtype=np.float64) - I = len(p) - K = len(q) + len_p = len(p) + len_q = len(q) generator = check_random_state(random_state) # The most natural way to define nb_sample is with a simple integer. if isinstance(nb_samples_grad, int): - if nb_samples_grad > I: + if nb_samples_grad > len_p: # As the sampling along the first dimension is done without replacement, the rest is reported to the second # dimension. - nb_samples_grad_p, nb_samples_grad_q = I, nb_samples_grad // I + nb_samples_grad_p, nb_samples_grad_q = len_p, nb_samples_grad // len_p else: nb_samples_grad_p, nb_samples_grad_q = nb_samples_grad, 1 else: @@ -885,10 +798,10 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, C_are_symmetric = np.allclose(C1, C1.T, rtol=1e-10, atol=1e-10) and np.allclose(C2, C2.T, rtol=1e-10, atol=1e-10) for cpt in range(max_iter): - index0 = generator.choice(I, size=nb_samples_grad_p, p=p, replace=False) + index0 = generator.choice(len_p, size=nb_samples_grad_p, p=p, replace=False) Lik = 0 for i, index0_i in enumerate(index0): - index1 = generator.choice(K, + index1 = generator.choice(len_q, size=nb_samples_grad_q, p=T[index0_i, :] / T[index0_i, :].sum(), replace=False) @@ -949,20 +862,13 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) - (C1,p) and (C2,q) - The function solves the following optimization problem: - .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) - s.t. T 1 = p - T^T 1= q - T\geq 0 - Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space @@ -970,7 +876,6 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - H : entropy - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -993,18 +898,15 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Print information along iterations log : bool, optional Record log if True. - Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ C1 = np.asarray(C1, dtype=np.float64) @@ -1056,14 +958,10 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices - (C1,p) and (C2,q) - The function solves the following optimization problem: - .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) - Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space @@ -1071,7 +969,6 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - H : entropy - Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -1094,18 +991,15 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, Print information along iterations log : bool, optional Record log if True. - Returns ------- gw_dist : float Gromov-Wasserstein distance - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ gw, logv = entropic_gromov_wasserstein( C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True) @@ -1122,20 +1016,13 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices - (Cs)_{s=1}^{s=S} - The function solves the following optimization problem: - .. math:: C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s) - - Where : - - :math:`C_s` : metric cost matrix - :math:`p_s` : distribution - Parameters ---------- N : int @@ -1165,12 +1052,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Record log if True. init_C : bool | ndarray, shape (N, N) Random initial value for the C matrix provided by user. - Returns ------- C : ndarray, shape (N, N) Similarity matrix in the barycenter space (permutated arbitrarily) - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, @@ -1232,20 +1117,14 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices - (Cs)_{s=1}^{s=S} - The function solves the following optimization problem with block coordinate descent: - .. math:: C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) - Where : - - Cs : metric cost matrix - ps : distribution - Parameters ---------- N : int @@ -1271,18 +1150,15 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Record log if True. init_C : bool | ndarray, shape(N,N) Random initial value for the C matrix provided by user. - Returns ------- C : ndarray, shape (N, N) Similarity matrix in the barycenter space (permutated arbitrarily) - References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. - """ S = len(Cs) @@ -1338,7 +1214,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None): """Compute the fgw barycenter as presented eq (5) in [24]. - Parameters ---------- N : integer @@ -1373,7 +1248,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ init_X : ndarray, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. - Returns ------- X : ndarray, shape (N, d) @@ -1385,7 +1259,6 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T : list of (N,ns) transport matrices Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) - References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1489,9 +1362,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_sructure_matrix(p, lambdas, T, Cs): """Updates C according to the L2 Loss kernel with the S Ts couplings. - It is calculated at each iteration - Parameters ---------- p : ndarray, shape (N,) @@ -1502,7 +1373,6 @@ def update_sructure_matrix(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape (ns, ns) Metric cost matrices. - Returns ------- C : ndarray, shape (nt, nt) @@ -1516,11 +1386,8 @@ def update_sructure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): """Updates the feature with respect to the S Ts couplings. - - See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24] calculated at each iteration - Parameters ---------- p : ndarray, shape (N,) @@ -1531,11 +1398,9 @@ def update_feature_matrix(lambdas, Ys, Ts, p): the S Ts couplings calculated at each iteration Ys : list of S ndarray, shape(d,ns) The features. - Returns ------- X : ndarray, shape (d, N) - References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1547,4 +1412,4 @@ def update_feature_matrix(lambdas, Ys, Ts, p): tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))]) - return tmpsum + return tmpsum \ No newline at end of file From 10df3f660438ad25b9025118c40a94f3c23aeb74 Mon Sep 17 00:00:00 2001 From: Tanguy Kerdoncuff Date: Fri, 10 Sep 2021 11:13:08 +0200 Subject: [PATCH 5/5] Re-add some deleted lines in the comments of gromov.py --- ot/gromov.py | 140 ++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 139 insertions(+), 1 deletion(-) diff --git a/ot/gromov.py b/ot/gromov.py index 24dd1e10a..a27217a9e 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -24,25 +24,31 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): r"""Return loss matrices and tensors for Gromov-Wasserstein fast computation + Returns the value of \mathcal{L}(C1,C2) \otimes T with the selected loss function as the loss function of Gromow-Wasserstein discrepancy. + The matrices are computed as described in Proposition 1 in [12] + Where : * C1 : Metric cost matrix in the source space * C2 : Metric cost matrix in the target space * T : A coupling between those two spaces + The square-loss function L(a,b)=|a-b|^2 is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : * f1(a)=(a^2) * f2(b)=(b^2) * h1(a)=a * h2(b)=2*b + The kl-loss function L(a,b)=a*log(a/b)-a+b is read as : L(a,b) = f1(a)+f2(b)-h1(a)*h2(b) with : * f1(a)=a*log(a)-a * f2(b)=b * h1(a)=a * h2(b)=log(b) + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -52,6 +58,7 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): T : ndarray, shape (ns, nt) Coupling between source and target spaces p : ndarray, shape (ns,) + Returns ------- constC : ndarray, shape (ns, nt) @@ -60,11 +67,13 @@ def init_matrix(C1, C2, p, q, loss_fun='square_loss'): h1(C1) matrix in Eq. (6) hC2 : ndarray, shape (nt, nt) h2(C) matrix in Eq. (6) + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ if loss_fun == 'square_loss': @@ -105,7 +114,9 @@ def h2(b): def tensor_product(constC, hC1, hC2, T): r"""Return the tensor for Gromov-Wasserstein fast computation + The tensor is computed as described in Proposition 1 Eq. (6) in [12]. + Parameters ---------- constC : ndarray, shape (ns, nt) @@ -114,15 +125,18 @@ def tensor_product(constC, hC1, hC2, T): h1(C1) matrix in Eq. (6) hC2 : ndarray, shape (nt, nt) h2(C) matrix in Eq. (6) + Returns ------- tens : ndarray, shape (ns, nt) \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ A = -np.dot(hC1, T).dot(hC2.T) tens = constC + A @@ -132,7 +146,9 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): """Return the Loss for Gromov-Wasserstein + The loss is computed as described in Proposition 1 Eq. (6) in [12]. + Parameters ---------- constC : ndarray, shape (ns, nt) @@ -143,15 +159,18 @@ def gwloss(constC, hC1, hC2, T): h2(C) matrix in Eq. (6) T : ndarray, shape (ns, nt) Current value of transport matrix T + Returns ------- loss : float Gromov Wasserstein loss + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ tens = tensor_product(constC, hC1, hC2, T) @@ -161,7 +180,9 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): """Return the gradient for Gromov-Wasserstein + The gradient is computed as described in Proposition 2 in [12]. + Parameters ---------- constC : ndarray, shape (ns, nt) @@ -172,15 +193,18 @@ def gwggrad(constC, hC1, hC2, T): h2(C) matrix in Eq. (6) T : ndarray, shape (ns, nt) Current value of transport matrix T + Returns ------- grad : ndarray, shape (ns, nt) Gromov Wasserstein gradient + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ return 2 * tensor_product(constC, hC1, hC2, T) # [12] Prop. 2 misses a 2 factor @@ -190,6 +214,7 @@ def update_square_loss(p, lambdas, T, Cs): """ Updates C according to the L2 Loss kernel with the S Ts couplings calculated at each iteration + Parameters ---------- p : ndarray, shape (N,) @@ -200,6 +225,7 @@ def update_square_loss(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape(ns,ns) Metric cost matrices. + Returns ---------- C : ndarray, shape (nt, nt) @@ -215,6 +241,8 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): """ Updates C according to the KL Loss kernel with the S Ts couplings calculated at each iteration + + Parameters ---------- p : ndarray, shape (N,) @@ -224,6 +252,7 @@ def update_kl_loss(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape(ns,ns) Metric cost matrices. + Returns ---------- C : ndarray, shape (ns,ns) @@ -239,15 +268,19 @@ def update_kl_loss(p, lambdas, T, Cs): def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) + The function solves the following optimization problem: + .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -260,6 +293,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs Distribution in the target space loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss' + max_iter : int, optional Max number of iterations tol : float, optional @@ -273,6 +307,7 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs If there is convergence issues use False. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver + Returns ------- T : ndarray, shape (ns, nt) @@ -280,14 +315,17 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} log : dict Convergence information and loss. + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + """ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -311,15 +349,19 @@ def df(G): def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between (C1,p) and (C2,q) + The function solves the following optimization problem: + .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -343,20 +385,24 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg armijo : bool, optional If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False. + Returns ------- gw_dist : float Gromov-Wasserstein distance log : dict convergence information and Coupling marix + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + .. [13] Mémoli, Facundo. Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + """ constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) @@ -380,17 +426,22 @@ def df(G): def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" Computes the FGW transport between two graphs see [24] + .. math:: \gamma = arg\min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. \gamma 1 = p \gamma^T 1= q \gamma\geq 0 + where : - M is the (ns,nt) metric cost matrix - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices + The algorithm used for solving the problem is conditional gradient as discussed in [24]_ + Parameters ---------- M : ndarray, shape (ns, nt) @@ -414,12 +465,14 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, record log if True **kwargs : dict parameters can be directly passed to the ot.optim.cg solver + Returns ------- gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -449,17 +502,22 @@ def df(G): def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): r""" Computes the FGW distance between two graphs see [24] + .. math:: \min_\gamma (1-\\alpha)*<\gamma,M>_F + \\alpha* \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + + s.t. \gamma 1 = p \gamma^T 1= q \gamma\geq 0 + where : - M is the (ns,nt) metric cost matrix - p and q are source and target weights (sum to 1) - L is a loss function to account for the misfit between the similarity matrices The algorithm used for solving the problem is conditional gradient as discussed in [1]_ + Parameters ---------- M : ndarray, shape (ns, nt) @@ -483,12 +541,14 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 Record log if True. **kwargs : dict Parameters can be directly pased to the ot.optim.cg solver. + Returns ------- gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters. log : dict Log dictionary return only if log==True in parameters. + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -521,14 +581,19 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, r""" Returns an approximation of the gromov-wasserstein cost between (C1,p) and (C2,q) with a fixed transport plan T. + The function gives an unbiased approximation of the following equation: + .. math:: GW = \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - L : Loss function to account for the misfit between the similarity matrices - T : Matrix with marginal p and q + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -551,15 +616,18 @@ def GW_distance_estimation(C1, C2, p, q, loss_fun, T, Standard deviation associated with the prediction of the gromov-wasserstein cost. random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility + Returns ------- : float Gromov-wasserstein cost + References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc "Sampled Gromov Wasserstein." Machine Learning Journal (MLJ). 2021. + """ generator = check_random_state(random_state) @@ -620,18 +688,26 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a stochastic Frank-Wolfe. This method as a O(max_iter \times PN^2) time complexity with P the number of Sinkhorn iterations. + The function solves the following optimization problem: + .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. T 1 = p + T^T 1= q + T\geq 0 + Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -656,15 +732,18 @@ def pointwise_gromov_wasserstein(C1, C2, p, q, loss_fun, Gives the distance estimated and the standard deviation random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility + Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces + References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc "Sampled Gromov Wasserstein." Machine Learning Journal (MLJ). 2021. + """ C1 = np.asarray(C1, dtype=np.float64) C2 = np.asarray(C2, dtype=np.float64) @@ -725,18 +804,26 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) using a 1-stochastic Frank-Wolfe. This method as a O(max_iter \times Nlog(N)) time complexity by relying on the 1D Optimal Transport solver. + The function solves the following optimization problem: + .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l} + s.t. T 1 = p + T^T 1= q + T\geq 0 + Where : + - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space - p : distribution in the source space - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -761,15 +848,18 @@ def sampled_gromov_wasserstein(C1, C2, p, q, loss_fun, Gives the distance estimated and the standard deviation random_state : int or RandomState instance, optional Fix the seed for to allow reproducibility + Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces + References ---------- .. [14] Kerdoncuff, Tanguy, Emonet, Rémi, Sebban, Marc "Sampled Gromov Wasserstein." Machine Learning Journal (MLJ). 2021. + """ C1 = np.asarray(C1, dtype=np.float64) C2 = np.asarray(C2, dtype=np.float64) @@ -862,13 +952,20 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" Returns the gromov-wasserstein transport between (C1,p) and (C2,q) + (C1,p) and (C2,q) + The function solves the following optimization problem: + .. math:: GW = arg\min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + s.t. T 1 = p + T^T 1= q + T\geq 0 + Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space @@ -876,6 +973,7 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - H : entropy + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -898,15 +996,18 @@ def entropic_gromov_wasserstein(C1, C2, p, q, loss_fun, epsilon, Print information along iterations log : bool, optional Record log if True. + Returns ------- T : ndarray, shape (ns, nt) Optimal coupling between the two spaces + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ C1 = np.asarray(C1, dtype=np.float64) @@ -958,10 +1059,14 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False): r""" Returns the entropic gromov-wasserstein discrepancy between the two measured similarity matrices + (C1,p) and (C2,q) + The function solves the following optimization problem: + .. math:: GW = \min_T \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*T_{i,j}*T_{k,l}-\epsilon(H(T)) + Where : - C1 : Metric cost matrix in the source space - C2 : Metric cost matrix in the target space @@ -969,6 +1074,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, - q : distribution in the target space - L : loss function to account for the misfit between the similarity matrices - H : entropy + Parameters ---------- C1 : ndarray, shape (ns, ns) @@ -991,15 +1097,18 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, Print information along iterations log : bool, optional Record log if True. + Returns ------- gw_dist : float Gromov-Wasserstein distance + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ gw, logv = entropic_gromov_wasserstein( C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True) @@ -1016,13 +1125,20 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices + (Cs)_{s=1}^{s=S} + The function solves the following optimization problem: + .. math:: C = argmin_{C\in R^{NxN}} \sum_s \lambda_s GW(C,C_s,p,p_s) + + Where : + - :math:`C_s` : metric cost matrix - :math:`p_s` : distribution + Parameters ---------- N : int @@ -1052,10 +1168,12 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, Record log if True. init_C : bool | ndarray, shape (N, N) Random initial value for the C matrix provided by user. + Returns ------- C : ndarray, shape (N, N) Similarity matrix in the barycenter space (permutated arbitrarily) + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, @@ -1117,14 +1235,20 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, max_iter=1000, tol=1e-9, verbose=False, log=False, init_C=None): r""" Returns the gromov-wasserstein barycenters of S measured similarity matrices + (Cs)_{s=1}^{s=S} + The function solves the following optimization problem with block coordinate descent: + .. math:: C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + Where : + - Cs : metric cost matrix - ps : distribution + Parameters ---------- N : int @@ -1150,15 +1274,18 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, Record log if True. init_C : bool | ndarray, shape(N,N) Random initial value for the C matrix provided by user. + Returns ------- C : ndarray, shape (N, N) Similarity matrix in the barycenter space (permutated arbitrarily) + References ---------- .. [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, "Gromov-Wasserstein averaging of kernel and distance matrices." International Conference on Machine Learning (ICML). 2016. + """ S = len(Cs) @@ -1214,6 +1341,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None): """Compute the fgw barycenter as presented eq (5) in [24]. + Parameters ---------- N : integer @@ -1248,6 +1376,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ init_X : ndarray, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + Returns ------- X : ndarray, shape (N, d) @@ -1259,6 +1388,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ T : list of (N,ns) transport matrices Ms : all distance matrices between the feature of the barycenter and the other features dist(X,Ys) shape (N,ns) + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1362,7 +1492,9 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_sructure_matrix(p, lambdas, T, Cs): """Updates C according to the L2 Loss kernel with the S Ts couplings. + It is calculated at each iteration + Parameters ---------- p : ndarray, shape (N,) @@ -1373,6 +1505,7 @@ def update_sructure_matrix(p, lambdas, T, Cs): The S Ts couplings calculated at each iteration. Cs : list of S ndarray, shape (ns, ns) Metric cost matrices. + Returns ------- C : ndarray, shape (nt, nt) @@ -1386,8 +1519,11 @@ def update_sructure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): """Updates the feature with respect to the S Ts couplings. + + See "Solving the barycenter problem with Block Coordinate Descent (BCD)" in [24] calculated at each iteration + Parameters ---------- p : ndarray, shape (N,) @@ -1398,9 +1534,11 @@ def update_feature_matrix(lambdas, Ys, Ts, p): the S Ts couplings calculated at each iteration Ys : list of S ndarray, shape(d,ns) The features. + Returns ------- X : ndarray, shape (d, N) + References ---------- .. [24] Vayer Titouan, Chapel Laetitia, Flamary R{\'e}mi, Tavenard Romain @@ -1412,4 +1550,4 @@ def update_feature_matrix(lambdas, Ys, Ts, p): tmpsum = sum([lambdas[s] * np.dot(Ys[s], Ts[s].T) * p[None, :] for s in range(len(Ts))]) - return tmpsum \ No newline at end of file + return tmpsum