diff --git a/README.md b/README.md index b7493074c..fbad3086e 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples): * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45] * [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] * [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38]. -* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]). +* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized [48]). * [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68]. * [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50]. * [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays. diff --git a/RELEASES.md b/RELEASES.md index 1ec999811..cc18cc91b 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,9 +5,8 @@ #### New features - Add feature `mass=True` for `nx.kl_div` (PR #654) - Gaussian Mixture Model OT `ot.gmm` (PR #649) -- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter - updates `update_barycenter_structure` and `update_barycenter_feature` (PR - #659) +- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659) +- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659) - Improved `ot.plot.plot1D_mat` (PR #649) - Added `nx.det` (PR #649) - `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649) diff --git a/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py new file mode 100644 index 000000000..e555d1e70 --- /dev/null +++ b/examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py @@ -0,0 +1,266 @@ +# -*- coding: utf-8 -*- + +r""" +===================================================== +Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning +===================================================== + +In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein +(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of +structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}` +where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`. +Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes, +each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph +with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N` +where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT +:math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`) +minimizing the srGW loss between the s^{th} input and the barycenter. + + +First, we consider a dataset composed of graphs generated by Stochastic Block models +with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters +varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter +with 3 nodes and visualize the learned structure and the embeddings for some inputs. + +Second, we illustrate the extension of this framework to graphs endowed +with node features by using the semi-relaxed Fused Gromov-Wasserstein +divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we +add discrete labels uniformly depending on the number of clusters. Then conduct +the analog analysis. + + +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. +"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs". +International Conference on Learning Representations (ICLR), 2022. + +""" +# Author: Cédric Vincent-Cuaz +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pylab as pl +from sklearn.manifold import MDS +from ot.gromov import ( + semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) +import ot +import networkx +from networkx.generators.community import stochastic_block_model as sbm + +############################################################################# +# +# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters. +# ----------------------------------------------------------------------------------------------- + +np.random.seed(42) + +n_samples = 60 # number of graphs in the dataset +# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability, +# and variable cluster proportions. +clusters = [1, 2, 3] +Nc = n_samples // len(clusters) # number of graphs by cluster +nlabels = len(clusters) +dataset = [] +node_labels = [] +labels = [] + +p_inter = 0.1 +p_intra = 0.9 +for n_cluster in clusters: + for i in range(Nc): + n_nodes = int(np.random.uniform(low=30, high=50)) + + if n_cluster > 1: + P = p_inter * np.ones((n_cluster, n_cluster)) + np.fill_diagonal(P, p_intra) + props = np.random.uniform(0.2, 1, size=(n_cluster,)) + props /= props.sum() + sizes = np.round(n_nodes * props).astype(np.int32) + else: + P = p_intra * np.eye(1) + sizes = [n_nodes] + + G = sbm(sizes, P, seed=i, directed=False) + part = np.array([G.nodes[i]['block'] for i in range(np.sum(sizes))]) + C = networkx.to_numpy_array(G) + dataset.append(C) + node_labels.append(part) + labels.append(n_cluster) + + +# Visualize samples + +def plot_graph(x, C, binary=True, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if binary: + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + else: # connection intensity proportional to C[i,j] + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k') + + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color='C0', s=50.) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +############################################################################# +# +# Estimate the srGW barycenter from the dataset and visualize embeddings +# ----------------------------------------------------------- + + +np.random.seed(0) +ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes +lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter +N = 3 # 3 nodes in the barycenter + +# Here we use the Fluid partitioning method to deduce initial transport plans +# for the barycenter problem. An initlal structure is also deduced from these +# initial transport plans. Then a warmstart strategy is used iteratively to +# init each individual srGW problem within the BCD algorithm. + +init_plan = 'fluid' # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan` +warmstartT = True + +C, log = semirelaxed_gromov_barycenters( + N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss', + tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True, + G0=init_plan, verbose=False) + +print('barycenter structure:', C) + +unmixings = log['p'] +# Compute the 2D representation of the embeddings living in the 2-simplex of probability +unmixings2D = np.zeros(shape=(n_samples, 2)) +for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(2, (4, 4)) +pl.clf() +pl.title('Embedding space', fontsize=14) +for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + else: + pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) +pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') +pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) +pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) +pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) +pl.axis('off') +pl.legend(fontsize=11) +pl.tight_layout() +pl.show() + +############################################################################# +# +# Endow the dataset with node features +# ------------------------------------ +# node labels, corresponding to the true SBM cluster assignments, +# are set for each graph as one-hot encoded node features. + +dataset_features = [] +for i in range(len(dataset)): + n = dataset[i].shape[0] + F = np.zeros((n, 3)) + F[np.arange(n), node_labels[i]] = 1. + dataset_features.append(F) + +pl.figure(3, (12, 8)) +pl.clf() +for idx_c, c in enumerate(clusters): + C = dataset[(c - 1) * Nc] # sample with c clusters + F = dataset_features[(c - 1) * Nc] + colors = [f'C{labels[i]}' for i in range(F.shape[0])] + # get 2d position for nodes + x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C) + pl.subplot(2, nlabels, c) + pl.title('(graph) sample from label ' + str(c), fontsize=14) + plot_graph(x, C, binary=True, color=colors, s=50) + pl.axis("off") + pl.subplot(2, nlabels, nlabels + c) + pl.title('(matrix) sample from label %s \n' % c, fontsize=14) + pl.imshow(C, interpolation='nearest') + pl.axis("off") +pl.tight_layout() +pl.show() + +############################################################################# +# +# Estimate the srFGW barycenter from the attributed graphs and visualize embeddings +# ----------------------------------------------------------- +# We emphasize the dependence to the trade-off parameter alpha that weights the +# relative importance between structures (alpha=1) and features (alpha=0), +# knowing that embeddings that perfectly cluster graphs w.r.t their features +# should ease the identification of the number of clusters in the graphs. + +list_alphas = [0.0001, 0.5, 0.9999] +list_unmixings2D = [] + +for ialpha, alpha in enumerate(list_alphas): + print('--- alpha:', alpha) + C, F, log = semirelaxed_fgw_barycenters( + N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas, + alpha=alpha, loss_fun='square_loss', tol=1e-6, stop_criterion='loss', + warmstartT=warmstartT, log=True, G0=init_plan) + + print('barycenter structure:', C) + print('barycenter features:', F) + + unmixings = log['p'] + # Compute the 2D representation of the embeddings living in the 2-simplex of probability + unmixings2D = np.zeros(shape=(n_samples, 2)) + for i, w in enumerate(unmixings): + unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2. + unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2. + list_unmixings2D.append(unmixings2D.copy()) + +x = [0., 0.] +y = [1., 0.] +z = [0.5, np.sqrt(3) / 2.] +extremities = np.stack([x, y, z]) + +pl.figure(4, (12, 4)) +pl.clf() +pl.suptitle('Embedding spaces', fontsize=14) +for ialpha, alpha in enumerate(list_alphas): + pl.subplot(1, len(list_alphas), ialpha + 1) + pl.title(f'alpha = {alpha}', fontsize=14) + for cluster in range(nlabels): + start, end = Nc * cluster, Nc * (cluster + 1) + if cluster == 0: + pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster') + else: + pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1)) + pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes') + pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.) + pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.) + pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.) + pl.axis('off') + pl.legend(fontsize=11) +pl.tight_layout() +pl.show() diff --git a/ot/gromov/__init__.py b/ot/gromov/__init__.py index 675f42ccb..5cf19784b 100644 --- a/ot/gromov/__init__.py +++ b/ot/gromov/__init__.py @@ -11,9 +11,8 @@ # All submodules and packages from ._utils import (init_matrix, tensor_product, gwloss, gwggrad, - init_matrix_semirelaxed, - update_barycenter_structure, update_barycenter_feature, - ) + init_matrix_semirelaxed, semirelaxed_init_plan, + update_barycenter_structure, update_barycenter_feature) from ._gw import (gromov_wasserstein, gromov_wasserstein2, fused_gromov_wasserstein, fused_gromov_wasserstein2, @@ -42,6 +41,7 @@ entropic_semirelaxed_gromov_wasserstein2, entropic_semirelaxed_fused_gromov_wasserstein, entropic_semirelaxed_fused_gromov_wasserstein2, + semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters) from ._dictionary import (gromov_wasserstein_dictionary_learning, @@ -64,7 +64,7 @@ ) __all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', - 'init_matrix_semirelaxed', + 'init_matrix_semirelaxed', 'semirelaxed_init_plan', 'update_barycenter_structure', 'update_barycenter_feature', 'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters', @@ -78,11 +78,13 @@ 'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2', 'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein', 'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein', - 'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning', + 'entropic_semirelaxed_fused_gromov_wasserstein2', + 'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters', + 'gromov_wasserstein_dictionary_learning', 'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning', 'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples', 'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition', 'get_graph_representants', 'format_partitioned_graph', 'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples', 'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples', - 'semirelaxed_fgw_barycenters'] + ] diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index a777239d3..96f776cb1 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -12,7 +12,7 @@ from ..utils import ( - list_to_array, unif, dist, UndefinedParameter, check_random_state + list_to_array, unif, dist, UndefinedParameter ) from ..optim import semirelaxed_cg, solve_1d_linesearch_quad from ..backend import get_backend @@ -20,11 +20,19 @@ from ._utils import ( init_matrix_semirelaxed, gwloss, gwggrad, update_barycenter_structure, update_barycenter_feature, + semirelaxed_init_plan, ) +try: + from sklearn.cluster import KMeans + sklearn_import = True +except ImportError: + sklearn_import = False -def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + +def semirelaxed_gromov_wasserstein( + C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, + max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence transport from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -69,15 +77,22 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme Print information along iterations log : bool, optional record log if True - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -106,16 +121,21 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, method=G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -155,8 +175,10 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, G0=None, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_gromov_wasserstein2( + C1, C2, p=None, loss_fun='square_loss', symmetric=None, log=False, + G0=None, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, random_state=0, + **kwargs): r""" Returns the semi-relaxed Gromov-Wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` (see [48]). @@ -204,15 +226,22 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm Print information along iterations log : bool, optional record log if True - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol_rel : float, optional Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -243,7 +272,8 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm T, log_srgw = semirelaxed_gromov_wasserstein( C1, C2, p, loss_fun, symmetric, log=True, G0=G0, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, + random_state=random_state, **kwargs) q = nx.sum(T, 0) log_srgw['T'] = T @@ -267,7 +297,8 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm def semirelaxed_fused_gromov_wasserstein( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, - G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, + random_state=0, **kwargs): r""" Computes the semi-relaxed Fused Gromov-Wasserstein transport between two graphs (see [48]). @@ -311,9 +342,14 @@ def semirelaxed_fused_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional record log if True max_iter : int, optional @@ -322,6 +358,8 @@ def semirelaxed_fused_gromov_wasserstein( Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict parameters can be directly passed to the ot.optim.cg solver @@ -355,7 +393,7 @@ def semirelaxed_fused_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) @@ -366,6 +404,10 @@ def semirelaxed_fused_gromov_wasserstein( if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -406,8 +448,10 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) -def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, G0=None, log=False, - max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, **kwargs): +def semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, alpha=0.5, + G0=None, log=False, max_iter=1e4, tol_rel=1e-9, tol_abs=1e-9, + random_state=0, **kwargs): r""" Computes the semi-relaxed FGW divergence between two graphs (see [48]). @@ -454,9 +498,14 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". log : bool, optional Record log if True. max_iter : int, optional @@ -465,6 +514,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo Stop threshold on relative error (>0) tol_abs : float, optional Stop threshold on absolute error (>0) + random_state: int, optional + Random seed used in stochastic initialization methods. **kwargs : dict Parameters can be directly passed to the ot.optim.cg solver. @@ -501,7 +552,8 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo T, log_fgw = semirelaxed_fused_gromov_wasserstein( M, C1, C2, p, loss_fun, symmetric, alpha, G0, log=True, - max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, **kwargs) + max_iter=max_iter, tol_rel=tol_rel, tol_abs=tol_abs, + random_state=random_state, **kwargs) q = nx.sum(T, 0) srfgw_dist = log_fgw['srfgw_dist'] log_fgw['T'] = T @@ -615,7 +667,8 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, def entropic_semirelaxed_gromov_wasserstein( C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence transport plan from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -661,9 +714,14 @@ def entropic_semirelaxed_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). verbose : bool, optional Print information along iterations - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -672,6 +730,9 @@ def entropic_semirelaxed_gromov_wasserstein( record log if True verbose : bool, optional Print information along iterations + random_state: int, optional + Random seed used in stochastic initialization methods. + Returns ------- G : array-like, shape (`ns`, `nt`) @@ -693,16 +754,21 @@ def entropic_semirelaxed_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, method=G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -767,7 +833,8 @@ def df(G): def entropic_semirelaxed_gromov_wasserstein2( C1, C2, p=None, loss_fun='square_loss', epsilon=0.1, symmetric=None, - G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0, **kwargs): r""" Returns the entropic-regularized semi-relaxed gromov-wasserstein divergence from :math:`(\mathbf{C_1}, \mathbf{p})` to :math:`\mathbf{C_2}` @@ -815,9 +882,14 @@ def entropic_semirelaxed_gromov_wasserstein2( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). verbose : bool, optional Print information along iterations - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -826,8 +898,8 @@ def entropic_semirelaxed_gromov_wasserstein2( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -844,8 +916,8 @@ def entropic_semirelaxed_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srgw = entropic_semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun, epsilon, symmetric, G0, - max_iter, tol, log=True, verbose=verbose, **kwargs) + C1, C2, p, loss_fun, epsilon, symmetric, G0, max_iter, tol, + log=True, verbose=verbose, random_state=random_state) log_srgw['T'] = T @@ -857,7 +929,8 @@ def entropic_semirelaxed_gromov_wasserstein2( def entropic_semirelaxed_fused_gromov_wasserstein( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Computes the entropic-regularized semi-relaxed FGW transport between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -906,9 +979,14 @@ def entropic_semirelaxed_fused_gromov_wasserstein( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -917,8 +995,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - parameters can be directly passed to the ot.optim.cg solver + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -941,16 +1019,21 @@ def entropic_semirelaxed_fused_gromov_wasserstein( else: p = unif(C1.shape[0], type_as=C1) - if G0 is not None: + if (G0 is not None) and (not isinstance(G0, str)): arr.append(G0) nx = get_backend(*arr) if symmetric is None: symmetric = nx.allclose(C1, C1.T, atol=1e-10) and nx.allclose(C2, C2.T, atol=1e-10) + if G0 is None: q = unif(C2.shape[0], type_as=p) G0 = nx.outer(p, q) + elif isinstance(G0, str): + G0 = semirelaxed_init_plan( + C1, C2, p, M, alpha, G0, random_state=random_state, nx=nx) + q = nx.sum(G0, 0) else: q = nx.sum(G0, 0) # Check first marginal of G0 @@ -1017,7 +1100,8 @@ def df(G): def entropic_semirelaxed_fused_gromov_wasserstein2( M, C1, C2, p=None, loss_fun='square_loss', symmetric=None, epsilon=0.1, - alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, **kwargs): + alpha=0.5, G0=None, max_iter=1e4, tol=1e-9, log=False, verbose=False, + random_state=0): r""" Computes the entropic-regularized semi-relaxed FGW divergence between two graphs (see :ref:`[48] `) estimated using a Mirror Descent algorithm following the KL geometry. @@ -1066,9 +1150,14 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymetric). alpha : float, optional Trade-off parameter (0 < alpha < 1) - G0: array-like, shape (ns,nt), optional - If None the initial transport plan of the solver is pq^T. - Otherwise G0 must satisfy marginal constraints and will be used as initial transport of the solver. + G0: array-like of shape (ns,nt) or string, optional + If `G0=None` the initial transport plan of the solver is :math:`\mathbf{p} \frac{\mathbf{1}_{nt}}{nt}^\top`. + If G0 is a tensor it must satisfy marginal constraints and will be + used as initial transport of the solver. + if G0 is a string it will be interpreted as a method for + :func:`ot.gromov.semirelaxed_init_plan` taking values in "product", + "random_product", "random", "fluid", "fluid_soft", "spectral", + "spectral_soft", "kmeans", "kmeans_soft". max_iter : int, optional Max number of iterations tol : float, optional @@ -1077,8 +1166,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( record log if True verbose : bool, optional Print information along iterations - **kwargs : dict - Parameters can be directly passed to the ot.optim.cg solver. + random_state: int, optional + Random seed used in stochastic initialization methods. Returns ------- @@ -1096,8 +1185,8 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( International Conference on Learning Representations (ICLR), 2022. """ T, log_srfgw = entropic_semirelaxed_fused_gromov_wasserstein( - M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, - max_iter, tol, log=True, verbose=verbose, **kwargs) + M, C1, C2, p, loss_fun, symmetric, epsilon, alpha, G0, max_iter, tol, + log=True, verbose=verbose, random_state=random_state) log_srfgw['T'] = T @@ -1107,12 +1196,241 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( return log_srfgw['srfgw_dist'] +def semirelaxed_gromov_barycenters( + N, Cs, ps=None, lambdas=None, loss_fun='square_loss', + symmetric=True, max_iter=1000, tol=1e-9, + stop_criterion='barycenter', warmstartT=False, verbose=False, + log=False, init_C=None, G0='product', random_state=None, **kwargs): + r""" + Returns the Semi-relaxed Gromov-Wasserstein barycenters of `S` measured similarity matrices + :math:`(\mathbf{C}_s)_{1 \leq s \leq S}` + + The function solves the following optimization problem with block coordinate descent: + + .. math:: + + \mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{srGW}(\mathbf{C}_s, \mathbf{p}_s, \mathbf{C}) + + Where : + + - :math:`\mathbf{C}_s`: input metric cost matrix + - :math:`\mathbf{p}_s`: distribution + + Parameters + ---------- + N : int + Size of the targeted barycenter + Cs : list of S array-like of shape (ns, ns) + Metric cost matrices + ps : list of S array-like of shape (ns,), optional + Sample weights in the `S` spaces. + If let to its default value None, uniform distributions are taken. + lambdas : array-like of shape (S,) , optional + List of the `S` spaces' weights. + If let to its default value None, uniform weights are taken. + loss_fun : callable, optional + tensor-matrix multiplication function based on specific loss function + symmetric : bool, optional. + Either structures are to be assumed symmetric or not. Default value is True. + Else if set to True (resp. False), C1 and C2 will be assumed symmetric (resp. asymmetric). + max_iter : int, optional + Max number of iterations + tol : float, optional + Stop threshold on relative error (>0) + stop_criterion : str, optional. Default is 'barycenter'. + Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter' + uses absolute norm variations of estimated barycenters. Else if set to 'loss' + uses the relative variations of the loss. + warmstartT: bool, optional + Either to perform warmstart of transport plans in the successive + fused gromov-wasserstein transport problems.s + verbose : bool, optional + Print information along iterations. + log : bool, optional + Record log if True. + init_C : array-like of shape (N,N), optional. + Random initial value for the :math:`\mathbf{C}` matrix provided by user. + Default is None and relies `G0` to produce an initial structure. + G0: str, optional. Default is 'product'. + Initialization method for transport plans calling :func:`ot.gromov.semirelaxed_init_plan`, + and taking values in "product", "random_product", "random", "fluid", + "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + Transport plans are used to deduce an initial barycenter structure + if `init_C=None`. + random_state : int or RandomState instance, optional + Fix the seed for reproducibility + + Returns + ------- + C : array-like, shape (`N`, `N`) + Barycenters' structure matrix + log : dict + Only returned when log=True. It contains the keys: + + - :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices + - :math:`\mathbf{p}`: (`N`,) barycenter weights + - values used in convergence evaluation. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + if stop_criterion not in ['barycenter', 'loss']: + raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.") + + arr = [*Cs] + if ps is not None: + arr += [*ps] + else: + ps = [unif(C.shape[0], type_as=C) for C in Cs] + + nx = get_backend(*arr) + + S = len(Cs) + if lambdas is None: + lambdas = nx.ones(S) / S + else: + lambdas = list_to_array(lambdas) + lambdas = nx.from_numpy(lambdas) + + # Initialization of transport plans and C (if not provided by user) + if init_C is None: + init_C = nx.zeros((N, N), type_as=Cs[0]) + if G0 in ['product', 'random_product', 'random']: + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + + if G0 in ['product', 'random_product']: + # initial structure is constant so we add a small random noise + # to avoid getting stuck at init + np.random.seed(random_state) + noise = np.random.uniform(-0.01, 0.01, size=(N, N)) + if symmetric: + noise = (noise + noise.T) / 2. + noise = nx.from_numpy(noise) + C = C + noise + + else: # relies on partitioning of inputs + shapes = np.array([C.shape[0] for C in Cs]) + large_graphs_idx = np.where(shapes > N)[0] + small_graphs_idx = np.where(shapes <= N)[0] + T = [None] * S + list_init_C = [] # store different barycenter structure to average + + # we first compute an initial informative barycenter structure + # on graphs we can compress + # then use it on graphs to expand + for indices in [large_graphs_idx, small_graphs_idx]: + if len(indices) > 0: + sub_T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in indices] + sub_Cs = [Cs[i] for i in indices] + sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) + init_C = update_barycenter_structure( + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + for i, idx in enumerate(indices): + T[idx] = sub_T[i] + list_init_C.append(init_C) + + if len(list_init_C) == 2: + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + C = init_C + + else: + C = init_C + T = [semirelaxed_init_plan( + Cs[i], C, ps[i], method=G0, use_target=True, + random_state=random_state, nx=nx) for i in range(S)] + + if stop_criterion == 'barycenter': + inner_log = False + else: + inner_log = True + curr_loss = 1e15 + + if log: + log_ = {} + log_['err'] = [] + if stop_criterion == 'loss': + log_['loss'] = [] + + for cpt in range(max_iter): + + if stop_criterion == 'barycenter': + Cprev = C + else: + prev_loss = curr_loss + + # get transport plans + if warmstartT: + res = [semirelaxed_gromov_wasserstein( + Cs[s], C, ps[s], loss_fun, symmetric, G0=T[s], + max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, + verbose=verbose, **kwargs) + for s in range(S)] + else: + res = [semirelaxed_gromov_wasserstein( + Cs[s], C, ps[s], loss_fun, symmetric, G0=G0, + max_iter=max_iter, tol_rel=tol, tol_abs=0., log=inner_log, + verbose=verbose, **kwargs) + for s in range(S)] + + if stop_criterion == 'barycenter': + T = res + else: + T = [output[0] for output in res] + curr_loss = np.sum([output[1]['srgw_dist'] for output in res]) + + # update barycenters + p = nx.concatenate( + [nx.sum(T[s], 0)[None, :] for s in range(S)], axis=0) + + C = update_barycenter_structure(T, Cs, lambdas, p, loss_fun, nx=nx) + + # update convergence criterion + if stop_criterion == 'barycenter': + err = nx.norm(C - Cprev) + if log: + log_['err'].append(err) + + else: + err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan + if log: + log_['loss'].append(curr_loss) + log_['err'].append(err) + + if verbose: + if cpt % 200 == 0: + print('{:5s}|{:12s}'.format( + 'It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(cpt, err)) + + if err <= tol: + break + + if log: + log_['T'] = T + log_['p'] = p + + return C, log_ + else: + return C + + def semirelaxed_fgw_barycenters( N, Ys, Cs, ps=None, lambdas=None, alpha=0.5, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', symmetric=True, max_iter=100, tol=1e-9, stop_criterion='barycenter', warmstartT=False, verbose=False, log=False, init_C=None, init_X=None, - random_state=None, **kwargs): + G0='product', random_state=None, **kwargs): r""" Returns the Semi-relaxed Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}` @@ -1142,7 +1460,7 @@ def semirelaxed_fgw_barycenters( ps : list of array-like, each element has shape (ns,), optional Masses of all samples. If let to its default value None, uniform distributions are taken. - lambdas : list of float, optional + lambdas : array-like of shape (S,) , optional List of the `S` spaces' weights. If let to its default value None, uniform weights are taken. alpha : float, optional @@ -1177,6 +1495,12 @@ def semirelaxed_fgw_barycenters( init_X : array-like, shape (N,d), optional Initialization for the barycenters' features. If not set a random init is used. + G0: str, optional. Default is "product". + Initialization method for transport plans calling :func:`ot.gromov.semirelaxed_init_plan`, + and taking values in "product", "random_product", "random", "fluid", + "fluid_soft", "spectral", "spectral_soft", "kmeans", "kmeans_soft". + Transport plans are used to deduce an initial barycenter structure + if `init_C=None`. random_state : int or RandomState instance, optional Fix the seed for reproducibility @@ -1212,9 +1536,10 @@ def semirelaxed_fgw_barycenters( S = len(Cs) if lambdas is None: - lambdas = [1. / S] * S - - d = Ys[0].shape[1] # dimension on the node features + lambdas = nx.ones(S) / S + else: + lambdas = list_to_array(lambdas) + lambdas = nx.from_numpy(lambdas) if fixed_structure: if init_C is None: @@ -1222,14 +1547,6 @@ def semirelaxed_fgw_barycenters( 'If C is fixed it must be provided in init_C') else: C = init_C - else: - if init_C is None: - rng = check_random_state(random_state) - xalea = rng.randn(N, 2) - C = dist(xalea, xalea) - C = nx.from_numpy(C, type_as=ps[0]) - else: - C = init_C if fixed_features: if init_X is None: @@ -1237,17 +1554,104 @@ def semirelaxed_fgw_barycenters( 'If X is fixed it must be provided in init_X') else: X = init_X - else: + + # Initialization of transport plans, C and X (if not provided by user) + if G0 in ['product', 'random_product', 'random']: + # both init_X and init_C are simply deduced from transport plans + # if not initialized + if init_C is None: + init_C = nx.zeros((N, N), type_as=Cs[0]) # to know the barycenter shape + + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + + C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + if G0 in ['product', 'random_product']: + # initial structure is constant so we add a small random noise + # to avoid getting stuck at init + np.random.seed(random_state) + noise = np.random.uniform(-0.01, 0.01, size=(N, N)) + if symmetric: + noise = (noise + noise.T) / 2. + noise = nx.from_numpy(noise) + C = C + noise + + else: + T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], method=G0, use_target=False, + random_state=random_state, nx=nx) for i in range(S)] + + C = init_C + if init_X is None: - X = nx.zeros((N, d), type_as=ps[0]) + X = update_barycenter_feature( + T, Ys, lambdas, loss_fun=loss_fun, nx=nx) + else: + X = init_X + + Ms = [dist(Ys[s], X) for s in range(len(Ys))] + else: + # more computationally costly inits could be used on structures + # so we assume affordable a Kmeans-like init for features + # and use it by default. + + if init_X is None: + stacked_features = nx.concatenate(Ys, axis=0) + if sklearn_import: + stacked_features = nx.to_numpy(stacked_features) + km = KMeans(n_clusters=N, random_state=random_state, + n_init=1).fit(stacked_features) + X = nx.from_numpy(km.cluster_centers_) + else: + raise ValueError( + f"Scikit-learn must be installed to use G0={G0} while init_X=None." + ) else: X = init_X - Ms = [dist(Ys[s], X) for s in range(len(Ys))] + Ms = [dist(Ys[s], X) for s in range(len(Ys))] - if warmstartT: - T = [None] * S + if (init_C is None): + init_C = nx.zeros((N, N), type_as=Cs[0]) + + # relies on partitioning of inputs + shapes = np.array([C.shape[0] for C in Cs]) + large_graphs_idx = np.where(shapes > N)[0] + small_graphs_idx = np.where(shapes <= N)[0] + T = [None] * S + list_init_C = [] # store different barycenter structure to average + + # we first compute an initial informative barycenter structure + # on graphs we can compress + # then use it on graphs to expand + for indices in [large_graphs_idx, small_graphs_idx]: + if len(indices) > 0: + sub_T = [semirelaxed_init_plan( + Cs[i], init_C, ps[i], Ms[i], alpha, method=G0, use_target=False, + random_state=random_state, nx=nx) for i in indices] + sub_Cs = [Cs[i] for i in indices] + sub_lambdas = lambdas[indices] / nx.sum(lambdas[indices]) + + init_C = update_barycenter_structure( + sub_T, sub_Cs, sub_lambdas, loss_fun=loss_fun, nx=nx) + + for i, idx in enumerate(indices): + T[idx] = sub_T[i] + + list_init_C.append(init_C) + + if len(list_init_C) == 2: + init_C = update_barycenter_structure( + T, Cs, lambdas, loss_fun=loss_fun, nx=nx) + C = init_C + else: + C = init_C + T = [semirelaxed_init_plan( + Cs[i], C, ps[i], Ms[i], alpha, method=G0, use_target=True, + random_state=random_state, nx=nx) for i in range(S)] if stop_criterion == 'barycenter': inner_log = False @@ -1277,13 +1681,14 @@ def semirelaxed_fgw_barycenters( if warmstartT: res = [semirelaxed_fused_gromov_wasserstein( Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, T[s], - inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) for s in range(S)] else: res = [semirelaxed_fused_gromov_wasserstein( - Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, None, - inner_log, max_iter, tol_rel=1e-5, tol_abs=0., **kwargs) + Ms[s], Cs[s], C, ps[s], loss_fun, symmetric, alpha, G0, + inner_log, max_iter, tol_rel=tol, tol_abs=0., **kwargs) for s in range(S)] + if stop_criterion == 'barycenter': T = res else: diff --git a/ot/gromov/_utils.py b/ot/gromov/_utils.py index d4928d062..5c465cba8 100644 --- a/ot/gromov/_utils.py +++ b/ot/gromov/_utils.py @@ -12,8 +12,25 @@ # License: MIT License -from ..utils import list_to_array +from ..utils import list_to_array, euclidean_distances from ..backend import get_backend +from ..lp import emd + +try: + from networkx.algorithms.community import asyn_fluidc + from networkx import from_numpy_array + networkx_import = True +except ImportError: + networkx_import = False + +try: + from sklearn.cluster import SpectralClustering, KMeans + sklearn_import = True +except ImportError: + sklearn_import = False + +import numpy as np +import warnings def init_matrix(C1, C2, p, q, loss_fun='square_loss', nx=None): @@ -302,6 +319,7 @@ def init_matrix_semirelaxed(C1, C2, p, loss_fun='square_loss', nx=None): C2 : array-like, shape (nt, nt) Metric cost matrix in the target space p : array-like, shape (ns,) + Probability distribution in the source space loss_fun : str, optional Name of loss function to use: either 'square_loss' or 'kl_loss' (default='square_loss') nx : backend, optional @@ -370,6 +388,216 @@ def h2(b): return constC, hC1, hC2, fC2t +def semirelaxed_init_plan(C1, C2, p, M=None, alpha=1., method='product', + use_target=True, random_state=0, nx=None): + r""" + Heuristics to initialize the semi-relaxed (F)GW transport plan + :math:`\mathbf{T} \in \mathcal{U}_{nt}(\mathbf{p})`, between a graph + :math:`(\mathbf{C1}, \mathbf{p})` and a structure matrix :math:`\mathbf{C2}`, + where :math:`\mathcal{U}_{nt}(\mathbf{p}) = \{\mathbf{T} \in \mathbb{R}_{+}^{ns * nt}, \mathbf{T} \mathbf{1}_{nt} = \mathbf{p} \}`. + Available methods are: + - "product" or "random_product": :math:`\mathbf{T} = \mathbf{pq}^{T}` + with :math:`\mathbf{q}` uniform or randomly samples in the nt probability simplex. + + - "random": random sampling in :math:`\mathcal{U}_{nt}(\mathbf{p})`. + + - "fluid": Fluid algorithm from networkx for graph partitioning. + + - "spectral", "kmeans" : Spectral or Kmeans clustering from sklearn. + + - "fluid_soft", "spectral_soft", "kmeans_soft": :math:`\mathbf{T}_0` given + by corresponding clustering with target marginal :math:`\mathbf{q}_0`, further + centered as :math:`\mathbf{T} = (\mathbf{T}_0 + \mathbf{pq}_0^T) / 2` . + + If a metric cost matrix between features across domains :math:`\mathbf{M}` + is a provided, it will be used as cost matrix in a semi-relaxed Wasserstein + problem providing :math:`\mathbf{T}_M \in \mathcal{U}_{nt}(\mathbf{p})`. Then + the outputed transport plan is :math:`\alpha \mathbf{T} + (1 - \alpha ) \mathbf{T}_{M}`. + + Parameters + ---------- + C1 : array-like, shape (ns, ns) + Metric cost matrix in the source space. + C2 : array-like, shape (nt, nt) + Metric cost matrix in the target space. + p : array-like, shape (ns,), optional. + Probability distribution in the source space. If let to None, uniform + weights are assumed on C1. + M : array-like, shape (ns, nt), optional. + Metric cost matrix between features across domains. + alpha : float, optional + Trade-off parameter (0 <= alpha <= 1) + method : str, optional + Method to initialize the transport plan. The default is 'product'. + use_target : bool, optional. + Whether or not to use the target structure/features to further align + transport plan provided by the `method`. + random_state: int, optional + Random seed used for stochastic methods. + nx : backend, optional + POT backend. + + Returns + ------- + T : array-like, shape (ns, ns) + Admissible transport plan for the sr(F)GW problems. + + References + ---------- + .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. + "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" + International Conference on Learning Representations (ICLR), 2022. + + """ + list_partitioning_methods = [ + 'fluid', 'spectral', 'kmeans', 'fluid_soft', 'spectral_soft', + 'kmeans_soft'] + + if method not in list_partitioning_methods + ['product', 'random_product', 'random']: + raise ValueError(f'Unsupported initialization method = {method}.') + + if (method in ['kmeans', 'kmeans_soft']) and (not sklearn_import): + raise ValueError(f'Scikit-learn must be installed to use method = {method}') + + if (method in ['fluid', 'fluid_soft']) and (not networkx_import): + raise ValueError(f'Networkx must be installed to use method = {method}') + + if nx is None: + nx = get_backend(C1, C2, p, M) + + n = C1.shape[0] + m = C2.shape[0] + min_size = min(n, m) + + if method in list_partitioning_methods: + if n > m: # partition C1 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C1) + elif m > n: # partition C2 to deduce map from C1 to C2 + C_to_partition = nx.to_numpy(C2) + else: # equal size -> simple Wasserstein alignment + C_to_partition = None + warnings.warn( + "Both structures have the same size so no partitioning is" + "performed to initialize the transport plan even though" + f"initialization method is {method}", + stacklevel=2 + ) + + def get_transport_from_partition(part): + if n > m: # partition C1 to deduce map from C1 to C2 + T_ = nx.eye(m, type_as=C1)[part] + T_ = p[:, None] * T_ + q = nx.sum(T_, 0) + + factored_C1 = nx.dot(nx.dot(T_.T, C1), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + if use_target: + M_structure = euclidean_distances(factored_C1, C2) + T_emd = emd(q, q, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd) + else: + T = T_ + + elif m > n: + T_ = nx.eye(n, type_as=C1)[part] / m # assume uniform masses on C2 + q = nx.sum(T_, 0) + + factored_C2 = nx.dot(nx.dot(T_.T, C2), T_) / nx.outer(q, q) + + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(factored_C2, C1) + T_emd = emd(q, p, M_structure) + inv_q = 1. / q + + T = nx.dot(T_, inv_q[:, None] * T_emd).T + q = nx.sum(T, 0) # uniform one + else: + # alignment of both structure seen as feature matrices + M_structure = euclidean_distances(C1, C2) + q = p + T = emd(p, q, M_structure) + + return T, q + + # Handle initialization via structure information + + if method == 'product': + q = nx.ones(m, type_as=C1) / m + T = nx.outer(p, q) + + elif method == 'random_product': + np.random.seed(random_state) + q = np.random.uniform(0, m, size=(m,)) + q = q / q.sum() + q = nx.from_numpy(q, type_as=p) + T = nx.outer(p, q) + + elif method == 'random': + np.random.seed(random_state) + U = np.random.uniform(0, n * m, size=(n, m)) + U = (p / U.sum(1))[:, None] * U + T = nx.from_numpy(U, type_as=C1) + + elif method in ['fluid', 'fluid_soft']: + # compute fluid partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + graph = from_numpy_array(C_to_partition) + part_sets = asyn_fluidc(graph, min_size, seed=random_state) + part = np.zeros(C_to_partition.shape[0], dtype=int) + for iset_, set_ in enumerate(part_sets): + set_ = list(set_) + part[set_] = iset_ + part = nx.from_numpy(part) + + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['spectral', 'spectral_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + sc = SpectralClustering(n_clusters=min_size, + random_state=random_state, + affinity='precomputed').fit(C_to_partition) + part = sc.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + elif method in ['kmeans', 'kmeans_soft']: + # compute spectral partitioning on the biggest graph + if C_to_partition is None: + T, q = get_transport_from_partition(None) + else: + km = KMeans(n_clusters=min_size, random_state=random_state, + n_init=1).fit(C_to_partition) + + part = km.labels_ + T, q = get_transport_from_partition(part) + + if 'soft' in method: + T = (T + nx.outer(p, q)) / 2. + + if (M is not None): + # Add feature information solving a semi-relaxed Wasserstein problem + # get minimum by rows as binary mask + TM = nx.ones(1, type_as=p) * (M == nx.reshape(nx.min(M, axis=1), (-1, 1))) + TM *= nx.reshape((p / nx.sum(TM, axis=1)), (-1, 1)) + + T = alpha * T + (1 - alpha) * TM + + return T + + def update_barycenter_structure( Ts, Cs, lambdas, p=None, loss_fun='square_loss', target=True, check_zeros=True, nx=None): @@ -437,10 +665,7 @@ def update_barycenter_structure( """ if nx is None: - arr = [*Ts, *Cs] - if p is not None: - arr += [p] - + arr = [*Ts, *Cs, p] nx = get_backend(*arr) S = len(Ts) @@ -461,7 +686,7 @@ def update_barycenter_structure( prod = nx.outer(inv_p, inv_p) else: - quotient = sum([nx.outer(p[s], p[s]) for s in range(S)]) + quotient = sum([lambdas[s] * nx.outer(p[s], p[s]) for s in range(S)]) if check_zeros: prod = nx.nan_to_num(1. / quotient, nan=1., posinf=1., neginf=1.) else: @@ -541,10 +766,7 @@ def update_barycenter_feature( International Conference on Learning Representations (ICLR), 2022. """ if nx is None: - arr = [*Ts, *Ys] - if p is not None: - arr += [p] - + arr = [*Ts, *Ys, p] nx = get_backend(*arr) if loss_fun != 'square_loss': @@ -568,7 +790,7 @@ def update_barycenter_feature( else: inv_p = 1. / p else: - p_sum = sum(p) + p_sum = sum([lambdas[s] * p[s] for s in range(S)]) if check_zeros: inv_p = nx.nan_to_num(1. / p_sum, nan=1., posinf=1., neginf=1.) else: diff --git a/ot/optim.py b/ot/optim.py index bde0fc814..a5f88bb29 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -476,7 +476,11 @@ def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask - Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) + min_ = nx.reshape(nx.min(Mi, axis=1), (-1, 1)) + # instead of exact elements equal to min_ we consider a small margin (1e-15) + # for float precision issues. Then the mass is splitted uniformly + # between these elements. + Gc = nx.ones(1, type_as=a) * (Mi <= min_ + 1e-15) Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) # return by default an empty inner_log return Gc, {} diff --git a/ot/solvers.py b/ot/solvers.py index 95165ea11..ae44e860f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -851,7 +851,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) plan = log['T'] value_linear = 0 @@ -865,7 +865,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, if tol is None: tol = 1e-9 - value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) + value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] diff --git a/test/gromov/test_semirelaxed.py b/test/gromov/test_semirelaxed.py index 2e4b2f128..acc49be5f 100644 --- a/test/gromov/test_semirelaxed.py +++ b/test/gromov/test_semirelaxed.py @@ -10,6 +10,9 @@ import ot from ot.backend import torch +from ot.gromov._utils import ( + networkx_import, sklearn_import) + def test_semirelaxed_gromov(nx): rng = np.random.RandomState(0) @@ -19,13 +22,19 @@ def test_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], + [0.1, 1.]], dtype=np.float64) + + pos = [0, 30, 45] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) G0 = p[:, None] * q0[None, :] @@ -60,19 +69,38 @@ def test_semirelaxed_gromov(nx): np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - # symmetric + # symmetric - testing various initialization of the OT plan. C1 = 0.5 * (C1 + C1.T) + C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_gromov_wasserstein( - C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_gromov_wasserstein( - C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=G0b) + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + if networkx_import: + init_plan_list += [('fluid', 'fluid'), ('fluid_soft', 'fluid_soft')] + + if sklearn_import: + init_plan_list += [ + ("spectral", "spectral"), ("spectral_soft", "spectral_soft"), + ("kmeans", "kmeans"), ("kmeans_soft", "kmeans_soft")] + + for (init, init_b) in init_plan_list: + + G, log = ot.gromov.semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', symmetric=None, log=True, G0=init) + Gb = ot.gromov.semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', symmetric=True, log=False, G0=init_b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + + if not isinstance(init, str): + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + else: + if 'spectral' not in init: # issues with spectral clustering related to label switching + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) srgw, log2 = ot.gromov.semirelaxed_gromov_wasserstein2( C1, C2, p, loss_fun='square_loss', symmetric=True, log=True, G0=G0) @@ -202,11 +230,17 @@ def test_semirelaxed_fgw(nx): C1 = np.zeros((ns, ns)) C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) + + pos = [0, 16, 24] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + F1 = np.zeros((ns, 1)) F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) @@ -218,57 +252,72 @@ def test_semirelaxed_fgw(nx): q0 = ot.unif(C2.shape[0]) G0 = p[:, None] * q0[None, :] - # asymmetric + # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) - Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( + M, C1, C2, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + Gb, logb = ot.gromov.semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0b) - # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) + # asymmetric - check consistency between srFGW and srFGW2 + srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun='square_loss', alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, loss_fun='square_loss', alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) - # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], G.sum(0), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - # symmetric + # symmetric structures + checking losses + inits + C1 = 0.5 * (C1 + C1.T) + Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] + + if networkx_import: + init_plan_list += [('fluid', 'fluid')] + + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + for loss_fun in ['square_loss', 'kl_loss']: - C1 = 0.5 * (C1 + C1.T) - Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) + for (init, init_b) in init_plan_list: - G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0b) + G, log = ot.gromov.semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=init) + Gb = ot.gromov.semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=init_b) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + # checking consistency with srFGW and srFGW2 solvers srgw, log2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=True, G0=G0) srgwb, logb2 = ot.gromov.semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, alpha=0.5, symmetric=None, log=True, G0=None) - srgw_ = ot.gromov.semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, alpha=0.5, symmetric=True, log=False, G0=G0) - - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) + G2 = log2['T'] + Gb2 = nx.to_numpy(logb2['T']) # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose(G2, Gb2, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(srgw, srgwb, atol=1e-07) def test_semirelaxed_fgw2_gradients(): @@ -390,14 +439,21 @@ def test_entropic_semirelaxed_gromov(nx): ns = np.sum(list_n) # create directed sbm with C2 as connectivity matrix C1 = np.zeros((ns, ns), dtype=np.float64) - C2 = np.array([[0.8, 0.05], - [0.05, 1.]], dtype=np.float64) + C2 = np.array([[0.8, 0.1], + [0.1, 0.9]], dtype=np.float64) + rng = np.random.RandomState(0) + + pos = [0, 30, 45] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + p = ot.unif(ns, type_as=C1) q0 = ot.unif(C2.shape[0], type_as=C1) G0 = p[:, None] * q0[None, :] @@ -427,29 +483,56 @@ def test_entropic_semirelaxed_gromov(nx): np.testing.assert_allclose(log2['srgw_dist'], logb['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) - # symmetric + # symmetric - testing various initialization of the OT plan. + C1 = 0.5 * (C1 + C1.T) C1b, C2b, pb, q0b, G0b = nx.from_numpy(C1, C2, p, q0, G0) - G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_gromov_wasserstein(C1b, C2b, None, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0b) + init_plan_list = [] # tests longer than with CG so we do not test all inits. - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + if networkx_import: + init_plan_list += [('fluid', 'fluid')] - srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, log=True, G0=None) + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] + + init_plan_list += [ + ('product', None), (None, G0b)] + + for (init, init_b) in init_plan_list: + print(f'---- init : {init} / init_b : {init_b}') + G, log = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=None, + log=True, G0=init) + Gb, logb = ot.gromov.entropic_semirelaxed_gromov_wasserstein( + C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=True, + log=True, G0=init_b) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + + if not isinstance(init, str): + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + else: + np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-02) + + # comparison between srGW and srGW2 solvers + + srgw, log2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, + log=True, G0=init) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_gromov_wasserstein2( + C1b, C2b, pb, loss_fun='square_loss', epsilon=epsilon, symmetric=None, + log=True, G0=init_b) srgw_ = ot.gromov.entropic_semirelaxed_gromov_wasserstein2(C1, C2, p, loss_fun='square_loss', epsilon=epsilon, symmetric=True, log=False, G0=G0) - G = log2['T'] + G2 = log2['T'] + G2b = logb2['T'] # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, 1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose(list_n / ns, np.sum(G, axis=0), atol=1e-01) - np.testing.assert_allclose(list_n / ns, nx.sum(Gb, axis=0), atol=1e-01) + np.testing.assert_allclose(G2, G2b, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) np.testing.assert_allclose(log2['srgw_dist'], log['srgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srgw_dist'], log['srgw_dist'], atol=1e-07) @@ -503,11 +586,17 @@ def test_entropic_semirelaxed_fgw(nx): C1 = np.zeros((ns, ns)) C2 = np.array([[0.7, 0.05], [0.05, 0.9]]) + + pos = [0, 16, 24] + for i in range(nt): for j in range(nt): ni, nj = list_n[i], list_n[j] xij = rng.binomial(size=(ni, nj), n=1, p=C2[i, j]) - C1[i * ni: (i + 1) * ni, j * nj: (j + 1) * nj] = xij + pos_i_min, pos_i_max = pos[i], pos[i + 1] + pos_j_min, pos_j_max = pos[j], pos[j + 1] + C1[pos_i_min: pos_i_max, pos_j_min: pos_j_max] = xij + F1 = np.zeros((ns, 1)) F1[:16] = rng.normal(loc=0., scale=0.01, size=(16, 1)) F1[16:] = rng.normal(loc=1., scale=0.01, size=(8, 1)) @@ -519,19 +608,20 @@ def test_entropic_semirelaxed_fgw(nx): q0 = ot.unif(C2.shape[0]) G0 = p[:, None] * q0[None, :] - # asymmetric + # asymmetric structure - checking constraints and values Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) Gb, logb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0b) - # check constraints np.testing.assert_allclose(G, Gb, atol=1e-06) np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=False, log=True, G0=G0) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, None, loss_fun='square_loss', epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) G = log2['T'] Gb = nx.to_numpy(logb2['T']) @@ -543,34 +633,51 @@ def test_entropic_semirelaxed_fgw(nx): np.testing.assert_allclose(log2['srfgw_dist'], logb['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - # symmetric + # symmetric structures + checking losses + inits C1 = 0.5 * (C1 + C1.T) Mb, C1b, C2b, pb, q0b, G0b = nx.from_numpy(M, C1, C2, p, q0, G0) - for loss_fun in ['square_loss', 'kl_loss']: - G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) - Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0b) + init_plan_list = [ + (None, G0b), ('product', None), ("random_product", "random_product")] - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + if networkx_import: + init_plan_list += [('fluid', 'fluid')] - srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=True, G0=G0) - srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=None, log=True, G0=None) + if sklearn_import: + init_plan_list += [("kmeans", "kmeans")] - srgw_ = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, symmetric=True, log=False, G0=G0) + for loss_fun in ['square_loss', 'kl_loss']: + for (init, init_b) in init_plan_list: - G = log2['T'] - Gb = nx.to_numpy(logb2['T']) - # check constraints - np.testing.assert_allclose(G, Gb, atol=1e-06) - np.testing.assert_allclose(p, Gb.sum(1), atol=1e-04) # cf convergence gromov - np.testing.assert_allclose([2 / 3, 1 / 3], Gb.sum(0), atol=1e-04) # cf convergence gromov + G, log = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=None, log=True, G0=init) + Gb = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein( + Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=True, log=False, G0=init_b) + + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(p, nx.sum(Gb, axis=1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], nx.sum(Gb, axis=0), atol=1e-02) # cf convergence gromov + + # checking consistency with srFGW and srFGW2 solvers + srgw, log2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + M, C1, C2, p, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=True, log=True, G0=init) + srgwb, logb2 = ot.gromov.entropic_semirelaxed_fused_gromov_wasserstein2( + Mb, C1b, C2b, pb, loss_fun=loss_fun, epsilon=0.1, alpha=0.5, + symmetric=None, log=True, G0=init_b) + + G2 = log2['T'] + Gb2 = nx.to_numpy(logb2['T']) + np.testing.assert_allclose(G2, Gb2, atol=1e-06) + np.testing.assert_allclose(G2, G, atol=1e-06) + np.testing.assert_allclose(p, Gb2.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose([2 / 3, 1 / 3], Gb2.sum(0), atol=1e-04) # cf convergence gromov np.testing.assert_allclose(log2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) np.testing.assert_allclose(logb2['srfgw_dist'], log['srfgw_dist'], atol=1e-07) - np.testing.assert_allclose(srgw, srgw_, atol=1e-07) + np.testing.assert_allclose(srgw, srgwb, atol=1e-07) @pytest.skip_backend("tf", reason="test very slow with tf backend") @@ -615,6 +722,178 @@ def test_entropic_semirelaxed_fgw_dtype_device(nx): nx.assert_same_dtype_device(C1b, fgw_valb) +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.skip_backend("jax", reason="test very slow with tf backend") +def test_semirelaxed_gromov_barycenter(nx): + ns = 5 + nt = 8 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + p1 = ot.unif(ns) + p2 = ot.unif(nt) + n_samples = 3 + + C1b, C2b, p1b, p2b = nx.from_numpy(C1, C2, p1, p2) + + # test on admissible stopping criterion + with pytest.raises(ValueError): + stop_criterion = 'unknown stop criterion' + _ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + + # test consistency of outputs across backends with 'square_loss' + # using different losses + # + tests on different inits + init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), + ('random', 'random')] + + for (init, init_b) in init_plan_list: + + for stop_criterion in ['barycenter', 'loss']: + print('--- stop_criterion:', stop_criterion) + + if (init == 'fluid') and (not networkx_import): + with pytest.raises(ValueError): + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + + elif (init == 'kmeans') and (not sklearn_import): + with pytest.raises(ValueError): + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + else: + + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42, G0=init + ) + + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0=init_b + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + Cb_, err_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], None, 'square_loss', max_iter=5, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + warmstartT=True, random_state=42, log=True, G0=init, + ) + Cbb_, errb_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, warmstartT=True, random_state=42, log=True, G0=init_b + ) + + Cbb_ = nx.to_numpy(Cbb_) + + np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06) + np.testing.assert_array_almost_equal(err_['err'], nx.to_numpy(*errb_['err'])) + np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples)) + + # test consistency across backends with larger barycenter than inputs + if sklearn_import: + C = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion='loss', verbose=False, + random_state=42, G0='kmeans' + ) + Cb = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0='kmeans') + + np.testing.assert_allclose(C, nx.to_numpy(Cb), atol=1e-06) + + # test providing init_C + C_ = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1, C2], None, [.5, .5], 'square_loss', max_iter=5, + tol=1e-3, stop_criterion='loss', verbose=False, + random_state=42, G0=init, init_C=C1 + ) + + Cb_ = ot.gromov.semirelaxed_gromov_barycenters( + ns, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=5, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42, G0=init_b, init_C=C1b) + + np.testing.assert_allclose(C_, Cb_, atol=1e-06) + + # test consistency across backends with 'kl_loss' + Cb2, err = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], [.5, .5], 'kl_loss', max_iter=5, + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, + G0=init_b, random_state=42 + ) + Cb2b, errb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'kl_loss', max_iter=5, + tol=1e-3, warmstartT=False, stop_criterion='loss', log=True, + G0=init_b, random_state=42 + ) + Cb2b = nx.to_numpy(Cb2b) + + try: + np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06) # may differ from permutation + except AssertionError: + np.testing.assert_allclose(err['loss'][-1], errb['loss'][-1], atol=1e-06) + + np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples)) + + # test of gromov_barycenters with `log` on + # providing init_C similarly than in the function. + rng = ot.utils.check_random_state(42) + xalea = rng.randn(n_samples, 2) + init_C = ot.utils.dist(xalea, xalea) + init_C /= init_C.max() + init_Cb = nx.from_numpy(init_C) + + Cb2_, err2_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1, C2], [p1, p2], [.5, .5], 'square_loss', max_iter=10, + tol=1e-3, verbose=False, random_state=42, log=True, init_C=init_C + ) + Cb2b_, err2b_ = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b, C2b], [p1b, p2b], [.5, .5], 'square_loss', + max_iter=10, tol=1e-3, verbose=True, random_state=42, + init_C=init_Cb, log=True + ) + Cb2b_ = nx.to_numpy(Cb2b_) + np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06) + np.testing.assert_array_almost_equal(err2_['err'], nx.to_numpy(*err2b_['err'])) + np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples)) + + # test edge cases for gw barycenters: + # unique input structure + Cb = ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1], None, None, 'square_loss', max_iter=1, + tol=1e-3, stop_criterion=stop_criterion, verbose=False, + random_state=42 + ) + Cbb = nx.to_numpy(ot.gromov.semirelaxed_gromov_barycenters( + n_samples, [C1b], None, [1.], 'square_loss', + max_iter=1, tol=1e-3, stop_criterion=stop_criterion, + verbose=False, random_state=42 + )) + np.testing.assert_allclose(Cb, Cbb, atol=1e-06) + np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples)) + + def test_semirelaxed_fgw_barycenter(nx): ns = 10 nt = 20 @@ -723,6 +1002,84 @@ def test_semirelaxed_fgw_barycenter(nx): np.testing.assert_allclose(C, recovered_C) + # test consistency of outputs across backends with 'square_loss' + # with various initialization of G0 + init_plan_list = [('fluid', 'fluid'), ("kmeans", "kmeans"), + ('product', 'product'), ('random', 'random')] + + for (init, init_b) in init_plan_list: + print(f'---- init : {init} / init_b : {init_b}') + + if (init == 'fluid') and (not networkx_import): + with pytest.raises(ValueError): + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + elif (init == 'kmeans') and (not sklearn_import): + with pytest.raises(ValueError): + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + else: + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0=init_b, + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) + + # test while providing advanced T inits and init_X != None, and init_C !=None + Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='random', + warmstartT=True, random_state=12345, log=True, verbose=True, + init_C=Cb, init_X=Xb + ) + np.testing.assert_allclose(Xb, Xb_) + np.testing.assert_allclose(Cb, Cb_) + + # test consistency of backends while barycenter size not strictly inferior to sizes + if sklearn_import: + Xb_, Cb_, logb_ = ot.gromov.semirelaxed_fgw_barycenters( + n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True, + init_C=Cb, init_X=Xb + ) + + X, C, log = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True + ) + Xb, Cb, logb = ot.gromov.semirelaxed_fgw_barycenters( + ns, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5, + fixed_structure=False, fixed_features=False, loss_fun='square_loss', + max_iter=10, tol=1e-3, stop_criterion='loss', G0='kmeans', + warmstartT=True, random_state=12345, log=True, verbose=True + ) + np.testing.assert_allclose(X, nx.to_numpy(Xb)) + np.testing.assert_allclose(C, nx.to_numpy(Cb)) + # test edge cases for semirelaxed fgw barycenters: # unique input structure X, C = ot.gromov.semirelaxed_fgw_barycenters( diff --git a/test/gromov/test_utils.py b/test/gromov/test_utils.py index ad94a4042..70894fcfc 100644 --- a/test/gromov/test_utils.py +++ b/test/gromov/test_utils.py @@ -8,6 +8,8 @@ import pytest import ot +from ot.gromov._utils import ( + networkx_import, sklearn_import) def test_update_barycenter(nx): @@ -61,3 +63,51 @@ def test_update_barycenter(nx): with pytest.raises(ValueError): Xbt = ot.gromov.update_barycenter_feature( Tb, Ysb, lambdas, None, loss_fun='unknown', target=True) + + +def test_semirelaxed_init_plan(nx): + ns = 5 + nt = 10 + + Xs, ys = ot.datasets.make_data_classif('3gauss', ns, random_state=42) + Xt, yt = ot.datasets.make_data_classif('3gauss2', nt, random_state=42) + + rng = np.random.RandomState(42) + ys = rng.randn(Xs.shape[0], 2) + yt = rng.randn(Xt.shape[0], 2) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + C1 /= C1.max() + C2 /= C2.max() + + p1, p2 = ot.unif(ns), ot.unif(nt) + + ysb, ytb, C1b, C2b, p1b, p2b = nx.from_numpy(ys, yt, C1, C2, p1, p2) + + # test not supported method + with pytest.raises(ValueError): + _ = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='unknown') + + if sklearn_import: + # tests consistency across backends with m > n + for method in ['kmeans', 'spectral']: + T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) + Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method=method) + np.testing.assert_allclose(T, Tb) + + # tests consistency across backends with m = n + T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method=method) + Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method=method) + np.testing.assert_allclose(T, Tb) + + if networkx_import: + # tests consistency across backends with m > n + T = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') + Tb = ot.gromov.semirelaxed_init_plan(C1b, C2b, p1b, method='fluid') + np.testing.assert_allclose(T, Tb) + + # tests consistency across backends with m = n + T = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') + Tb = ot.gromov.semirelaxed_init_plan(C1b, C1b, p1b, method='fluid') + np.testing.assert_allclose(T, Tb)