Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples):
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]).
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding [barycenter solvers](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.hmtl) (exact and regularized [48]).
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.
Expand Down
5 changes: 2 additions & 3 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#### New features
- Add feature `mass=True` for `nx.kl_div` (PR #654)
- Gaussian Mixture Model OT `ot.gmm` (PR #649)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter
updates `update_barycenter_structure` and `update_barycenter_feature` (PR
#659)
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
- Add initialization heuristics for sr(F)GW problems via `semirelaxed_init_plan`, integrated in all sr(F)GW solvers (PR #659)
- Improved `ot.plot.plot1D_mat` (PR #649)
- Added `nx.det` (PR #649)
- `nx.sqrtm` is now broadcastable (takes ..., d, d) inputs (PR #649)
Expand Down
266 changes: 266 additions & 0 deletions examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
# -*- coding: utf-8 -*-

r"""
=====================================================
Semi-relaxed (Fused) Gromov-Wasserstein Barycenter as Dictionary Learning
=====================================================

In this example, we illustrate how to learn a semi-relaxed Gromov-Wasserstein
(srGW) barycenter using a Block-Coordinate Descent algorithm, on a dataset of
structured data such as graphs, denoted :math:`\{ \mathbf{C_s} \}_{s \in [S]}`
where every nodes have uniform weights :math:`\{ \mathbf{p_s} \}_{s \in [S]}`.
Given a barycenter structure matrix :math:`\mathbf{C}` with N nodes,
each graph :math:`(\mathbf{C_s}, \mathbf{p_s})` is modeled as a reweighed subgraph
with structure :math:`\mathbf{C}` and weights :math:`\mathbf{w_s} \in \Sigma_N`
where each :math:`\mathbf{w_s}` corresponds to the second marginal of the OT
:math:`\mathbf{T_s}` (s.t :math:`\mathbf{w_s} = \mathbf{T_s}^\top \mathbf{1}`)
minimizing the srGW loss between the s^{th} input and the barycenter.


First, we consider a dataset composed of graphs generated by Stochastic Block models
with variable sizes taken in :math:`\{30, ... , 50\}` and number of clusters
varying in :math:`\{ 1, 2, 3\}` with random proportions. We learn a srGW barycenter
with 3 nodes and visualize the learned structure and the embeddings for some inputs.

Second, we illustrate the extension of this framework to graphs endowed
with node features by using the semi-relaxed Fused Gromov-Wasserstein
divergence (srFGW). Starting from the aforementioned dataset of unattributed graphs, we
add discrete labels uniformly depending on the number of clusters. Then conduct
the analog analysis.


[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs".
International Conference on Learning Representations (ICLR), 2022.

"""
# Author: Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pylab as pl
from sklearn.manifold import MDS
from ot.gromov import (
semirelaxed_gromov_barycenters, semirelaxed_fgw_barycenters)
import ot
import networkx
from networkx.generators.community import stochastic_block_model as sbm

#############################################################################
#
# Generate a dataset composed of graphs following Stochastic Block models of 1, 2 and 3 clusters.
# -----------------------------------------------------------------------------------------------

np.random.seed(42)

n_samples = 60 # number of graphs in the dataset
# For every number of clusters, we generate SBM with fixed inter/intra-clusters probability,
# and variable cluster proportions.
clusters = [1, 2, 3]
Nc = n_samples // len(clusters) # number of graphs by cluster
nlabels = len(clusters)
dataset = []
node_labels = []
labels = []

p_inter = 0.1
p_intra = 0.9
for n_cluster in clusters:
for i in range(Nc):
n_nodes = int(np.random.uniform(low=30, high=50))

if n_cluster > 1:
P = p_inter * np.ones((n_cluster, n_cluster))
np.fill_diagonal(P, p_intra)
props = np.random.uniform(0.2, 1, size=(n_cluster,))
props /= props.sum()
sizes = np.round(n_nodes * props).astype(np.int32)
else:
P = p_intra * np.eye(1)
sizes = [n_nodes]

G = sbm(sizes, P, seed=i, directed=False)
part = np.array([G.nodes[i]['block'] for i in range(np.sum(sizes))])
C = networkx.to_numpy_array(G)
dataset.append(C)
node_labels.append(part)
labels.append(n_cluster)


# Visualize samples

def plot_graph(x, C, binary=True, color='C0', s=None):
for j in range(C.shape[0]):
for i in range(j):
if binary:
if C[i, j] > 0:
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k')
else: # connection intensity proportional to C[i,j]
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=C[i, j], color='k')

pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9)


pl.figure(1, (12, 8))
pl.clf()
for idx_c, c in enumerate(clusters):
C = dataset[(c - 1) * Nc] # sample with c clusters
# get 2d position for nodes
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
pl.subplot(2, nlabels, c)
pl.title('(graph) sample from label ' + str(c), fontsize=14)
plot_graph(x, C, binary=True, color='C0', s=50.)
pl.axis("off")
pl.subplot(2, nlabels, nlabels + c)
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
pl.imshow(C, interpolation='nearest')
pl.axis("off")
pl.tight_layout()
pl.show()

#############################################################################
#
# Estimate the srGW barycenter from the dataset and visualize embeddings
# -----------------------------------------------------------


np.random.seed(0)
ps = [ot.unif(C.shape[0]) for C in dataset] # uniform weights on input nodes
lambdas = [1. / n_samples for _ in range(n_samples)] # uniform barycenter
N = 3 # 3 nodes in the barycenter

# Here we use the Fluid partitioning method to deduce initial transport plans
# for the barycenter problem. An initlal structure is also deduced from these
# initial transport plans. Then a warmstart strategy is used iteratively to
# init each individual srGW problem within the BCD algorithm.

init_plan = 'fluid' # notice that several init options are implemented in `ot.gromov.semirelaxed_init_plan`
warmstartT = True

C, log = semirelaxed_gromov_barycenters(
N=N, Cs=dataset, ps=ps, lambdas=lambdas, loss_fun='square_loss',
tol=1e-6, stop_criterion='loss', warmstartT=warmstartT, log=True,
G0=init_plan, verbose=False)

print('barycenter structure:', C)

unmixings = log['p']
# Compute the 2D representation of the embeddings living in the 2-simplex of probability
unmixings2D = np.zeros(shape=(n_samples, 2))
for i, w in enumerate(unmixings):
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
x = [0., 0.]
y = [1., 0.]
z = [0.5, np.sqrt(3) / 2.]
extremities = np.stack([x, y, z])

pl.figure(2, (4, 4))
pl.clf()
pl.title('Embedding space', fontsize=14)
for cluster in range(nlabels):
start, end = Nc * cluster, Nc * (cluster + 1)
if cluster == 0:
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster')
else:
pl.scatter(unmixings2D[start:end, 0], unmixings2D[start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1))
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes')
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()

#############################################################################
#
# Endow the dataset with node features
# ------------------------------------
# node labels, corresponding to the true SBM cluster assignments,
# are set for each graph as one-hot encoded node features.

dataset_features = []
for i in range(len(dataset)):
n = dataset[i].shape[0]
F = np.zeros((n, 3))
F[np.arange(n), node_labels[i]] = 1.
dataset_features.append(F)

pl.figure(3, (12, 8))
pl.clf()
for idx_c, c in enumerate(clusters):
C = dataset[(c - 1) * Nc] # sample with c clusters
F = dataset_features[(c - 1) * Nc]
colors = [f'C{labels[i]}' for i in range(F.shape[0])]
# get 2d position for nodes
x = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C)
pl.subplot(2, nlabels, c)
pl.title('(graph) sample from label ' + str(c), fontsize=14)
plot_graph(x, C, binary=True, color=colors, s=50)
pl.axis("off")
pl.subplot(2, nlabels, nlabels + c)
pl.title('(matrix) sample from label %s \n' % c, fontsize=14)
pl.imshow(C, interpolation='nearest')
pl.axis("off")
pl.tight_layout()
pl.show()

#############################################################################
#
# Estimate the srFGW barycenter from the attributed graphs and visualize embeddings
# -----------------------------------------------------------
# We emphasize the dependence to the trade-off parameter alpha that weights the
# relative importance between structures (alpha=1) and features (alpha=0),
# knowing that embeddings that perfectly cluster graphs w.r.t their features
# should ease the identification of the number of clusters in the graphs.

list_alphas = [0.0001, 0.5, 0.9999]
list_unmixings2D = []

for ialpha, alpha in enumerate(list_alphas):
print('--- alpha:', alpha)
C, F, log = semirelaxed_fgw_barycenters(
N=N, Ys=dataset_features, Cs=dataset, ps=ps, lambdas=lambdas,
alpha=alpha, loss_fun='square_loss', tol=1e-6, stop_criterion='loss',
warmstartT=warmstartT, log=True, G0=init_plan)

print('barycenter structure:', C)
print('barycenter features:', F)

unmixings = log['p']
# Compute the 2D representation of the embeddings living in the 2-simplex of probability
unmixings2D = np.zeros(shape=(n_samples, 2))
for i, w in enumerate(unmixings):
unmixings2D[i, 0] = (2. * w[1] + w[2]) / 2.
unmixings2D[i, 1] = (np.sqrt(3.) * w[2]) / 2.
list_unmixings2D.append(unmixings2D.copy())

x = [0., 0.]
y = [1., 0.]
z = [0.5, np.sqrt(3) / 2.]
extremities = np.stack([x, y, z])

pl.figure(4, (12, 4))
pl.clf()
pl.suptitle('Embedding spaces', fontsize=14)
for ialpha, alpha in enumerate(list_alphas):
pl.subplot(1, len(list_alphas), ialpha + 1)
pl.title(f'alpha = {alpha}', fontsize=14)
for cluster in range(nlabels):
start, end = Nc * cluster, Nc * (cluster + 1)
if cluster == 0:
pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='1 cluster')
else:
pl.scatter(list_unmixings2D[ialpha][start:end, 0], list_unmixings2D[ialpha][start:end, 1], c='C' + str(cluster), marker='o', s=80., label='%s clusters' % (cluster + 1))
pl.scatter(extremities[:, 0], extremities[:, 1], c='black', marker='x', s=100., label='bary. nodes')
pl.plot([x[0], y[0]], [x[1], y[1]], color='black', linewidth=2.)
pl.plot([x[0], z[0]], [x[1], z[1]], color='black', linewidth=2.)
pl.plot([y[0], z[0]], [y[1], z[1]], color='black', linewidth=2.)
pl.axis('off')
pl.legend(fontsize=11)
pl.tight_layout()
pl.show()
14 changes: 8 additions & 6 deletions ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@

# All submodules and packages
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
init_matrix_semirelaxed,
update_barycenter_structure, update_barycenter_feature,
)
init_matrix_semirelaxed, semirelaxed_init_plan,
update_barycenter_structure, update_barycenter_feature)

from ._gw import (gromov_wasserstein, gromov_wasserstein2,
fused_gromov_wasserstein, fused_gromov_wasserstein2,
Expand Down Expand Up @@ -42,6 +41,7 @@
entropic_semirelaxed_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein,
entropic_semirelaxed_fused_gromov_wasserstein2,
semirelaxed_gromov_barycenters,
semirelaxed_fgw_barycenters)

from ._dictionary import (gromov_wasserstein_dictionary_learning,
Expand All @@ -64,7 +64,7 @@
)

__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
'init_matrix_semirelaxed',
'init_matrix_semirelaxed', 'semirelaxed_init_plan',
'update_barycenter_structure', 'update_barycenter_feature',
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
Expand All @@ -78,11 +78,13 @@
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
'solve_semirelaxed_gromov_linesearch', 'entropic_semirelaxed_gromov_wasserstein',
'entropic_semirelaxed_gromov_wasserstein2', 'entropic_semirelaxed_fused_gromov_wasserstein',
'entropic_semirelaxed_fused_gromov_wasserstein2', 'gromov_wasserstein_dictionary_learning',
'entropic_semirelaxed_fused_gromov_wasserstein2',
'semirelaxed_fgw_barycenters', 'semirelaxed_gromov_barycenters',
'gromov_wasserstein_dictionary_learning',
'gromov_wasserstein_linear_unmixing', 'fused_gromov_wasserstein_dictionary_learning',
'fused_gromov_wasserstein_linear_unmixing', 'lowrank_gromov_wasserstein_samples',
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
'get_graph_representants', 'format_partitioned_graph',
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
'semirelaxed_fgw_barycenters']
]
Loading