From 1c99c285d6716bbafdab44665294bb377f71b129 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 21 Jul 2024 21:16:57 +0200 Subject: [PATCH 01/22] init commit srgw bary --- ot/gromov/__init__.py | 7 +- ot/gromov/_semirelaxed.py | 178 ++++++++++++++++++++++++++++++++ test/gromov/test_semirelaxed.py | 107 +++++++++++++++++++ 3 files changed, 290 insertions(+), 2 deletions(-) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 675f42ccb..748b1a97b 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -42,6 +42,7 @@ entropic_semirelaxed_gromov_wasserstein2, entropic_semirelaxed_fused_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein2, + semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) from ._dictionary import (gromov_wasserstein_dictionary_learning, @@ -78,11 +79,13 @@ 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', 'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein', 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', - 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', + 'entropic_semirelaxed_fused_gromov_wasserstein2', + 'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters', + 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples', 'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition', 'get_graph_representants', 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', - 'semirelaxed_fgw_barycenters'] + ] diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index a777239d3..14401631c 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -1107,6 +1107,184 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( return log_srfgw['srfgw_dist'] +def semirelaxed_gromov_barycenters( + N, Cs, ps=None, lambdas=None, loss_fun='square_loss', + symmetric=True, max_iter=1000, tol=1e-9, + stop_criterion='barycenter', warmstartT=False, verbose=False, + log=False, init_C=None, random_state=None, **kwargs): + r""" + Returns the Semi-relaxed Gromov-Wasserstein barycenters of `S` measured similarity matrices + :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem with block coordinate descent: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{srGW}(\mathbf{C}_s, \mathbf{p}_s, \mathbf{C}) + + Where : + + - :math:`\mathbf{C}_s`: input metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns, ns) + Metric cost matrices + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + lambdas : list of float, optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional + tensor-matrix multiplication function based on specific loss function + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + stop_criterion : str, optional. Default is 'barycenter'. + Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter' + uses absolute norm variations of estimated barycenters. Else if set to 'loss' + uses the relative variations of the loss. + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems.s + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : bool | array-like, shape(N,N) + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Barycenters' structure matrix + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`\mathbf{p}`: (`N`,) barycenter weights + - values used in convergence evaluation. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + if stop_criterion not in ['barycenter', 'loss']: + raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + + arr = [*Cs] + if ps is not None: + arr += [*ps] + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + + nx = get_backend(*arr) + + S = len(Cs) + if lambdas is None: + lambdas = [1. / S] * S + + # Initialization of C : random SPD matrix (if not provided by user) + if init_C is None: + rng = check_random_state(random_state) + xalea = rng.randn(N, 2) + C = dist(xalea, xalea) + C /= C.max() + C = nx.from_numpy(C, type_as=Cs[0]) + else: + C = init_C + + if warmstartT: + T = [None] * S + + if stop_criterion == 'barycenter': + inner_log = False + else: + inner_log = True + curr_loss = 1e15 + + if log: + log_ = {} + log_['err'] = [] + if stop_criterion == 'loss': + log_['loss'] = [] + + for cpt in range(max_iter): + + if stop_criterion == 'barycenter': + Cprev = C + else: + prev_loss = curr_loss + + # get transport plans + if warmstartT: + res = [semirelaxed_gromov_wasserstein( + Cs[s], C, ps[s], loss_fun, symmetric, G0=T[s], + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, + verbose=verbose, **kwargs) + for s in range(S)] + else: + res = [semirelaxed_gromov_wasserstein( + Cs[s], C, ps[s], loss_fun, symmetric, G0=None, + max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, + verbose=verbose, **kwargs) + for s in range(S)] + + if stop_criterion == 'barycenter': + T = res + else: + T = [output[0] for output in res] + curr_loss = np.sum([output[1]['srgw_dist'] for output in res]) + + # update barycenters + p = nx.concatenate( + [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) + + C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) + + # update convergence criterion + if stop_criterion == 'barycenter': + err = nx.norm(C - Cprev) + if log: + log_['err'].append(err) + + else: + err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + if log: + log_['loss'].append(curr_loss) + log_['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if err <= tol: + break + + if log: + log_['T'] = T + log_['p'] = p + + return C, log_ + else: + return C + + def semirelaxed_fgw_barycenters( N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 2e4b2f128..8e6805d41 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -615,6 +615,113 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): nx.assert_same_dtype_device(C1b, fgw_valb) +def test_semirelaxed_gromov_barycenter(nx): + ns = 5 + nt = 8 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 3 + + C1b, C2b, p1b, p2b = nx.from_numpy(C1, C2, p1, p2) + + # test on admissible stopping criterion + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + _ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # test consistency of outputs across backends with 'square_loss' + for stop_criterion in ['barycenter', 'loss']: + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + warmstartT=True, random_state=42, log=True + ) + Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, warmstartT=True, random_state=42, log=True + ) + + Cbb_ = nx.to_numpy(Cbb_) + + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + + # test consistency across backends with 'kl_loss' + Cb2 = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], [.5, .5], + 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + ) + Cb2b = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], + 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + )) + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + # providing init_C similarly than in the function. + rng = ot.utils.check_random_state(42) + xalea = rng.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + Cb2_, err2_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=10, + tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C + ) + Cb2b_, err2b_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', + max_iter=10, tol=1e-3, verbose=True, random_state=42, + init_C=init_Cb, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + + # test edge cases for gw barycenters: + # unique input structure + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1], None, None, 'square_loss', max_iter=1, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b], None, [1.], 'square_loss', + max_iter=1, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + def test_semirelaxed_fgw_barycenter(nx): ns = 10 nt = 20 From 4de90e7a9e6deff99931bf107e751a2229c773e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 22 Jul 2024 00:29:50 +0200 Subject: [PATCH 02/22] start exemple --- ...mirelaxed_gromov_wasserstein_barycenter.py | 149 ++++++++++++++++++ ot/gromov/_semirelaxed.py | 19 +-- 2 files changed, 159 insertions(+), 9 deletions(-) create mode 100644 examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py new file mode 100644 index 000000000..cef4ba9a7 --- /dev/null +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -0,0 +1,149 @@ +# -*- coding: utf-8 -*- + +r""" +===================================================== +Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning +===================================================== + +In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein +(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of +structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}` +where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`. +Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes, +each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph +with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N` +where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT +:math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`) +minimizing the srGW loss between the s^{th} input and the barycenter. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters +varying in :math:`\{ 1, 2, 3\}`. We learn a srGW barycenter with 6 nodes and +visualize the learned structure and the embeddings for some inputs. + +Second, we illustrate the extension of this framework to graphs endowed +with node features by using the semi-relaxed Fused Gromov-Wasserstein +divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then conduct +the analog analysis. + + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs". +International Conference on Learning Representations (ICLR), 2022. + +""" +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 4 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import ( + semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ----------------------------------------------------------------------------------------------- + +np.random.seed(42) + +n_samples = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +clusters = [1, 2, 3] +Nc = n_samples // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + else: + P = p_intra * np.eye(1) + sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + G = sbm(sizes, P, seed=i, directed=False) + C = networkx.to_numpy_array(G) + dataset.append(C) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +############################################################################# +# +# Estimate the Gromov-Wasserstein dictionary from the dataset +# ----------------------------------------------------------- + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on nodes +lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter +N = 3 # 3 nodes in the barycenter + +# Optionally provide an initial barycenter structure `init_C` + +init_C = np.array([[0.6, 0.2, 0.2], + [0.2, 0.6, 0.2], + [0.2, 0.2, 0.6]]) + +print('init_C:', init_C) + +C, log = semirelaxed_gromov_barycenters( + N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss', tol=1e-6, + stop_criterion='loss', warmstartT=True, log=True, init_C=init_C) + +# visualize loss evolution over epochs +pl.figure(2, (4, 3)) +pl.clf() +pl.title('loss evolution by iteration', fontsize=14) +pl.plot(log['loss']) +pl.xlabel('BCD iterations', fontsize=12) +pl.ylabel('loss', fontsize=12) +pl.tight_layout() +pl.show() + +print('C:', C) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 14401631c..90fe0e9cb 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -11,13 +11,13 @@ import numpy as np -from ..utils import ( +from ot.utils import ( list_to_array, unif, dist, UndefinedParameter, check_random_state ) -from ..optim import semirelaxed_cg, solve_1d_linesearch_quad -from ..backend import get_backend +from ot.optim import semirelaxed_cg, solve_1d_linesearch_quad +from ot.backend import get_backend -from ._utils import ( +from ot.gromov import ( init_matrix_semirelaxed, gwloss, gwggrad, update_barycenter_structure, update_barycenter_feature, ) @@ -1233,13 +1233,13 @@ def semirelaxed_gromov_barycenters( if warmstartT: res = [semirelaxed_gromov_wasserstein( Cs[s], C, ps[s], loss_fun, symmetric, G0=T[s], - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, + max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) for s in range(S)] else: res = [semirelaxed_gromov_wasserstein( Cs[s], C, ps[s], loss_fun, symmetric, G0=None, - max_iter=max_iter, tol_rel=1e-5, tol_abs=0., log=inner_log, + max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) for s in range(S)] @@ -1254,7 +1254,8 @@ def semirelaxed_gromov_barycenters( [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) - + print('p:', p) + print('C:', C) # update convergence criterion if stop_criterion == 'barycenter': err = nx.norm(C - Cprev) @@ -1455,12 +1456,12 @@ def semirelaxed_fgw_barycenters( if warmstartT: res = [semirelaxed_fused_gromov_wasserstein( Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, T[s], - inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) for s in range(S)] else: res = [semirelaxed_fused_gromov_wasserstein( Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, None, - inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) for s in range(S)] if stop_criterion == 'barycenter': T = res From 7883c0e527f7cbaa646f57ac0245ae8da89a0a6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 25 Jul 2024 15:06:08 +0200 Subject: [PATCH 03/22] start integrating smart init for srgw --- ot/gromov/_semirelaxed.py | 9 +- ot/gromov/_utils.py | 176 +++++++++++++++++++++++++++++++++++++- 2 files changed, 180 insertions(+), 5 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 90fe0e9cb..43723084f 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -11,15 +11,16 @@ import numpy as np -from ot.utils import ( +from ..utils import ( list_to_array, unif, dist, UndefinedParameter, check_random_state ) -from ot.optim import semirelaxed_cg, solve_1d_linesearch_quad -from ot.backend import get_backend +from ..optim import semirelaxed_cg, solve_1d_linesearch_quad +from ..backend import get_backend -from ot.gromov import ( +from ._utils import ( init_matrix_semirelaxed, gwloss, gwggrad, update_barycenter_structure, update_barycenter_feature, + _semirelaxed_init_plan ) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index d4928d062..87e8301bc 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -12,8 +12,25 @@ # License: MIT License -from ..utils import list_to_array +from ..utils import list_to_array, euclidean_distances from ..backend import get_backend +from .lp import emd + +try: + from networkx.algorithms.community import asyn_fluidc + from networkx import from_numpy_array + networkx_import = True +except ImportError: + networkx_import = False + +try: + from sklearn.cluster import SpectralClustering, KMeans + sklearn_import = True +except ImportError: + sklearn_import = False + +import numpy as np +import warnings def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): @@ -302,6 +319,7 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): C2 : array-like, shape (nt, nt) Metric cost matrix in the target space p : array-like, shape (ns,) + Probability distribution in the source space loss_fun : str, optional Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') nx : backend, optional @@ -575,3 +593,159 @@ def update_barycenter_feature( inv_p = 1. / p_sum return sum(list_features) * inv_p[:, None] + + +def _semirelaxed_init_plan(C1, C2, p, method='product', random_state=0, nx=None): + """ + Heuristics to initialize the semi-relaxed (F)GW transport plan between a + graph :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`. + + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space. + p : array-like, shape (ns,) + Probability distribution in the source space. + method : str, optional + Method to initialize the transport plan. The default is 'product'. + random_state: int, optional + Random seed used for stochastic methods. + nx : backend, optional + POT backend. + + Returns + ------- + T : array-like, shape (ns, ns) + Admissible transport plan for the sr(F)GW problems. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + list_partitioning_methods = [ + 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', + 'kmeans_soft'] + + if method not in list_partitioning_methods + ['product', 'random_product']: + raise ValueError(f'Unsupported initialization method = {method}.') + + if nx is None: + C1, C2, p = list_to_array(C1, C2, p) + nx = get_backend(C1, C2, p) + + n = C1.shape[0] + m = C2.shape[0] + if method in list_partitioning_methods: + min_size = min(n, m) + if n > m: # partition C1 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C1) + elif m > n: # partition C2 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C2) + else: # equal size -> simple Wasserstein alignment + C_to_partition = None + warnings.warn( + "Both structures have the same size so no partitioning is" + "performed to initialize the transport plan even though" + f"initialization method is {method}", + stacklevel=2 + ) + + + def get_transport_from_partition(part): + if n > m: # partition C1 to deduce map from C1 to C2 + T_ = nx.zeros((n, m)) + T_[nx.arange(n), part] = 1. + T_ = p[:, None] * T_ + q = nx.sum(T_, 0) + + factored_C1 = nx.dot(nx.dot(T_.T, C1), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(factored_C1, C2) + T_emd = emd(q, q, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd) + elif m > n: + T_ = nx.zeros((m, n)) + T_[nx.arange(m), part] = 1. / m # assume uniform masses on C2 + q = nx.sum(T_, 0) + + factored_C2 = nx.dot(nx.dot(T_.T, C2), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(factored_C2, C1) + T_emd = emd(q, p, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd).T + q = nx.sum(T, 0) # uniform one + else: + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(C1, C2) + q = p + T = emd(p, q, M_structure) + + return T, q + + if method == 'product': + q = nx.ones(m, type_as=C2) + T = nx.outer(p, q) + + elif method == 'random_product': + q = np.random.uniform(0, m, size=(m,)) + q = q / q.sum() + q = nx.from_numpy(q) + T = nx.outer(p, q) + + elif method in ['fluid', 'fluid_soft']: + # compute fluid partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + graph = from_numpy_array(C_to_partition) + part_sets = asyn_fluidc(graph, min_size, seed=random_state) + part = nx.zeros(C_to_partition) + for iset_, set_ in enumerate(part_sets): + set_ = list(set_) + part[set_] = iset_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['spectral', 'spectral_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + sc = SpectralClustering(n_clusters=min_size, + random_state=random_state, + affinity='precomputed').fit(C_to_partition) + part = sc.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['kmeans', 'kmeans_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + km = KMeans(n_clusters=min_size, random_state=random_state, + n_init=1).fit(C_to_partition) + + part = km.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + return T \ No newline at end of file From eccc5b9ba8f3e3b77e1a540c6cad79ced34e2015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 28 Jul 2024 21:12:37 +0200 Subject: [PATCH 04/22] update init srge --- ot/gromov/__init__.py | 7 +- ot/gromov/_semirelaxed.py | 166 +++++++++++++------ ot/gromov/_utils.py | 330 ++++++++++++++++++++------------------ 3 files changed, 294 insertions(+), 209 deletions(-) diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 748b1a97b..5cf19784b 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -11,9 +11,8 @@ # All submodules and packages from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, - init_matrix_semirelaxed, - update_barycenter_structure, update_barycenter_feature, - ) + init_matrix_semirelaxed, semirelaxed_init_plan, + update_barycenter_structure, update_barycenter_feature) from ._gw import (gromov_wasserstein, gromov_wasserstein2, fused_gromov_wasserstein, fused_gromov_wasserstein2, @@ -65,7 +64,7 @@ ) __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', - 'init_matrix_semirelaxed', + 'init_matrix_semirelaxed', 'semirelaxed_init_plan', 'update_barycenter_structure', 'update_barycenter_feature', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 43723084f..454bbbe3c 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -20,12 +20,13 @@ from ._utils import ( init_matrix_semirelaxed, gwloss, gwggrad, update_barycenter_structure, update_barycenter_feature, - _semirelaxed_init_plan + semirelaxed_init_plan ) -def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_gromov_wasserstein( + C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -70,15 +71,21 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme Print information along iterations log : bool, optional record log if True - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -114,9 +121,13 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, method=G0, random_state=random_state, nx=nx) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -156,8 +167,10 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_gromov_wasserstein2( + C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, + G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, + **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -205,15 +218,21 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm Print information along iterations log : bool, optional record log if True - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -244,7 +263,8 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm T, log_srgw = semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun, symmetric, log=True, G0=G0, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, + random_state=random_state, **kwargs) q = nx.sum(T, 0) log_srgw['T'] = T @@ -268,7 +288,8 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm def semirelaxed_fused_gromov_wasserstein( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, + random_state=0, **kwargs): r""" Computes the semi-relaxed Fused Gromov-Wasserstein transport between two graphs (see [48]). @@ -312,9 +333,13 @@ def semirelaxed_fused_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional record log if True max_iter : int, optional @@ -323,6 +348,8 @@ def semirelaxed_fused_gromov_wasserstein( Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -367,6 +394,9 @@ def semirelaxed_fused_gromov_wasserstein( if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, M, alpha, G0, random_state, nx) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -407,8 +437,10 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, + random_state=0, **kwargs): r""" Computes the semi-relaxed FGW divergence between two graphs (see [48]). @@ -455,9 +487,13 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional Record log if True. max_iter : int, optional @@ -466,6 +502,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict Parameters can be directly passed to the ot.optim.cg solver. @@ -502,7 +540,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo T, log_fgw = semirelaxed_fused_gromov_wasserstein( M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, + random_state=random_state, **kwargs) q = nx.sum(T, 0) srfgw_dist = log_fgw['srfgw_dist'] log_fgw['T'] = T @@ -616,7 +655,8 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, def entropic_semirelaxed_gromov_wasserstein( C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -662,9 +702,13 @@ def entropic_semirelaxed_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). verbose : bool, optional Print information along iterations - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -673,6 +717,9 @@ def entropic_semirelaxed_gromov_wasserstein( record log if True verbose : bool, optional Print information along iterations + random_state: int, optional + Random seed used in stochastic initialization methods. + Returns ------- G : array-like, shape (`ns`, `nt`) @@ -701,9 +748,13 @@ def entropic_semirelaxed_gromov_wasserstein( if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, method=G0, random_state=random_state, nx=nx) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -768,7 +819,8 @@ def df(G): def entropic_semirelaxed_gromov_wasserstein2( C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0, **kwargs): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -816,9 +868,13 @@ def entropic_semirelaxed_gromov_wasserstein2( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). verbose : bool, optional Print information along iterations - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -827,8 +883,8 @@ def entropic_semirelaxed_gromov_wasserstein2( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -845,8 +901,8 @@ def entropic_semirelaxed_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srgw = entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, epsilon, symmetric, G0, - max_iter, tol, log=True, verbose=verbose, **kwargs) + C1, C2, p, loss_fun, epsilon, symmetric, G0, max_iter, tol, + log=True, verbose=verbose, random_state=random_state) log_srgw['T'] = T @@ -858,7 +914,8 @@ def entropic_semirelaxed_gromov_wasserstein2( def entropic_semirelaxed_fused_gromov_wasserstein( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -907,9 +964,13 @@ def entropic_semirelaxed_fused_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -918,8 +979,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -949,9 +1010,13 @@ def entropic_semirelaxed_fused_gromov_wasserstein( if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, M, alpha, G0, random_state, nx) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -1018,7 +1083,8 @@ def df(G): def entropic_semirelaxed_fused_gromov_wasserstein2( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Computes the entropic-regularized semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -1067,9 +1133,13 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is pq^T. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + `semirelaxed_init_plan` taking values in "product", "random_product", + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -1078,8 +1148,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - Parameters can be directly passed to the ot.optim.cg solver. + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -1097,8 +1167,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, - max_iter, tol, log=True, verbose=verbose, **kwargs) + M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, max_iter, tol, + log=True, verbose=verbose, random_state=random_state) log_srfgw['T'] = T diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 87e8301bc..26af6311e 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -17,7 +17,7 @@ from .lp import emd try: - from networkx.algorithms.community import asyn_fluidc + from networkx.algorithms.community import asyn_fluidc from networkx import from_numpy_array networkx_import = True except ImportError: @@ -388,6 +388,178 @@ def h2(b): return constC, hC1, hC2, fC2t +def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., + method='product', random_state=0, nx=None): + """ + Heuristics to initialize the semi-relaxed (F)GW transport plan between a + graph :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`. + + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space. + p : array-like, shape (ns,) + Probability distribution in the source space. + M : array-like, shape (ns, nt) + Metric cost matrix between features across domains + alpha : float, optional + Trade-off parameter (0 <= alpha <= 1) + method : str, optional + Method to initialize the transport plan. The default is 'product'. + random_state: int, optional + Random seed used for stochastic methods. + nx : backend, optional + POT backend. + + Returns + ------- + T : array-like, shape (ns, ns) + Admissible transport plan for the sr(F)GW problems. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + list_partitioning_methods = [ + 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', + 'kmeans_soft'] + + if method not in list_partitioning_methods + ['product', 'random_product']: + raise ValueError(f'Unsupported initialization method = {method}.') + + if nx is None: + arr = [C1, C2, p] + if M is not None: + arr += arr + nx = get_backend(*arr) + + n = C1.shape[0] + m = C2.shape[0] + if method in list_partitioning_methods: + min_size = min(n, m) + if n > m: # partition C1 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C1) + elif m > n: # partition C2 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C2) + else: # equal size -> simple Wasserstein alignment + C_to_partition = None + warnings.warn( + "Both structures have the same size so no partitioning is" + "performed to initialize the transport plan even though" + f"initialization method is {method}", + stacklevel=2 + ) + + def get_transport_from_partition(part): + if n > m: # partition C1 to deduce map from C1 to C2 + T_ = nx.zeros((n, m)) + T_[nx.arange(n), part] = 1. + T_ = p[:, None] * T_ + q = nx.sum(T_, 0) + + factored_C1 = nx.dot(nx.dot(T_.T, C1), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(factored_C1, C2) + T_emd = emd(q, q, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd) + elif m > n: + T_ = nx.zeros((m, n)) + T_[nx.arange(m), part] = 1. / m # assume uniform masses on C2 + q = nx.sum(T_, 0) + + factored_C2 = nx.dot(nx.dot(T_.T, C2), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(factored_C2, C1) + T_emd = emd(q, p, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd).T + q = nx.sum(T, 0) # uniform one + else: + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(C1, C2) + q = p + T = emd(p, q, M_structure) + + return T, q + + # Handle initialization via structure information + + if method == 'product': + q = nx.ones(m, type_as=C2) + T = nx.outer(p, q) + + elif method == 'random_product': + q = np.random.uniform(0, m, size=(m,)) + q = q / q.sum() + q = nx.from_numpy(q) + T = nx.outer(p, q) + + elif method in ['fluid', 'fluid_soft']: + # compute fluid partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + graph = from_numpy_array(C_to_partition) + part_sets = asyn_fluidc(graph, min_size, seed=random_state) + part = nx.zeros(C_to_partition) + for iset_, set_ in enumerate(part_sets): + set_ = list(set_) + part[set_] = iset_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['spectral', 'spectral_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + sc = SpectralClustering(n_clusters=min_size, + random_state=random_state, + affinity='precomputed').fit(C_to_partition) + part = sc.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['kmeans', 'kmeans_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + km = KMeans(n_clusters=min_size, random_state=random_state, + n_init=1).fit(C_to_partition) + + part = km.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + # Add feature information solving a semi-relaxed Wasserstein problem + if M is not None: + # get minimum by rows as binary mask + TM = nx.ones(1, type_as=p) * (M == nx.reshape(nx.min(M, axis=1), (-1, 1))) + TM *= nx.reshape((p / nx.sum(TM, axis=1)), (-1, 1)) + + T = alpha * T + (1 - alpha) * TM + + return T + + def update_barycenter_structure( Ts, Cs, lambdas, p=None, loss_fun='square_loss', target=True, check_zeros=True, nx=None): @@ -593,159 +765,3 @@ def update_barycenter_feature( inv_p = 1. / p_sum return sum(list_features) * inv_p[:, None] - - -def _semirelaxed_init_plan(C1, C2, p, method='product', random_state=0, nx=None): - """ - Heuristics to initialize the semi-relaxed (F)GW transport plan between a - graph :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`. - - - Parameters - ---------- - C1 : array-like, shape (ns, ns) - Metric cost matrix in the source space. - C2 : array-like, shape (nt, nt) - Metric cost matrix in the target space. - p : array-like, shape (ns,) - Probability distribution in the source space. - method : str, optional - Method to initialize the transport plan. The default is 'product'. - random_state: int, optional - Random seed used for stochastic methods. - nx : backend, optional - POT backend. - - Returns - ------- - T : array-like, shape (ns, ns) - Admissible transport plan for the sr(F)GW problems. - - References - ---------- - .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. - "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" - International Conference on Learning Representations (ICLR), 2022. - - """ - list_partitioning_methods = [ - 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', - 'kmeans_soft'] - - if method not in list_partitioning_methods + ['product', 'random_product']: - raise ValueError(f'Unsupported initialization method = {method}.') - - if nx is None: - C1, C2, p = list_to_array(C1, C2, p) - nx = get_backend(C1, C2, p) - - n = C1.shape[0] - m = C2.shape[0] - if method in list_partitioning_methods: - min_size = min(n, m) - if n > m: # partition C1 to deduce map from C1 to C2 - C_to_partition = nx.to_numpy(C1) - elif m > n: # partition C2 to deduce map from C1 to C2 - C_to_partition = nx.to_numpy(C2) - else: # equal size -> simple Wasserstein alignment - C_to_partition = None - warnings.warn( - "Both structures have the same size so no partitioning is" - "performed to initialize the transport plan even though" - f"initialization method is {method}", - stacklevel=2 - ) - - - def get_transport_from_partition(part): - if n > m: # partition C1 to deduce map from C1 to C2 - T_ = nx.zeros((n, m)) - T_[nx.arange(n), part] = 1. - T_ = p[:, None] * T_ - q = nx.sum(T_, 0) - - factored_C1 = nx.dot(nx.dot(T_.T, C1), T_) / nx.outer(q, q) - - # alignment of both structure seen as feature matrices - M_structure = euclidean_distances(factored_C1, C2) - T_emd = emd(q, q, M_structure) - inv_q = 1. / q - - T = nx.dot(T_, inv_q[:, None] * T_emd) - elif m > n: - T_ = nx.zeros((m, n)) - T_[nx.arange(m), part] = 1. / m # assume uniform masses on C2 - q = nx.sum(T_, 0) - - factored_C2 = nx.dot(nx.dot(T_.T, C2), T_) / nx.outer(q, q) - - # alignment of both structure seen as feature matrices - M_structure = euclidean_distances(factored_C2, C1) - T_emd = emd(q, p, M_structure) - inv_q = 1. / q - - T = nx.dot(T_, inv_q[:, None] * T_emd).T - q = nx.sum(T, 0) # uniform one - else: - # alignment of both structure seen as feature matrices - M_structure = euclidean_distances(C1, C2) - q = p - T = emd(p, q, M_structure) - - return T, q - - if method == 'product': - q = nx.ones(m, type_as=C2) - T = nx.outer(p, q) - - elif method == 'random_product': - q = np.random.uniform(0, m, size=(m,)) - q = q / q.sum() - q = nx.from_numpy(q) - T = nx.outer(p, q) - - elif method in ['fluid', 'fluid_soft']: - # compute fluid partitioning on the biggest graph - if C_to_partition is None: - T, q = get_transport_from_partition(None) - else: - graph = from_numpy_array(C_to_partition) - part_sets = asyn_fluidc(graph, min_size, seed=random_state) - part = nx.zeros(C_to_partition) - for iset_, set_ in enumerate(part_sets): - set_ = list(set_) - part[set_] = iset_ - T, q = get_transport_from_partition(part) - - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. - - elif method in ['spectral', 'spectral_soft']: - # compute spectral partitioning on the biggest graph - if C_to_partition is None: - T, q = get_transport_from_partition(None) - else: - sc = SpectralClustering(n_clusters=min_size, - random_state=random_state, - affinity='precomputed').fit(C_to_partition) - part = sc.labels_ - T, q = get_transport_from_partition(part) - - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. - - elif method in ['kmeans', 'kmeans_soft']: - # compute spectral partitioning on the biggest graph - if C_to_partition is None: - T, q = get_transport_from_partition(None) - else: - km = KMeans(n_clusters=min_size, random_state=random_state, - n_init=1).fit(C_to_partition) - - part = km.labels_ - T, q = get_transport_from_partition(part) - - if 'soft' in method: - T = (T + nx.outer(p, q)) / 2. - - return T \ No newline at end of file From 3ec7c38444aa6a91f42d4c24b794b784bfc043fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 28 Jul 2024 23:32:37 +0200 Subject: [PATCH 05/22] first tests for srgw inits --- ot/gromov/_semirelaxed.py | 12 ++++--- ot/gromov/_utils.py | 16 +++++---- test/gromov/test_semirelaxed.py | 63 +++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 24 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 454bbbe3c..657bae299 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) @@ -128,6 +128,7 @@ def semirelaxed_gromov_wasserstein( elif isinstance(G0, str): G0 = semirelaxed_init_plan( C1, C2, p, method=G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -383,7 +384,7 @@ def semirelaxed_fused_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) @@ -397,6 +398,7 @@ def semirelaxed_fused_gromov_wasserstein( elif isinstance(G0, str): G0 = semirelaxed_init_plan( C1, C2, p, M, alpha, G0, random_state, nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -741,7 +743,7 @@ def entropic_semirelaxed_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) @@ -755,6 +757,7 @@ def entropic_semirelaxed_gromov_wasserstein( elif isinstance(G0, str): G0 = semirelaxed_init_plan( C1, C2, p, method=G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -1003,7 +1006,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) @@ -1017,6 +1020,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( elif isinstance(G0, str): G0 = semirelaxed_init_plan( C1, C2, p, M, alpha, G0, random_state, nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 26af6311e..f33ea330a 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -14,7 +14,7 @@ from ..utils import list_to_array, euclidean_distances from ..backend import get_backend -from .lp import emd +from ..lp import emd try: from networkx.algorithms.community import asyn_fluidc @@ -441,6 +441,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., n = C1.shape[0] m = C2.shape[0] + if method in list_partitioning_methods: min_size = min(n, m) if n > m: # partition C1 to deduce map from C1 to C2 @@ -458,7 +459,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., def get_transport_from_partition(part): if n > m: # partition C1 to deduce map from C1 to C2 - T_ = nx.zeros((n, m)) + T_ = nx.zeros((n, m), type_as=C1) T_[nx.arange(n), part] = 1. T_ = p[:, None] * T_ q = nx.sum(T_, 0) @@ -472,7 +473,7 @@ def get_transport_from_partition(part): T = nx.dot(T_, inv_q[:, None] * T_emd) elif m > n: - T_ = nx.zeros((m, n)) + T_ = nx.zeros((m, n), type_as=C1) T_[nx.arange(m), part] = 1. / m # assume uniform masses on C2 q = nx.sum(T_, 0) @@ -491,15 +492,16 @@ def get_transport_from_partition(part): q = p T = emd(p, q, M_structure) - return T, q + return T, q # Handle initialization via structure information if method == 'product': - q = nx.ones(m, type_as=C2) + q = nx.ones(m, type_as=C2) / m T = nx.outer(p, q) elif method == 'random_product': + np.random.seed(random_state) q = np.random.uniform(0, m, size=(m,)) q = q / q.sum() q = nx.from_numpy(q) @@ -512,10 +514,12 @@ def get_transport_from_partition(part): else: graph = from_numpy_array(C_to_partition) part_sets = asyn_fluidc(graph, min_size, seed=random_state) - part = nx.zeros(C_to_partition) + part = np.zeros(C_to_partition.shape[0], dtype=int) for iset_, set_ in enumerate(part_sets): set_ = list(set_) part[set_] = iset_ + part = nx.from_numpy(part) + T, q = get_transport_from_partition(part) if 'soft' in method: diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 8e6805d41..7a322e20e 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -10,7 +10,12 @@ import ot from ot.backend import torch +from ot.gromov._utils import ( + networkx_import, sklearn_import) + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_semirelaxed_gromov(nx): rng = np.random.RandomState(0) # unbalanced proportions @@ -19,19 +24,26 @@ def test_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], + [0.1, 1.]], dtype=np.float64) + + pos = [0, 30, 45] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i+1] + pos_j_min, pos_j_max = pos[j], pos[j+1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + + p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) G0 = p[:, None] * q0[None, :] # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - + for loss_fun in ['square_loss', 'kl_loss']: G, log = ot.gromov.semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) @@ -60,19 +72,42 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - # symmetric + ## symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) + print('deg:', C1.sum(0)) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) + + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] + + if networkx_import: + init_plan_list += [('fluid', 'fluid'), ('fluid_soft', 'fluid_soft')] + + if sklearn_import: + init_plan_list += [ + ("spectral", "spectral"), ("spectral_soft", "spectral_soft"), + ("kmeans", "kmeans"), ("kmeans_soft", "kmeans_soft")] + + for (init, init_b) in init_plan_list: + print('------') + print('init:', init) + print('init_b:', init_b) - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) - - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=init) + Gb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=init_b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + + if not isinstance(init, str): + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + else: + if not 'spectral' in init: # issues with spectral clustering related to label switching + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) From 3e7a9fa11b159379cb1f3b01dc37c1d3aabaf3e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 28 Jul 2024 23:45:54 +0200 Subject: [PATCH 06/22] fix tol params in solve_gromov --- ot/solvers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 95165ea11..ae44e860f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -851,7 +851,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) plan = log['T'] value_linear = 0 @@ -865,7 +865,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] From bf80ea028c2fa083064021817438cb51a6370559 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 29 Jul 2024 00:28:59 +0200 Subject: [PATCH 07/22] fix bary tests --- ot/gromov/_utils.py | 6 ++-- test/gromov/test_semirelaxed.py | 57 +++++++++++++++++---------------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index f33ea330a..23da552e5 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -459,8 +459,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., def get_transport_from_partition(part): if n > m: # partition C1 to deduce map from C1 to C2 - T_ = nx.zeros((n, m), type_as=C1) - T_[nx.arange(n), part] = 1. + T_ = nx.eye(m, type_as=C1)[part] T_ = p[:, None] * T_ q = nx.sum(T_, 0) @@ -473,8 +472,7 @@ def get_transport_from_partition(part): T = nx.dot(T_, inv_q[:, None] * T_emd) elif m > n: - T_ = nx.zeros((m, n), type_as=C1) - T_[nx.arange(m), part] = 1. / m # assume uniform masses on C2 + T_ = nx.eye(n, type_as=C1)[part] / m # assume uniform masses on C2 q = nx.sum(T_, 0) factored_C2 = nx.dot(nx.dot(T_.T, C2), T_) / nx.outer(q, q) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 7a322e20e..016273a6c 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -14,8 +14,6 @@ networkx_import, sklearn_import) -@pytest.skip_backend("jax", reason="test very slow with jax backend") -@pytest.skip_backend("tf", reason="test very slow with tf backend") def test_semirelaxed_gromov(nx): rng = np.random.RandomState(0) # unbalanced proportions @@ -26,24 +24,23 @@ def test_semirelaxed_gromov(nx): C1 = np.zeros((ns, ns), dtype=np.float64) C2 = np.array([[0.8, 0.1], [0.1, 1.]], dtype=np.float64) - + pos = [0, 30, 45] - + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - pos_i_min, pos_i_max = pos[i], pos[i+1] - pos_j_min, pos_j_max = pos[j], pos[j+1] + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij - - + p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) G0 = p[:, None] * q0[None, :] # asymmetric C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - + for loss_fun in ['square_loss', 'kl_loss']: G, log = ot.gromov.semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=G0) @@ -75,20 +72,20 @@ def test_semirelaxed_gromov(nx): ## symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) print('deg:', C1.sum(0)) - + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - + init_plan_list = [ (None, G0b), ('product', None), ("random_product", "random_product")] - + if networkx_import: init_plan_list += [('fluid', 'fluid'), ('fluid_soft', 'fluid_soft')] - + if sklearn_import: init_plan_list += [ ("spectral", "spectral"), ("spectral_soft", "spectral_soft"), ("kmeans", "kmeans"), ("kmeans_soft", "kmeans_soft")] - + for (init, init_b) in init_plan_list: print('------') print('init:', init) @@ -98,15 +95,15 @@ def test_semirelaxed_gromov(nx): C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=init) Gb = ot.gromov.semirelaxed_gromov_wasserstein( C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=init_b) - + # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - + if not isinstance(init, str): np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov else: - if not 'spectral' in init: # issues with spectral clustering related to label switching + if not 'spectral' in init: # issues with spectral clustering related to label switching np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( @@ -708,15 +705,21 @@ def test_semirelaxed_gromov_barycenter(nx): np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) # test consistency across backends with 'kl_loss' - Cb2 = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 + Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=10, + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, random_state=42 ) - Cb2b = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], - 'kl_loss', max_iter=10, tol=1e-3, warmstartT=True, random_state=42 - )) - np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) + Cb2b, errb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=10, + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, random_state=42 + ) + Cb2b = nx.to_numpy(Cb2b) + + try: + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) # may differ from permutation + except: + np.testing.assert_allclose(err['loss'][-1], errb['loss'][-1], atol=1e-06) + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) # test of gromov_barycenters with `log` on @@ -728,11 +731,11 @@ def test_semirelaxed_gromov_barycenter(nx): init_Cb = nx.from_numpy(init_C) Cb2_, err2_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=10, + n_samples, [C1, C2], [p1, p2], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C ) Cb2b_, err2b_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, verbose=True, random_state=42, init_C=init_Cb, log=True ) From 45d48d7eecfd5e34ff691ff2d1d77a7a578ad60d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Mon, 29 Jul 2024 00:43:42 +0200 Subject: [PATCH 08/22] update pep8 --- ot/gromov/_semirelaxed.py | 32 ++++++++++++++++---------------- test/gromov/test_semirelaxed.py | 9 +++------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 657bae299..820163e91 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -75,9 +75,9 @@ def semirelaxed_gromov_wasserstein( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional @@ -223,9 +223,9 @@ def semirelaxed_gromov_wasserstein2( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional @@ -338,9 +338,9 @@ def semirelaxed_fused_gromov_wasserstein( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional record log if True max_iter : int, optional @@ -493,9 +493,9 @@ def semirelaxed_fused_gromov_wasserstein2( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional Record log if True. max_iter : int, optional @@ -708,9 +708,9 @@ def entropic_semirelaxed_gromov_wasserstein( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -875,9 +875,9 @@ def entropic_semirelaxed_gromov_wasserstein2( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -971,9 +971,9 @@ def entropic_semirelaxed_fused_gromov_wasserstein( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -1141,9 +1141,9 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( If `G0=None` the initial transport plan of the solver is pq^T. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. - if G0 is a string it will be interpreted as a method for + if G0 is a string it will be interpreted as a method for `semirelaxed_init_plan` taking values in "product", "random_product", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 016273a6c..b7acb8748 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -69,7 +69,7 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - ## symmetric - testing various initialization of the OT plan. + # symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) print('deg:', C1.sum(0)) @@ -87,9 +87,6 @@ def test_semirelaxed_gromov(nx): ("kmeans", "kmeans"), ("kmeans_soft", "kmeans_soft")] for (init, init_b) in init_plan_list: - print('------') - print('init:', init) - print('init_b:', init_b) G, log = ot.gromov.semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=init) @@ -103,7 +100,7 @@ def test_semirelaxed_gromov(nx): if not isinstance(init, str): np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov else: - if not 'spectral' in init: # issues with spectral clustering related to label switching + if 'spectral' not in init: # issues with spectral clustering related to label switching np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( @@ -717,7 +714,7 @@ def test_semirelaxed_gromov_barycenter(nx): try: np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) # may differ from permutation - except: + except AssertionError: np.testing.assert_allclose(err['loss'][-1], errb['loss'][-1], atol=1e-06) np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) From 17dd7ca3b85d652104904ad5556349ebfa83734d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 30 Jul 2024 00:38:26 +0200 Subject: [PATCH 09/22] setting up informative inits for sr(F)GW barycenters --- ot/gromov/_semirelaxed.py | 178 +++++++++++++++++++++------ ot/gromov/_utils.py | 39 +++--- test/gromov/test_semirelaxed.py | 212 ++++++++++++++++++++++---------- 3 files changed, 312 insertions(+), 117 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 820163e91..5c14cafac 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -20,9 +20,17 @@ from ._utils import ( init_matrix_semirelaxed, gwloss, gwggrad, update_barycenter_structure, update_barycenter_feature, - semirelaxed_init_plan + semirelaxed_init_plan, ) +try: + from sklearn.cluster import KMeans + sklearn_import = True +except ImportError: + sklearn_import = False + +import warnings + def semirelaxed_gromov_wasserstein( C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, @@ -1186,7 +1194,7 @@ def semirelaxed_gromov_barycenters( N, Cs, ps=None, lambdas=None, loss_fun='square_loss', symmetric=True, max_iter=1000, tol=1e-9, stop_criterion='barycenter', warmstartT=False, verbose=False, - log=False, init_C=None, random_state=None, **kwargs): + log=False, init_C=None, G0='product', random_state=None, **kwargs): r""" Returns the Semi-relaxed Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` @@ -1211,7 +1219,7 @@ def semirelaxed_gromov_barycenters( ps : list of S array-like of shape (ns,), optional Sample weights in the `S` spaces. If let to its default value None, uniform distributions are taken. - lambdas : list of float, optional + lambdas : array-like of shape (S,) , optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. loss_fun : callable, optional @@ -1234,8 +1242,13 @@ def semirelaxed_gromov_barycenters( Print information along iterations. log : bool, optional Record log if True. - init_C : bool | array-like, shape(N,N) + init_C : array-like of shape (N,N), optional. Random initial value for the :math:`\mathbf{C}` matrix provided by user. + Default is None and relies `G0` to produce an initial structure. + G0: str, optional. Default is 'product'. + Initialization method following heuristics developed in `semirelaxed_init_plan`. + Methods based on the clustering of inputs are used to deduce an initial + barycenter structure if `init_C=None`. random_state : int or RandomState instance, optional Fix the seed for reproducibility @@ -1270,20 +1283,52 @@ def semirelaxed_gromov_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = nx.ones(S) / S + else: + lambdas = list_to_array(lambdas, nx) - # Initialization of C : random SPD matrix (if not provided by user) + # Initialization of transport plans and C (if not provided by user) if init_C is None: - rng = check_random_state(random_state) - xalea = rng.randn(N, 2) - C = dist(xalea, xalea) - C /= C.max() - C = nx.from_numpy(C, type_as=Cs[0]) + init_C = nx.zeros((N, N), type_as=Cs[0]) + if G0 in ['product', 'random_product']: + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + + else: # relies on partitioning of inputs + shapes = np.array([C.shape[0] for C in Cs]) + large_graphs_idx = np.where(shapes > N)[0] + small_graphs_idx = np.where(shapes <= N) + T = [] + list_init_C = [] # store different barycenter structure to average + + # we first compute an initial informative barycenter structure + # on graphs we can compress + # then use it on graphs to expand + for indices in [large_graphs_idx, small_graphs_idx]: + if indices.shape[0] > 0: + sub_T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in indices] + sub_Cs = [Cs[i] for i in indices] + sub_lambdas = lambdas[indices] + + init_C = update_barycenter_structure( + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + T += sub_T + list_init_C.append(init_C) + + if len(list_init_C) == 2: + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + C = init_C else: C = init_C - - if warmstartT: - T = [None] * S + T = [semirelaxed_init_plan( + Cs[i], C, ps[i], method=G0, use_target=True, + random_state=random_state, nx=nx) for i in range(S)] if stop_criterion == 'barycenter': inner_log = False @@ -1313,7 +1358,7 @@ def semirelaxed_gromov_barycenters( for s in range(S)] else: res = [semirelaxed_gromov_wasserstein( - Cs[s], C, ps[s], loss_fun, symmetric, G0=None, + Cs[s], C, ps[s], loss_fun, symmetric, G0=G0, max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, verbose=verbose, **kwargs) for s in range(S)] @@ -1329,8 +1374,7 @@ def semirelaxed_gromov_barycenters( [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) - print('p:', p) - print('C:', C) + # update convergence criterion if stop_criterion == 'barycenter': err = nx.norm(C - Cprev) @@ -1366,7 +1410,7 @@ def semirelaxed_fgw_barycenters( fixed_features=False, p=None, loss_fun='square_loss', symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter', warmstartT=False, verbose=False, log=False, init_C=None, init_X=None, - random_state=None, **kwargs): + G0='product', random_state=None, **kwargs): r""" Returns the Semi-relaxed Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` @@ -1396,7 +1440,7 @@ def semirelaxed_fgw_barycenters( ps : list of array-like, each element has shape (ns,), optional Masses of all samples. If let to its default value None, uniform distributions are taken. - lambdas : list of float, optional + lambdas : array-like of shape (S,) , optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. alpha : float, optional @@ -1431,6 +1475,8 @@ def semirelaxed_fgw_barycenters( init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + G0: str, optional. Default is 'product'. + Initialization method following heuristics developed in `semirelaxed_init_plan`. random_state : int or RandomState instance, optional Fix the seed for reproducibility @@ -1466,7 +1512,9 @@ def semirelaxed_fgw_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S + lambdas = nx.ones(S) / S + else: + lambdas = list_to_array(lambdas, nx) d = Ys[0].shape[1] # dimension on the node features @@ -1476,14 +1524,6 @@ def semirelaxed_fgw_barycenters( 'If C is fixed it must be provided in init_C') else: C = init_C - else: - if init_C is None: - rng = check_random_state(random_state) - xalea = rng.randn(N, 2) - C = dist(xalea, xalea) - C = nx.from_numpy(C, type_as=ps[0]) - else: - C = init_C if fixed_features: if init_X is None: @@ -1491,17 +1531,84 @@ def semirelaxed_fgw_barycenters( 'If X is fixed it must be provided in init_X') else: X = init_X - else: + + # Initialization of transport plans, C and X (if not provided by user) + if G0 in ['product', 'random_product']: + # both init_X and init_C are simply deduced from transport plans + # if not initialized + if init_C is None: + init_C = nx.zeros((N, N), type_as=Cs[0]) # to know the barycenter shape + + if G0 in ['product', 'random_product']: + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + + if init_C is None: + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) if init_X is None: - X = nx.zeros((N, d), type_as=ps[0]) + init_X = update_barycenter_feature( + T, Ys, lambdas, loss_fun=loss_fun, nx=nx) + else: + # more computationally costly inits could be used on structures + # so we assume affordable a Kmeans-like init for features + # and use it by default. + + if init_X is None: + stacked_features = nx.to_numpy(nx.concatenate(Ys, axis=0)) + if sklearn_import: + km = KMeans(n_clusters=N, random_state=random_state, + n_init=1).fit(stacked_features) + init_X = nx.from_numpy(km.cluster_centers_) + else: + warnings.warn( + "Kmeans clustering cannot be performed to init barycenter features," + "consider installing scikit-learn.", + stacklevel=2 + ) + X = init_X else: X = init_X - Ms = [dist(Ys[s], X) for s in range(len(Ys))] - - if warmstartT: - T = [None] * S + Ms = [dist(Ys[s], X) for s in range(len(Ys))] + + if (init_C is None): + init_C = nx.zeros((N, N), type_as=Cs[0]) + + # relies on partitioning of inputs + shapes = np.array([C.shape[0] for C in Cs]) + large_graphs_idx = np.where(shapes > N)[0] + small_graphs_idx = np.where(shapes <= N) + T = [] + list_init_C = [] # store different barycenter structure to average + + # we first compute an initial informative barycenter structure + # on graphs we can compress + # then use it on graphs to expand + for indices in [large_graphs_idx, small_graphs_idx]: + if indices.shape[0] > 0: + sub_T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], Ms[i], alpha, method=G0, use_target=False, + random_state=random_state, nx=nx) for i in indices] + sub_Cs = [Cs[i] for i in indices] + sub_lambdas = lambdas[indices] + + init_C = update_barycenter_structure( + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + T += sub_T + list_init_C.append(init_C) + + if len(list_init_C) == 2: + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + C = init_C + else: + C = init_C + T = [semirelaxed_init_plan( + Cs[i], C, ps[i], Ms[i], alpha, method=G0, use_target=True, + random_state=random_state, nx=nx) for i in range(S)] if stop_criterion == 'barycenter': inner_log = False @@ -1535,9 +1642,10 @@ def semirelaxed_fgw_barycenters( for s in range(S)] else: res = [semirelaxed_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, None, + Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, G0, inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) for s in range(S)] + if stop_criterion == 'barycenter': T = res else: diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 23da552e5..b21d57dc1 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -388,8 +388,8 @@ def h2(b): return constC, hC1, hC2, fC2t -def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., - method='product', random_state=0, nx=None): +def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', + use_target=True, random_state=0, nx=None): """ Heuristics to initialize the semi-relaxed (F)GW transport plan between a graph :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`. @@ -399,16 +399,20 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., ---------- C1 : array-like, shape (ns, ns) Metric cost matrix in the source space. - C2 : array-like, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix in the target space. - p : array-like, shape (ns,) - Probability distribution in the source space. - M : array-like, shape (ns, nt) - Metric cost matrix between features across domains + p : array-like, shape (ns,), optional. + Probability distribution in the source space. If let to None, uniform + weights are assumed on C1. + M : array-like, shape (ns, nt), optional. + Metric cost matrix between features across domains. alpha : float, optional Trade-off parameter (0 <= alpha <= 1) method : str, optional Method to initialize the transport plan. The default is 'product'. + use_target : bool, optional. + Whether or not to use the target structure/features to further align + transport plan provided by the `method`. random_state: int, optional Random seed used for stochastic methods. nx : backend, optional @@ -436,14 +440,15 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., if nx is None: arr = [C1, C2, p] if M is not None: - arr += arr + arr.append(M) + nx = get_backend(*arr) n = C1.shape[0] m = C2.shape[0] + min_size = min(n, m) if method in list_partitioning_methods: - min_size = min(n, m) if n > m: # partition C1 to deduce map from C1 to C2 C_to_partition = nx.to_numpy(C1) elif m > n: # partition C2 to deduce map from C1 to C2 @@ -466,11 +471,15 @@ def get_transport_from_partition(part): factored_C1 = nx.dot(nx.dot(T_.T, C1), T_) / nx.outer(q, q) # alignment of both structure seen as feature matrices - M_structure = euclidean_distances(factored_C1, C2) - T_emd = emd(q, q, M_structure) - inv_q = 1. / q + if use_target: + M_structure = euclidean_distances(factored_C1, C2) + T_emd = emd(q, q, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd) + else: + T = T_ - T = nx.dot(T_, inv_q[:, None] * T_emd) elif m > n: T_ = nx.eye(n, type_as=C1)[part] / m # assume uniform masses on C2 q = nx.sum(T_, 0) @@ -551,8 +560,8 @@ def get_transport_from_partition(part): if 'soft' in method: T = (T + nx.outer(p, q)) / 2. - # Add feature information solving a semi-relaxed Wasserstein problem - if M is not None: + if (M is not None): + # Add feature information solving a semi-relaxed Wasserstein problem # get minimum by rows as binary mask TM = nx.ones(1, type_as=p) * (M == nx.reshape(nx.min(M, axis=1), (-1, 1))) TM *= nx.reshape((p / nx.sum(TM, axis=1)), (-1, 1)) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index b7acb8748..6d6f4ec2c 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -71,7 +71,6 @@ def test_semirelaxed_gromov(nx): # symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) - print('deg:', C1.sum(0)) C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) @@ -231,11 +230,17 @@ def test_semirelaxed_fgw(nx): C1 = np.zeros((ns, ns)) C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) + + pos = [0, 16, 24] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + F1 = np.zeros((ns, 1)) F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) @@ -247,57 +252,72 @@ def test_semirelaxed_fgw(nx): q0 = ot.unif(C2.shape[0]) G0 = p[:, None] * q0[None, :] - # asymmetric + # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( + M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) - # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + # asymmetric - check consistency between srFGW and srFGW2 + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) - # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], G.sum(0), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - # symmetric + # symmetric structures + checking losses + inits + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] + + if networkx_import: + init_plan_list += [('fluid', 'fluid')] + + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + for loss_fun in ['square_loss', 'kl_loss']: - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + for (init, init_b) in init_plan_list: - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=init) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=init_b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + # checking consistency with srFGW and srFGW2 solvers srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G2 = log2['T'] + Gb2 = nx.to_numpy(logb2['T']) # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(G2, Gb2, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(srgw, srgwb, atol=1e-07) def test_semirelaxed_fgw2_gradients(): @@ -419,14 +439,21 @@ def test_entropic_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], + [0.1, 0.9]], dtype=np.float64) + rng = np.random.RandomState(0) + + pos = [0, 30, 45] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) G0 = p[:, None] * q0[None, :] @@ -456,29 +483,56 @@ def test_entropic_semirelaxed_gromov(nx): np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - # symmetric + # symmetric - testing various initialization of the OT plan. + C1 = 0.5 * (C1 + C1.T) C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) + init_plan_list = [] # tests longer than with CG so we do not test all inits. - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + if networkx_import: + init_plan_list += [('fluid', 'fluid')] + + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + + init_plan_list += [ + ('product', None), (None, G0b)] + + for (init, init_b) in init_plan_list: + print(f'---- init : {init} / init_b : {init_b}') + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, + log=True, G0=init) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=True, + log=True, G0=init_b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + + if not isinstance(init, str): + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + else: + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + # comparison between srGW and srGW2 solvers + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, + log=True, G0=init) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, + log=True, G0=init_b) srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) - G = log2['T'] + G2 = log2['T'] + G2b = logb2['T'] # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + np.testing.assert_allclose(G2, G2b, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) @@ -532,11 +586,17 @@ def test_entropic_semirelaxed_fgw(nx): C1 = np.zeros((ns, ns)) C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) + + pos = [0, 16, 24] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + F1 = np.zeros((ns, 1)) F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) @@ -548,19 +608,20 @@ def test_entropic_semirelaxed_fgw(nx): q0 = ot.unif(C2.shape[0]) G0 = p[:, None] * q0[None, :] - # asymmetric + # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) - # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -572,34 +633,51 @@ def test_entropic_semirelaxed_fgw(nx): np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - # symmetric + # symmetric structures + checking losses + inits C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + if networkx_import: + init_plan_list += [('fluid', 'fluid')] - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] - srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + for loss_fun in ['square_loss', 'kl_loss']: + for (init, init_b) in init_plan_list: - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=None, log=True, G0=init) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=True, log=False, G0=init_b) + + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + # checking consistency with srFGW and srFGW2 solvers + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=True, log=True, G0=init) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=None, log=True, G0=init_b) + + G2 = log2['T'] + Gb2 = nx.to_numpy(logb2['T']) + np.testing.assert_allclose(G2, Gb2, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) + np.testing.assert_allclose(p, Gb2.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb2.sum(0), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(srgw, srgwb, atol=1e-07) @pytest.skip_backend("tf", reason="test very slow with tf backend") From 1613d510d6219eeb7cd518c9a89d61090aa18db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 30 Jul 2024 01:00:58 +0200 Subject: [PATCH 10/22] updating inits for sr(F)GW bary --- ot/gromov/_semirelaxed.py | 26 +++++++++++++++----------- ot/gromov/_utils.py | 2 +- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 5c14cafac..196820d5d 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -12,7 +12,7 @@ from ..utils import ( - list_to_array, unif, dist, UndefinedParameter, check_random_state + list_to_array, unif, dist, UndefinedParameter ) from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend @@ -405,7 +405,7 @@ def semirelaxed_fused_gromov_wasserstein( G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, M, alpha, G0, random_state, nx) + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -1027,7 +1027,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( G0 = nx.outer(p, q) elif isinstance(G0, str): G0 = semirelaxed_init_plan( - C1, C2, p, M, alpha, G0, random_state, nx) + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) @@ -1285,7 +1285,8 @@ def semirelaxed_gromov_barycenters( if lambdas is None: lambdas = nx.ones(S) / S else: - lambdas = list_to_array(lambdas, nx) + lambdas = list_to_array(lambdas) + lambdas = nx.from_numpy(lambdas) # Initialization of transport plans and C (if not provided by user) if init_C is None: @@ -1294,7 +1295,7 @@ def semirelaxed_gromov_barycenters( T = [semirelaxed_init_plan( Cs[i], init_C, ps[i], method=G0, use_target=False, random_state=random_state, nx=nx) for i in range(S)] - init_C = update_barycenter_structure( + C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) else: # relies on partitioning of inputs @@ -1514,9 +1515,8 @@ def semirelaxed_fgw_barycenters( if lambdas is None: lambdas = nx.ones(S) / S else: - lambdas = list_to_array(lambdas, nx) - - d = Ys[0].shape[1] # dimension on the node features + lambdas = list_to_array(lambdas) + lambdas = nx.from_numpy(lambdas) if fixed_structure: if init_C is None: @@ -1545,12 +1545,16 @@ def semirelaxed_fgw_barycenters( random_state=random_state, nx=nx) for i in range(S)] if init_C is None: - init_C = update_barycenter_structure( + C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + else: + C = init_C + if init_X is None: - init_X = update_barycenter_feature( + X = update_barycenter_feature( T, Ys, lambdas, loss_fun=loss_fun, nx=nx) - + else: + X = init_X else: # more computationally costly inits could be used on structures # so we assume affordable a Kmeans-like init for features diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index b21d57dc1..509314f77 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -399,7 +399,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', ---------- C1 : array-like, shape (ns, ns) Metric cost matrix in the source space. - C2 : array-like, shape (nt, nt) + C2 : array-like, shape (nt, nt) Metric cost matrix in the target space. p : array-like, shape (ns,), optional. Probability distribution in the source space. If let to None, uniform From 5484f4ed1d893168d13aa0266f7f994898a0e583 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 30 Jul 2024 17:52:33 +0200 Subject: [PATCH 11/22] working on bary tests --- ot/gromov/_semirelaxed.py | 62 ++++++++++++----- ot/gromov/_utils.py | 12 +++- ot/optim.py | 4 +- test/gromov/test_semirelaxed.py | 118 ++++++++++++++++++++++++-------- 4 files changed, 145 insertions(+), 51 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 196820d5d..1ab3ad54e 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -84,7 +84,7 @@ def semirelaxed_gromov_wasserstein( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -232,7 +232,7 @@ def semirelaxed_gromov_wasserstein2( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -347,7 +347,7 @@ def semirelaxed_fused_gromov_wasserstein( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional record log if True @@ -502,7 +502,7 @@ def semirelaxed_fused_gromov_wasserstein2( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional Record log if True. @@ -717,7 +717,7 @@ def entropic_semirelaxed_gromov_wasserstein( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -884,7 +884,7 @@ def entropic_semirelaxed_gromov_wasserstein2( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -980,7 +980,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -1150,7 +1150,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", + `semirelaxed_init_plan` taking values in "product", "random_product", "random", "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations @@ -1245,7 +1245,7 @@ def semirelaxed_gromov_barycenters( init_C : array-like of shape (N,N), optional. Random initial value for the :math:`\mathbf{C}` matrix provided by user. Default is None and relies `G0` to produce an initial structure. - G0: str, optional. Default is 'product'. + G0: str, optional. Default is 'random'. Initialization method following heuristics developed in `semirelaxed_init_plan`. Methods based on the clustering of inputs are used to deduce an initial barycenter structure if `init_C=None`. @@ -1291,23 +1291,34 @@ def semirelaxed_gromov_barycenters( # Initialization of transport plans and C (if not provided by user) if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) - if G0 in ['product', 'random_product']: + if G0 in ['product', 'random_product', 'random']: T = [semirelaxed_init_plan( Cs[i], init_C, ps[i], method=G0, use_target=False, random_state=random_state, nx=nx) for i in range(S)] C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + if G0 in ['product', 'random_product']: + # initial structure is constant so we add a small random noise + # to avoid getting stuck at init + np.random.seed(random_state) + noise = np.random.uniform(-0.01, 0.01, size=(N, N)) + if symmetric: + noise = (noise + noise.T) / 2. + noise = nx.from_numpy(noise) + C = C + noise + else: # relies on partitioning of inputs shapes = np.array([C.shape[0] for C in Cs]) large_graphs_idx = np.where(shapes > N)[0] - small_graphs_idx = np.where(shapes <= N) + small_graphs_idx = np.where(shapes <= N)[0] T = [] list_init_C = [] # store different barycenter structure to average # we first compute an initial informative barycenter structure # on graphs we can compress # then use it on graphs to expand + print('--- looping on indices ---') for indices in [large_graphs_idx, small_graphs_idx]: if indices.shape[0] > 0: sub_T = [semirelaxed_init_plan( @@ -1315,16 +1326,19 @@ def semirelaxed_gromov_barycenters( random_state=random_state, nx=nx) for i in indices] sub_Cs = [Cs[i] for i in indices] sub_lambdas = lambdas[indices] - + print('sub_T:', sub_T) init_C = update_barycenter_structure( sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + print('init_C:', init_C.shape) T += sub_T list_init_C.append(init_C) if len(list_init_C) == 2: + print('len(T):', len(T), T[0].shape, T[1].shape) init_C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) C = init_C + else: C = init_C T = [semirelaxed_init_plan( @@ -1533,20 +1547,29 @@ def semirelaxed_fgw_barycenters( X = init_X # Initialization of transport plans, C and X (if not provided by user) - if G0 in ['product', 'random_product']: + if G0 in ['product', 'random_product', 'random']: # both init_X and init_C are simply deduced from transport plans # if not initialized if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) # to know the barycenter shape - if G0 in ['product', 'random_product']: - T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in range(S)] + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] if init_C is None: C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + if G0 in ['product', 'random_product']: + # initial structure is constant so we add a small random noise + # to avoid getting stuck at init + np.random.seed(random_state) + noise = np.random.uniform(-0.01, 0.01, size=(N, N)) + if symmetric: + noise = (noise + noise.T) / 2. + noise = nx.from_numpy(noise) + C = C + noise + else: C = init_C @@ -1555,6 +1578,9 @@ def semirelaxed_fgw_barycenters( T, Ys, lambdas, loss_fun=loss_fun, nx=nx) else: X = init_X + + Ms = [dist(Ys[s], X) for s in range(len(Ys))] + else: # more computationally costly inits could be used on structures # so we assume affordable a Kmeans-like init for features @@ -1584,7 +1610,7 @@ def semirelaxed_fgw_barycenters( # relies on partitioning of inputs shapes = np.array([C.shape[0] for C in Cs]) large_graphs_idx = np.where(shapes > N)[0] - small_graphs_idx = np.where(shapes <= N) + small_graphs_idx = np.where(shapes <= N)[0] T = [] list_init_C = [] # store different barycenter structure to average diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 509314f77..1c33f1bab 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -434,7 +434,7 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', 'kmeans_soft'] - if method not in list_partitioning_methods + ['product', 'random_product']: + if method not in list_partitioning_methods + ['product', 'random_product', 'random']: raise ValueError(f'Unsupported initialization method = {method}.') if nx is None: @@ -504,16 +504,22 @@ def get_transport_from_partition(part): # Handle initialization via structure information if method == 'product': - q = nx.ones(m, type_as=C2) / m + q = nx.ones(m, type_as=C1) / m T = nx.outer(p, q) elif method == 'random_product': np.random.seed(random_state) q = np.random.uniform(0, m, size=(m,)) q = q / q.sum() - q = nx.from_numpy(q) + q = nx.from_numpy(q, type_as=p) T = nx.outer(p, q) + elif method == 'random': + np.random.seed(random_state) + U = np.random.uniform(0, n * m, size=(n, m)) + U = (p / U.sum(1))[:, None] * U + T = nx.from_numpy(U, type_as=C1) + elif method in ['fluid', 'fluid_soft']: # compute fluid partitioning on the biggest graph if C_to_partition is None: diff --git a/ot/optim.py b/ot/optim.py index bde0fc814..185669a88 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -476,7 +476,9 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask - Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + min_ = nx.reshape(nx.min(Mi, axis=1), (-1, 1)) + # instead of exact equality to min_ we consider a small margin for float precision issues. + Gc = nx.ones(1, type_as=a) * (Mi <= min_ + 1e-15) Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) # return by default an empty inner_log return Gc, {} diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 6d6f4ec2c..7237ce3e8 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -747,37 +747,69 @@ def test_semirelaxed_gromov_barycenter(nx): ) # test consistency of outputs across backends with 'square_loss' - for stop_criterion in ['barycenter', 'loss']: - Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42 - ) - Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42 - )) - np.testing.assert_allclose(Cb, Cbb, atol=1e-06) - np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) - - # test of gromov_barycenters with `log` on - Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - warmstartT=True, random_state=42, log=True - ) - Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, warmstartT=True, random_state=42, log=True - ) + # using different losses + # + tests on different inits + init_plan_list = [('random', 'random')] - Cbb_ = nx.to_numpy(Cbb_) + if networkx_import: + init_plan_list += [('fluid', 'fluid')] - np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) - np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + + for (init, init_b) in init_plan_list: + + for stop_criterion in ['barycenter', 'loss']: + print('--- stop_criterion:', stop_criterion) + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0=init_b + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + warmstartT=True, random_state=42, log=True, G0=init, + ) + Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, warmstartT=True, random_state=42, log=True, G0=init_b + ) + + Cbb_ = nx.to_numpy(Cbb_) + + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + # test consistency across backends with larger barycenter than inputs + C = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion='loss', verbose=False, + random_state=42, G0='kmeans' + ) + Cb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0='kmeans' + )) + np.testing.assert_allclose(C, Cb, atol=1e-06) + # test providing init_C + Cb_ = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0='kmeans', init_C=Cb + )) + np.testing.assert_allclose(Cb, Cb_, atol=1e-06) # test consistency across backends with 'kl_loss' Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( @@ -943,6 +975,34 @@ def test_semirelaxed_fgw_barycenter(nx): np.testing.assert_allclose(C, recovered_C) + # test consistency of outputs across backends with 'square_loss' + # with various initialization of G0 + init_plan_list = [('random', 'random')] + + if networkx_import: + init_plan_list += [('fluid', 'fluid')] + + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + + for (init, init_b) in init_plan_list: + print(f'---- init : {init} / init_b : {init_b}') + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) + # test edge cases for semirelaxed fgw barycenters: # unique input structure X, C = ot.gromov.semirelaxed_fgw_barycenters( From e7ee31e9255e2f4864dff43ddf7395419afe6253 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Tue, 30 Jul 2024 21:39:34 +0200 Subject: [PATCH 12/22] complete tests sr(F)GW barycenters with inits --- ot/gromov/_semirelaxed.py | 47 ++++++++++++++++++++++----------- test/gromov/test_semirelaxed.py | 45 +++++++++++++++++++++++++------ 2 files changed, 69 insertions(+), 23 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 1ab3ad54e..2d5e7cd22 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -12,7 +12,7 @@ from ..utils import ( - list_to_array, unif, dist, UndefinedParameter + list_to_array, unif, dist, UndefinedParameter, euclidean_distances ) from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend @@ -1312,7 +1312,7 @@ def semirelaxed_gromov_barycenters( shapes = np.array([C.shape[0] for C in Cs]) large_graphs_idx = np.where(shapes > N)[0] small_graphs_idx = np.where(shapes <= N)[0] - T = [] + T = [None] * S list_init_C = [] # store different barycenter structure to average # we first compute an initial informative barycenter structure @@ -1320,21 +1320,25 @@ def semirelaxed_gromov_barycenters( # then use it on graphs to expand print('--- looping on indices ---') for indices in [large_graphs_idx, small_graphs_idx]: - if indices.shape[0] > 0: + if len(indices) > 0: sub_T = [semirelaxed_init_plan( Cs[i], init_C, ps[i], method=G0, use_target=False, random_state=random_state, nx=nx) for i in indices] sub_Cs = [Cs[i] for i in indices] - sub_lambdas = lambdas[indices] + print('shapes sub_CS:', [x.shape for x in sub_Cs]) + sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) + print('sub_lambdas:', sub_lambdas) print('sub_T:', sub_T) init_C = update_barycenter_structure( sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) print('init_C:', init_C.shape) - T += sub_T + for i, idx in enumerate(indices): + T[idx] = sub_T[i] list_init_C.append(init_C) if len(list_init_C) == 2: print('len(T):', len(T), T[0].shape, T[1].shape) + print('len(Cs):', [x.shape for x in Cs]) init_C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) C = init_C @@ -1553,11 +1557,10 @@ def semirelaxed_fgw_barycenters( if init_C is None: init_C = nx.zeros((N, N), type_as=Cs[0]) # to know the barycenter shape - T = [semirelaxed_init_plan( - Cs[i], init_C, ps[i], method=G0, use_target=False, - random_state=random_state, nx=nx) for i in range(S)] + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] - if init_C is None: C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) if G0 in ['product', 'random_product']: @@ -1571,6 +1574,10 @@ def semirelaxed_fgw_barycenters( C = C + noise else: + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + C = init_C if init_X is None: @@ -1587,17 +1594,24 @@ def semirelaxed_fgw_barycenters( # and use it by default. if init_X is None: - stacked_features = nx.to_numpy(nx.concatenate(Ys, axis=0)) + stacked_features = nx.concatenate(Ys, axis=0) if sklearn_import: + stacked_features = nx.to_numpy(stacked_features) km = KMeans(n_clusters=N, random_state=random_state, n_init=1).fit(stacked_features) init_X = nx.from_numpy(km.cluster_centers_) else: warnings.warn( - "Kmeans clustering cannot be performed to init barycenter features," - "consider installing scikit-learn.", + "Kmeans clustering cannot be performed to init barycenter " + "features, so we only randomly sample input features. " + "Consider installing scikit-learn.", stacklevel=2 ) + np.random.seed(random_state) + selected_idx = np.random.choice( + np.arange(stacked_features.shape[0]), size=N, replace=False) + init_X = stacked_features[selected_idx] + X = init_X else: X = init_X @@ -1611,14 +1625,14 @@ def semirelaxed_fgw_barycenters( shapes = np.array([C.shape[0] for C in Cs]) large_graphs_idx = np.where(shapes > N)[0] small_graphs_idx = np.where(shapes <= N)[0] - T = [] + T = [None] * S list_init_C = [] # store different barycenter structure to average # we first compute an initial informative barycenter structure # on graphs we can compress # then use it on graphs to expand for indices in [large_graphs_idx, small_graphs_idx]: - if indices.shape[0] > 0: + if len(indices) > 0: sub_T = [semirelaxed_init_plan( Cs[i], init_C, ps[i], Ms[i], alpha, method=G0, use_target=False, random_state=random_state, nx=nx) for i in indices] @@ -1627,7 +1641,10 @@ def semirelaxed_fgw_barycenters( init_C = update_barycenter_structure( sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) - T += sub_T + + for i, idx in enumerate(indices): + T[idx] = sub_T[i] + list_init_C.append(init_C) if len(list_init_C) == 2: diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 7237ce3e8..bfa7ee2e1 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -722,6 +722,8 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): nx.assert_same_dtype_device(C1b, fgw_valb) +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.skip_backend("jax", reason="test very slow with tf backend") def test_semirelaxed_gromov_barycenter(nx): ns = 5 nt = 8 @@ -797,18 +799,18 @@ def test_semirelaxed_gromov_barycenter(nx): tol=1e-3, stop_criterion='loss', verbose=False, random_state=42, G0='kmeans' ) - Cb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + Cb = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0='kmeans' - )) - np.testing.assert_allclose(C, Cb, atol=1e-06) + verbose=False, random_state=42, G0='kmeans') + + np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C - Cb_ = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0='kmeans', init_C=Cb - )) + verbose=False, random_state=42, G0='kmeans', init_C=Cb) + np.testing.assert_allclose(Cb, Cb_, atol=1e-06) # test consistency across backends with 'kl_loss' @@ -977,7 +979,7 @@ def test_semirelaxed_fgw_barycenter(nx): # test consistency of outputs across backends with 'square_loss' # with various initialization of G0 - init_plan_list = [('random', 'random')] + init_plan_list = [('product', 'product'), ('random', 'random')] if networkx_import: init_plan_list += [('fluid', 'fluid')] @@ -1003,6 +1005,33 @@ def test_semirelaxed_fgw_barycenter(nx): np.testing.assert_allclose(X, nx.to_numpy(Xb)) np.testing.assert_allclose(C, nx.to_numpy(Cb)) + # test while providing advanced T inits and init_X != None + Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + warmstartT=True, random_state=12345, log=True, verbose=True, + init_C=Cb, init_X=Xb + ) + np.testing.assert_allclose(Xb, Xb_) + np.testing.assert_allclose(Cb, Cb_) + + # test consistency of backends while barycenter size not strictly inferior to sizes + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) + # test edge cases for semirelaxed fgw barycenters: # unique input structure X, C = ot.gromov.semirelaxed_fgw_barycenters( From 935087ceb9b59dca10c404220eae310a102745e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 7 Aug 2024 01:35:07 +0200 Subject: [PATCH 13/22] update tests and exemple --- ...mirelaxed_gromov_wasserstein_barycenter.py | 23 +++++++++---------- ot/gromov/_semirelaxed.py | 11 ++------- ot/gromov/_utils.py | 22 ++++++++---------- test/gromov/test_semirelaxed.py | 7 +++--- 4 files changed, 26 insertions(+), 37 deletions(-) diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py index cef4ba9a7..8c069fb10 100644 --- a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -44,7 +44,7 @@ import matplotlib.pylab as pl from sklearn.manifold import MDS from ot.gromov import ( - semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) + semirelaxed_gromov_barycenters) import ot import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -115,7 +115,7 @@ def plot_graph(x, C, binary=True, color='C0', s=None): ############################################################################# # -# Estimate the Gromov-Wasserstein dictionary from the dataset +# Estimate the srGW barycenter from the dataset # ----------------------------------------------------------- @@ -124,17 +124,18 @@ def plot_graph(x, C, binary=True, color='C0', s=None): lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter N = 3 # 3 nodes in the barycenter -# Optionally provide an initial barycenter structure `init_C` +# Here we use the Fluid partitioning method to deduce initial transport plans +# for the barycenter problem. An initlal structure is also deduced from these +# initial transport plans. Then a warmstart strategy is used iteratively to +# init each individual srGW problem within the BCD algorithm. -init_C = np.array([[0.6, 0.2, 0.2], - [0.2, 0.6, 0.2], - [0.2, 0.2, 0.6]]) - -print('init_C:', init_C) +init_plan = 'fluid' # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan` +warmstartT = True C, log = semirelaxed_gromov_barycenters( - N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss', tol=1e-6, - stop_criterion='loss', warmstartT=True, log=True, init_C=init_C) + N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss', + tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True, + G0=init_plan, verbose=False) # visualize loss evolution over epochs pl.figure(2, (4, 3)) @@ -145,5 +146,3 @@ def plot_graph(x, C, binary=True, color='C0', s=None): pl.ylabel('loss', fontsize=12) pl.tight_layout() pl.show() - -print('C:', C) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 2d5e7cd22..c1c6303b1 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -12,7 +12,7 @@ from ..utils import ( - list_to_array, unif, dist, UndefinedParameter, euclidean_distances + list_to_array, unif, dist, UndefinedParameter ) from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend @@ -1318,27 +1318,20 @@ def semirelaxed_gromov_barycenters( # we first compute an initial informative barycenter structure # on graphs we can compress # then use it on graphs to expand - print('--- looping on indices ---') for indices in [large_graphs_idx, small_graphs_idx]: if len(indices) > 0: sub_T = [semirelaxed_init_plan( Cs[i], init_C, ps[i], method=G0, use_target=False, random_state=random_state, nx=nx) for i in indices] sub_Cs = [Cs[i] for i in indices] - print('shapes sub_CS:', [x.shape for x in sub_Cs]) sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) - print('sub_lambdas:', sub_lambdas) - print('sub_T:', sub_T) init_C = update_barycenter_structure( sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) - print('init_C:', init_C.shape) for i, idx in enumerate(indices): T[idx] = sub_T[i] list_init_C.append(init_C) if len(list_init_C) == 2: - print('len(T):', len(T), T[0].shape, T[1].shape) - print('len(Cs):', [x.shape for x in Cs]) init_C = update_barycenter_structure( T, Cs, lambdas, loss_fun=loss_fun, nx=nx) C = init_C @@ -1637,7 +1630,7 @@ def semirelaxed_fgw_barycenters( Cs[i], init_C, ps[i], Ms[i], alpha, method=G0, use_target=False, random_state=random_state, nx=nx) for i in indices] sub_Cs = [Cs[i] for i in indices] - sub_lambdas = lambdas[indices] + sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) init_C = update_barycenter_structure( sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 1c33f1bab..999e325e8 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -437,12 +437,14 @@ def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', if method not in list_partitioning_methods + ['product', 'random_product', 'random']: raise ValueError(f'Unsupported initialization method = {method}.') - if nx is None: - arr = [C1, C2, p] - if M is not None: - arr.append(M) + if (method in ['kmeans', 'kmeans_soft']) and (not sklearn_import): + raise ValueError(f'Scikit-learn must be installed to use method = {method}') - nx = get_backend(*arr) + if (method in ['fluid', 'fluid_soft']) and (not networkx_import): + raise ValueError(f'Networkx must be installed to use method = {method}') + + if nx is None: + nx = get_backend(C1, C2, p, M) n = C1.shape[0] m = C2.shape[0] @@ -644,10 +646,7 @@ def update_barycenter_structure( """ if nx is None: - arr = [*Ts, *Cs] - if p is not None: - arr += [p] - + arr = [*Ts, *Cs, p] nx = get_backend(*arr) S = len(Ts) @@ -748,10 +747,7 @@ def update_barycenter_feature( International Conference on Learning Representations (ICLR), 2022. """ if nx is None: - arr = [*Ts, *Ys] - if p is not None: - arr += [p] - + arr = [*Ts, *Ys, p] nx = get_backend(*arr) if loss_fun != 'square_loss': diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index bfa7ee2e1..bee978528 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -793,23 +793,24 @@ def test_semirelaxed_gromov_barycenter(nx): np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + # test consistency across backends with larger barycenter than inputs C = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, tol=1e-3, stop_criterion='loss', verbose=False, - random_state=42, G0='kmeans' + random_state=42, G0=init ) Cb = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0='kmeans') + verbose=False, random_state=42, G0=init_b) np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0='kmeans', init_C=Cb) + verbose=False, random_state=42, G0=init_b, init_C=Cb) np.testing.assert_allclose(Cb, Cb_, atol=1e-06) From 6cb10bec9d04217e87ae1e36a5a9508e257dbd37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 7 Aug 2024 02:07:43 +0200 Subject: [PATCH 14/22] up --- test/gromov/test_semirelaxed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index bee978528..f069cfff6 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -817,11 +817,13 @@ def test_semirelaxed_gromov_barycenter(nx): # test consistency across backends with 'kl_loss' Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=10, - tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, random_state=42 + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, + G0=init_b, random_state=42 ) Cb2b, errb = ot.gromov.semirelaxed_gromov_barycenters( n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=10, - tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, random_state=42 + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, + G0=init_b, random_state=42 ) Cb2b = nx.to_numpy(Cb2b) From c16980e5ea8a1a1afb82db353f86f5db6fb48d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Sun, 25 Aug 2024 21:15:18 +0200 Subject: [PATCH 15/22] complete exemple for barycenters --- README.md | 2 +- RELEASES.md | 5 +- ...mirelaxed_gromov_wasserstein_barycenter.py | 152 ++++++++++++++++-- ot/gromov/_utils.py | 4 +- 4 files changed, 142 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index b7493074c..fbad3086e 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]). +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized [48]). * [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. diff --git a/RELEASES.md b/RELEASES.md index 1ec999811..cc18cc91b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,9 +5,8 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) - Gaussian Mixture Model OT `ot.gmm` (PR #649) -- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter - updates `update_barycenter_structure` and `update_barycenter_feature` (PR - #659) +- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) +- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659) - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py index 8c069fb10..0620e14b9 100644 --- a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -18,9 +18,9 @@ First, we consider a dataset composed of graphs generated by Stochastic Block models -with variable sizes taken in :math:`\{30, ... , 50\}` and quantities of clusters -varying in :math:`\{ 1, 2, 3\}`. We learn a srGW barycenter with 6 nodes and -visualize the learned structure and the embeddings for some inputs. +with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters +varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter +with 3 nodes and visualize the learned structure and the embeddings for some inputs. Second, we illustrate the extension of this framework to graphs endowed with node features by using the semi-relaxed Fused Gromov-Wasserstein @@ -38,13 +38,13 @@ # # License: MIT License -# sphinx_gallery_thumbnail_number = 4 +# sphinx_gallery_thumbnail_number = 2 import numpy as np import matplotlib.pylab as pl from sklearn.manifold import MDS from ot.gromov import ( - semirelaxed_gromov_barycenters) + semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) import ot import networkx from networkx.generators.community import stochastic_block_model as sbm @@ -57,7 +57,8 @@ np.random.seed(42) n_samples = 60 # number of graphs in the dataset -# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability. +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability, +# and variable cluster proportions. clusters = [1, 2, 3] Nc = n_samples // len(clusters) # number of graphs by cluster nlabels = len(clusters) @@ -73,9 +74,13 @@ if n_cluster > 1: P = p_inter * np.ones((n_cluster, n_cluster)) np.fill_diagonal(P, p_intra) + props = np.random.uniform(0.2, 1, size=(n_cluster,)) + props /= props.sum() + sizes = np.round(n_nodes * props).astype(np.int32) else: P = p_intra * np.eye(1) - sizes = np.round(n_nodes * np.ones(n_cluster) / n_cluster).astype(np.int32) + sizes = [n_nodes] + G = sbm(sizes, P, seed=i, directed=False) C = networkx.to_numpy_array(G) dataset.append(C) @@ -115,12 +120,12 @@ def plot_graph(x, C, binary=True, color='C0', s=None): ############################################################################# # -# Estimate the srGW barycenter from the dataset +# Estimate the srGW barycenter from the dataset and visualize embeddings # ----------------------------------------------------------- np.random.seed(0) -ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on nodes +ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter N = 3 # 3 nodes in the barycenter @@ -137,12 +142,129 @@ def plot_graph(x, C, binary=True, color='C0', s=None): tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True, G0=init_plan, verbose=False) -# visualize loss evolution over epochs -pl.figure(2, (4, 3)) +print('barycenter structure:', C) + +unmixings = log['p'] +# Compute the 2D representation of the embeddings living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(n_samples, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(2, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() + +############################################################################# +# +# Endow the dataset with node features +# ------------------------------------ +# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters +# 1 cluster --> 0 as nodes feature +# 2 clusters --> 1 as nodes feature +# 3 clusters --> 2 as nodes feature +# features are one-hot encoded following these assignments +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + if i < Nc: # graph with 1 cluster + F[:, 0] = 1. + elif i < 2 * Nc: # graph with 2 clusters + F[:, 1] = 1. + else: # graph with 3 clusters + F[:, 2] = 1. + dataset_features.append(F) + +pl.figure(3, (12, 8)) pl.clf() -pl.title('loss evolution by iteration', fontsize=14) -pl.plot(log['loss']) -pl.xlabel('BCD iterations', fontsize=12) -pl.ylabel('loss', fontsize=12) +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +############################################################################# +# +# Estimate the srFGW barycenter from the attributed graphs and visualize embeddings +# ----------------------------------------------------------- +# We emphasize the dependence to the trade-off parameter alpha that weights the +# relative importance between structures (alpha=1) and features (alpha=0), +# knowing that embeddings that perfectly cluster graphs w.r.t their features +# should collapse in one node with the proper feature/label. + +list_alphas = [0.8, 0.85, 0.9, 0.9999] +list_unmixings2D = [] + +for ialpha, alpha in enumerate(list_alphas): + print('--- alpha:', alpha) + C, F, log = semirelaxed_fgw_barycenters( + N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas, + alpha=alpha, loss_fun='square_loss', tol=1e-6, stop_criterion='loss', + warmstartT=warmstartT, log=True, G0=init_plan) + + print('barycenter structure:', C) + print('barycenter features:', F) + + unmixings = log['p'] + # Compute the 2D representation of the embeddings living in the 2-simplex of probability + unmixings2D = np.zeros(shape=(n_samples, 2)) + for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. + list_unmixings2D.append(unmixings2D.copy()) + +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (16, 4)) +pl.clf() +pl.suptitle('Embedding spaces', fontsize=14) +for ialpha, alpha in enumerate(list_alphas): + pl.subplot(1, len(list_alphas), ialpha + 1) + pl.title(f'alpha = {alpha}', fontsize=14) + for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + else: + pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) + pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') + pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) + pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) + pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) + pl.axis('off') + pl.legend(fontsize=11) pl.tight_layout() pl.show() diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 999e325e8..1c5b92e04 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -667,7 +667,7 @@ def update_barycenter_structure( prod = nx.outer(inv_p, inv_p) else: - quotient = sum([nx.outer(p[s], p[s]) for s in range(S)]) + quotient = sum([lambdas[s] * nx.outer(p[s], p[s]) for s in range(S)]) if check_zeros: prod = nx.nan_to_num(1. / quotient, nan=1., posinf=1., neginf=1.) else: @@ -771,7 +771,7 @@ def update_barycenter_feature( else: inv_p = 1. / p else: - p_sum = sum(p) + p_sum = sum([lambdas[s] * p[s] for s in range(S)]) if check_zeros: inv_p = nx.nan_to_num(1. / p_sum, nan=1., posinf=1., neginf=1.) else: From a633660e2d77681353f372fab6ccf580a1a238f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 28 Aug 2024 18:40:06 +0200 Subject: [PATCH 16/22] review docstrings --- ot/gromov/_semirelaxed.py | 74 +++++++++++++++++++++++---------------- ot/gromov/_utils.py | 25 +++++++++++-- 2 files changed, 66 insertions(+), 33 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index c1c6303b1..adc0668e0 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -80,12 +80,13 @@ def semirelaxed_gromov_wasserstein( log : bool, optional record log if True G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional @@ -228,12 +229,13 @@ def semirelaxed_gromov_wasserstein2( log : bool, optional record log if True G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional @@ -343,12 +345,13 @@ def semirelaxed_fused_gromov_wasserstein( alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional record log if True max_iter : int, optional @@ -498,12 +501,13 @@ def semirelaxed_fused_gromov_wasserstein2( alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional Record log if True. max_iter : int, optional @@ -713,12 +717,13 @@ def entropic_semirelaxed_gromov_wasserstein( verbose : bool, optional Print information along iterations G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -880,12 +885,13 @@ def entropic_semirelaxed_gromov_wasserstein2( verbose : bool, optional Print information along iterations G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -976,12 +982,13 @@ def entropic_semirelaxed_fused_gromov_wasserstein( alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -1146,12 +1153,13 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( alpha : float, optional Trade-off parameter (0 < alpha < 1) G0: array-like of shape (ns,nt) or string, optional - If `G0=None` the initial transport plan of the solver is pq^T. + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. If G0 is a tensor it must satisfy marginal constraints and will be used as initial transport of the solver. if G0 is a string it will be interpreted as a method for - `semirelaxed_init_plan` taking values in "product", "random_product", "random", - "fluid", "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -1245,10 +1253,12 @@ def semirelaxed_gromov_barycenters( init_C : array-like of shape (N,N), optional. Random initial value for the :math:`\mathbf{C}` matrix provided by user. Default is None and relies `G0` to produce an initial structure. - G0: str, optional. Default is 'random'. - Initialization method following heuristics developed in `semirelaxed_init_plan`. - Methods based on the clustering of inputs are used to deduce an initial - barycenter structure if `init_C=None`. + G0: str, optional. Default is 'product'. + Initialization method for transport plans calling :func:`ot.gromov.semirelaxed_init_plan`, + and taking values in "product", "random_product", "random", "fluid", + "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + Transport plans are used to deduce an initial barycenter structure + if `init_C=None`. random_state : int or RandomState instance, optional Fix the seed for reproducibility @@ -1487,8 +1497,12 @@ def semirelaxed_fgw_barycenters( init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. - G0: str, optional. Default is 'product'. - Initialization method following heuristics developed in `semirelaxed_init_plan`. + G0: str, optional. Default is "product". + Initialization method for transport plans calling :func:`ot.gromov.semirelaxed_init_plan`, + and taking values in "product", "random_product", "random", "fluid", + "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + Transport plans are used to deduce an initial barycenter structure + if `init_C=None`. random_state : int or RandomState instance, optional Fix the seed for reproducibility diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index 1c5b92e04..5c465cba8 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -390,10 +390,29 @@ def h2(b): def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', use_target=True, random_state=0, nx=None): - """ - Heuristics to initialize the semi-relaxed (F)GW transport plan between a - graph :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`. + r""" + Heuristics to initialize the semi-relaxed (F)GW transport plan + :math:`\mathbf{T} \in \mathcal{U}_{nt}(\mathbf{p})`, between a graph + :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`, + where :math:`\mathcal{U}_{nt}(\mathbf{p}) = \{\mathbf{T} \in \mathbb{R}_{+}^{ns * nt}, \mathbf{T} \mathbf{1}_{nt} = \mathbf{p} \}`. + Available methods are: + - "product" or "random_product": :math:`\mathbf{T} = \mathbf{pq}^{T}` + with :math:`\mathbf{q}` uniform or randomly samples in the nt probability simplex. + + - "random": random sampling in :math:`\mathcal{U}_{nt}(\mathbf{p})`. + + - "fluid": Fluid algorithm from networkx for graph partitioning. + + - "spectral", "kmeans" : Spectral or Kmeans clustering from sklearn. + + - "fluid_soft", "spectral_soft", "kmeans_soft": :math:`\mathbf{T}_0` given + by corresponding clustering with target marginal :math:`\mathbf{q}_0`, further + centered as :math:`\mathbf{T} = (\mathbf{T}_0 + \mathbf{pq}_0^T) / 2` . + If a metric cost matrix between features across domains :math:`\mathbf{M}` + is a provided, it will be used as cost matrix in a semi-relaxed Wasserstein + problem providing :math:`\mathbf{T}_M \in \mathcal{U}_{nt}(\mathbf{p})`. Then + the outputed transport plan is :math:`\alpha \mathbf{T} + (1 - \alpha ) \mathbf{T}_{M}`. Parameters ---------- From a2de63acb93c67baf9fde62f546300ab4ac0339b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 28 Aug 2024 23:10:08 +0200 Subject: [PATCH 17/22] complete tests --- ot/gromov/_semirelaxed.py | 15 +--- test/gromov/test_semirelaxed.py | 152 +++++++++++++++++++------------- 2 files changed, 93 insertions(+), 74 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index adc0668e0..99211cea0 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -1606,20 +1606,11 @@ def semirelaxed_fgw_barycenters( stacked_features = nx.to_numpy(stacked_features) km = KMeans(n_clusters=N, random_state=random_state, n_init=1).fit(stacked_features) - init_X = nx.from_numpy(km.cluster_centers_) + X = nx.from_numpy(km.cluster_centers_) else: - warnings.warn( - "Kmeans clustering cannot be performed to init barycenter " - "features, so we only randomly sample input features. " - "Consider installing scikit-learn.", - stacklevel=2 + raise ValueError( + f"Scikit-learn must be installed to use G0={G0} while init_X=None." ) - np.random.seed(random_state) - selected_idx = np.random.choice( - np.arange(stacked_features.shape[0]), size=N, replace=False) - init_X = stacked_features[selected_idx] - - X = init_X else: X = init_X diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index f069cfff6..e1512ee30 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -751,77 +751,91 @@ def test_semirelaxed_gromov_barycenter(nx): # test consistency of outputs across backends with 'square_loss' # using different losses # + tests on different inits - init_plan_list = [('random', 'random')] - - if networkx_import: - init_plan_list += [('fluid', 'fluid')] - - if sklearn_import: - init_plan_list += [("kmeans", "kmeans")] + init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), + ('random', 'random')] for (init, init_b) in init_plan_list: for stop_criterion in ['barycenter', 'loss']: print('--- stop_criterion:', stop_criterion) - Cb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - random_state=42, G0=init - ) - Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0=init_b - )) - np.testing.assert_allclose(Cb, Cbb, atol=1e-06) - np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) - - # test of gromov_barycenters with `log` on - Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=10, - tol=1e-3, stop_criterion=stop_criterion, verbose=False, - warmstartT=True, random_state=42, log=True, G0=init, - ) - Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, warmstartT=True, random_state=42, log=True, G0=init_b - ) - - Cbb_ = nx.to_numpy(Cbb_) - np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) - np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) - np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + if (init == 'fluid') and (not networkx_import): + with pytest.raises(ValueError): + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + + elif (init == 'kmeans') and (not sklearn_import): + with pytest.raises(ValueError): + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + else: + + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0=init_b + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + warmstartT=True, random_state=42, log=True, G0=init, + ) + Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, warmstartT=True, random_state=42, log=True, G0=init_b + ) + + Cbb_ = nx.to_numpy(Cbb_) + + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) # test consistency across backends with larger barycenter than inputs C = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, tol=1e-3, stop_criterion='loss', verbose=False, random_state=42, G0=init ) Cb = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, verbose=False, random_state=42, G0=init_b) np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=10, tol=1e-3, stop_criterion=stop_criterion, + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, verbose=False, random_state=42, G0=init_b, init_C=Cb) np.testing.assert_allclose(Cb, Cb_, atol=1e-06) # test consistency across backends with 'kl_loss' Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=10, + n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=5, tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, G0=init_b, random_state=42 ) Cb2b, errb = ot.gromov.semirelaxed_gromov_barycenters( - n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=10, + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=5, tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, G0=init_b, random_state=42 ) @@ -982,31 +996,45 @@ def test_semirelaxed_fgw_barycenter(nx): # test consistency of outputs across backends with 'square_loss' # with various initialization of G0 - init_plan_list = [('product', 'product'), ('random', 'random')] - - if networkx_import: - init_plan_list += [('fluid', 'fluid')] - - if sklearn_import: - init_plan_list += [("kmeans", "kmeans")] + init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), + ('product', 'product'), ('random', 'random')] for (init, init_b) in init_plan_list: print(f'---- init : {init} / init_b : {init_b}') - X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, - warmstartT=True, random_state=12345, log=True, verbose=True - ) - Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, - warmstartT=True, random_state=12345, log=True, verbose=True - ) - np.testing.assert_allclose(X, nx.to_numpy(Xb)) - np.testing.assert_allclose(C, nx.to_numpy(Cb)) + if (init == 'fluid') and (not networkx_import): + with pytest.raises(ValueError): + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + elif (init == 'kmeans') and (not sklearn_import): + with pytest.raises(ValueError): + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + else: + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) # test while providing advanced T inits and init_X != None Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( From d751cac6b9dcc7bb0610eb821e6a9cdbaa32b5e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Wed, 28 Aug 2024 23:12:40 +0200 Subject: [PATCH 18/22] complete tests --- ot/gromov/_semirelaxed.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index 99211cea0..96f776cb1 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -29,8 +29,6 @@ except ImportError: sklearn_import = False -import warnings - def semirelaxed_gromov_wasserstein( C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, From d41bb4cbfaa82d3011b0250bd8295227281b0348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 29 Aug 2024 00:01:21 +0200 Subject: [PATCH 19/22] complete tests --- test/gromov/test_utils.py | 50 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/test/gromov/test_utils.py b/test/gromov/test_utils.py index ad94a4042..70894fcfc 100644 --- a/test/gromov/test_utils.py +++ b/test/gromov/test_utils.py @@ -8,6 +8,8 @@ import pytest import ot +from ot.gromov._utils import ( + networkx_import, sklearn_import) def test_update_barycenter(nx): @@ -61,3 +63,51 @@ def test_update_barycenter(nx): with pytest.raises(ValueError): Xbt = ot.gromov.update_barycenter_feature( Tb, Ysb, lambdas, None, loss_fun='unknown', target=True) + + +def test_semirelaxed_init_plan(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + rng = np.random.RandomState(42) + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + + p1, p2 = ot.unif(ns), ot.unif(nt) + + ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) + + # test not supported method + with pytest.raises(ValueError): + _ = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='unknown') + + if sklearn_import: + # tests consistency across backends with m > n + for method in ['kmeans', 'spectral']: + T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) + Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) + np.testing.assert_allclose(T, Tb) + + # tests consistency across backends with m = n + T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method=method) + Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method=method) + np.testing.assert_allclose(T, Tb) + + if networkx_import: + # tests consistency across backends with m > n + T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') + Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') + np.testing.assert_allclose(T, Tb) + + # tests consistency across backends with m = n + T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') + Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') + np.testing.assert_allclose(T, Tb) From dee391bd0ea83a7258ed4f4cf784dfb368ac56c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 29 Aug 2024 01:31:45 +0200 Subject: [PATCH 20/22] finish taking into account review --- ...mirelaxed_gromov_wasserstein_barycenter.py | 26 ++++---- ot/optim.py | 4 +- test/gromov/test_semirelaxed.py | 61 +++++++++++-------- 3 files changed, 50 insertions(+), 41 deletions(-) diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py index 0620e14b9..e555d1e70 100644 --- a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -63,6 +63,7 @@ Nc = n_samples // len(clusters) # number of graphs by cluster nlabels = len(clusters) dataset = [] +node_labels = [] labels = [] p_inter = 0.1 @@ -82,8 +83,10 @@ sizes = [n_nodes] G = sbm(sizes, P, seed=i, directed=False) + part = np.array([G.nodes[i]['block'] for i in range(np.sum(sizes))]) C = networkx.to_numpy_array(G) dataset.append(C) + node_labels.append(part) labels.append(n_cluster) @@ -177,21 +180,14 @@ def plot_graph(x, C, binary=True, color='C0', s=None): # # Endow the dataset with node features # ------------------------------------ -# We follow this feature assignment on all nodes of a graph depending on its label/number of clusters -# 1 cluster --> 0 as nodes feature -# 2 clusters --> 1 as nodes feature -# 3 clusters --> 2 as nodes feature -# features are one-hot encoded following these assignments +# node labels, corresponding to the true SBM cluster assignments, +# are set for each graph as one-hot encoded node features. + dataset_features = [] for i in range(len(dataset)): n = dataset[i].shape[0] F = np.zeros((n, 3)) - if i < Nc: # graph with 1 cluster - F[:, 0] = 1. - elif i < 2 * Nc: # graph with 2 clusters - F[:, 1] = 1. - else: # graph with 3 clusters - F[:, 2] = 1. + F[np.arange(n), node_labels[i]] = 1. dataset_features.append(F) pl.figure(3, (12, 8)) @@ -199,7 +195,7 @@ def plot_graph(x, C, binary=True, color='C0', s=None): for idx_c, c in enumerate(clusters): C = dataset[(c - 1) * Nc] # sample with c clusters F = dataset_features[(c - 1) * Nc] - colors = ['C' + str(np.argmax(F[i])) for i in range(F.shape[0])] + colors = [f'C{labels[i]}' for i in range(F.shape[0])] # get 2d position for nodes x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) pl.subplot(2, nlabels, c) @@ -220,9 +216,9 @@ def plot_graph(x, C, binary=True, color='C0', s=None): # We emphasize the dependence to the trade-off parameter alpha that weights the # relative importance between structures (alpha=1) and features (alpha=0), # knowing that embeddings that perfectly cluster graphs w.r.t their features -# should collapse in one node with the proper feature/label. +# should ease the identification of the number of clusters in the graphs. -list_alphas = [0.8, 0.85, 0.9, 0.9999] +list_alphas = [0.0001, 0.5, 0.9999] list_unmixings2D = [] for ialpha, alpha in enumerate(list_alphas): @@ -248,7 +244,7 @@ def plot_graph(x, C, binary=True, color='C0', s=None): z = [0.5, np.sqrt(3) / 2.] extremities = np.stack([x, y, z]) -pl.figure(4, (16, 4)) +pl.figure(4, (12, 4)) pl.clf() pl.suptitle('Embedding spaces', fontsize=14) for ialpha, alpha in enumerate(list_alphas): diff --git a/ot/optim.py b/ot/optim.py index 185669a88..a5f88bb29 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -477,7 +477,9 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask min_ = nx.reshape(nx.min(Mi, axis=1), (-1, 1)) - # instead of exact equality to min_ we consider a small margin for float precision issues. + # instead of exact elements equal to min_ we consider a small margin (1e-15) + # for float precision issues. Then the mass is splitted uniformly + # between these elements. Gc = nx.ones(1, type_as=a) * (Mi <= min_ + 1e-15) Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) # return by default an empty inner_log diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index e1512ee30..d4cbcb558 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -809,17 +809,19 @@ def test_semirelaxed_gromov_barycenter(nx): np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) # test consistency across backends with larger barycenter than inputs - C = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, - tol=1e-3, stop_criterion='loss', verbose=False, - random_state=42, G0=init - ) - Cb = ot.gromov.semirelaxed_gromov_barycenters( - ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', - max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0=init_b) + if sklearn_import: + C = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion='loss', verbose=False, + random_state=42, G0='kmeans' + ) + Cb = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0='kmeans') + + np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) - np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', @@ -1036,32 +1038,41 @@ def test_semirelaxed_fgw_barycenter(nx): np.testing.assert_allclose(X, nx.to_numpy(Xb)) np.testing.assert_allclose(C, nx.to_numpy(Cb)) - # test while providing advanced T inits and init_X != None + # test while providing advanced T inits and init_X != None, and init_C !=None Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + max_iter=10, tol=1e-3, stop_criterion='loss', G0='random', warmstartT=True, random_state=12345, log=True, verbose=True, init_C=Cb, init_X=Xb ) np.testing.assert_allclose(Xb, Xb_) np.testing.assert_allclose(Cb, Cb_) - # test consistency of backends while barycenter size not strictly inferior to sizes - X, C, log = ot.gromov.semirelaxed_fgw_barycenters( - ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, - warmstartT=True, random_state=12345, log=True, verbose=True - ) - Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( - ns, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, - warmstartT=True, random_state=12345, log=True, verbose=True + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True, + init_C=Cb, init_X=Xb ) - np.testing.assert_allclose(X, nx.to_numpy(Xb)) - np.testing.assert_allclose(C, nx.to_numpy(Cb)) + + # test consistency of backends while barycenter size not strictly inferior to sizes + if sklearn_import: + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) # test edge cases for semirelaxed fgw barycenters: # unique input structure From e7870d7ed48969b2eab40c9b6b5df9ba4dfd783f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 29 Aug 2024 02:07:16 +0200 Subject: [PATCH 21/22] fix --- test/gromov/test_semirelaxed.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index d4cbcb558..ee5324751 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -823,12 +823,18 @@ def test_semirelaxed_gromov_barycenter(nx): np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) # test providing init_C + C_ = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion='loss', verbose=False, + random_state=42, G0=init, init_C=C + ) + Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=5, tol=1e-3, stop_criterion=stop_criterion, verbose=False, random_state=42, G0=init_b, init_C=Cb) - np.testing.assert_allclose(Cb, Cb_, atol=1e-06) + np.testing.assert_allclose(C_, Cb_, atol=1e-06) # test consistency across backends with 'kl_loss' Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( @@ -1049,16 +1055,16 @@ def test_semirelaxed_fgw_barycenter(nx): np.testing.assert_allclose(Xb, Xb_) np.testing.assert_allclose(Cb, Cb_) - Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( - n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, - fixed_structure=False, fixed_features=False, loss_fun='square_loss', - max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', - warmstartT=True, random_state=12345, log=True, verbose=True, - init_C=Cb, init_X=Xb - ) - # test consistency of backends while barycenter size not strictly inferior to sizes if sklearn_import: + Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True, + init_C=Cb, init_X=Xb + ) + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, fixed_structure=False, fixed_features=False, loss_fun='square_loss', From 06120312e1f78f0dad6b20e894fe696190de44c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Vincent-Cuaz?= Date: Thu, 29 Aug 2024 11:03:57 +0200 Subject: [PATCH 22/22] fix test minimal --- test/gromov/test_semirelaxed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index ee5324751..acc49be5f 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -826,13 +826,13 @@ def test_semirelaxed_gromov_barycenter(nx): C_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, tol=1e-3, stop_criterion='loss', verbose=False, - random_state=42, G0=init, init_C=C + random_state=42, G0=init, init_C=C1 ) Cb_ = ot.gromov.semirelaxed_gromov_barycenters( ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', max_iter=5, tol=1e-3, stop_criterion=stop_criterion, - verbose=False, random_state=42, G0=init_b, init_C=Cb) + verbose=False, random_state=42, G0=init_b, init_C=C1b) np.testing.assert_allclose(C_, Cb_, atol=1e-06)