Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
- Create `ot/bregman/`repository (Issue #567, PR #569)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)
- Fix (fused) gromov-wasserstein barycenter solvers to support `kl_loss`(PR #576)


## 0.9.1
Expand Down Expand Up @@ -602,4 +603,4 @@ It provides the following solvers:
* Optimal transport for domain adaptation with group lasso regularization
* Conditional gradient and Generalized conditional gradient for regularized OT.

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
25 changes: 12 additions & 13 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,17 +457,17 @@ def entropic_gromov_barycenters(
Cprev = C
if warmstartT:
T = [entropic_gromov_wasserstein(
Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, T[s],
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s],
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
else:
T = [entropic_gromov_wasserstein(
Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, None,
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None,
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)
C = update_kl_loss(p, lambdas, T, Cs, nx)

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
Expand Down Expand Up @@ -962,9 +962,9 @@ def entropic_fused_gromov_barycenters(
Y = init_Y

if warmstartT:
T = [nx.outer(p_, p) for p_ in ps]
T = [None] * S

Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

cpt = 0
err = 1
Expand All @@ -984,23 +984,22 @@ def entropic_fused_gromov_barycenters(

if warmstartT:
T = [entropic_fused_gromov_wasserstein(
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]

else:
T = [entropic_fused_gromov_wasserstein(
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)
C = update_kl_loss(p, lambdas, T, Cs, nx)

Ys_temp = [y.T for y in Ys]
T_temp = [Ts.T for Ts in T]
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p).T
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]
Y = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
Expand Down
31 changes: 16 additions & 15 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,16 +830,18 @@ def gromov_barycenters(
Cprev = C

if warmstartT:
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
T = [gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=T[s],
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
else:
T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
T = [gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, symmetric=symmetric, armijo=armijo, G0=None,
max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, log=False, **kwargs) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs)
C = update_kl_loss(p, lambdas, T, Cs, nx)

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
Expand Down Expand Up @@ -898,14 +900,14 @@ def fgw_barycenters(
If let to its default value None, uniform weights are taken.
alpha : float, optional
Alpha parameter for the fgw distance.
fixed_structure : bool
Whether to fix the structure of the barycenter during the updates
fixed_features : bool
fixed_structure : bool, optional
Whether to fix the structure of the barycenter during the updates.
fixed_features : bool, optional
Whether to fix the feature of the barycenter during the updates
p : array-like, shape (N,), optional
Weights in the targeted barycenter.
If let to its default value None, uniform distribution is taken.
loss_fun : str
loss_fun : str, optional
Loss function used for the solver either 'square_loss' or 'kl_loss'
symmetric : bool, optional
Either structures are to be assumed symmetric or not. Default value is True.
Expand Down Expand Up @@ -1024,19 +1026,18 @@ def fgw_barycenters(
T = [fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
# T is N,ns

if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
T_temp = [t.T for t in T]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T_temp, Cs)
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T_temp, Cs)
C = update_kl_loss(p, lambdas, T, Cs, nx)

err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
err_structure = nx.norm(C - Cprev)
Expand Down
87 changes: 62 additions & 25 deletions ot/gromov/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,46 +253,75 @@ def gwggrad(constC, hC1, hC2, T, nx=None):
T, nx) # [12] Prop. 2 misses a 2 factor


def update_square_loss(p, lambdas, T, Cs):
def update_square_loss(p, lambdas, T, Cs, nx=None):
r"""
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s`
couplings calculated at each iteration
Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S`
:math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
barycenter problem in :ref:`[12]`:

.. math::

\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)

Where :

- :math:`\mathbf{C}_s`: metric cost matrix
- :math:`\mathbf{p}_s`: distribution

Parameters
----------
p : array-like, shape (N,)
Masses in the targeted barycenter.
lambdas : list of float
List of the `S` spaces' weights.
T : list of S array-like of shape (ns,N)
T : list of S array-like of shape (N, ns)
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
----------
C : array-like, shape (`nt`, `nt`)
Updated :math:`\mathbf{C}` matrix.

References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
T = list_to_array(*T)
Cs = list_to_array(*Cs)
p = list_to_array(p)
nx = get_backend(p, *T, *Cs)
if nx is None:
nx = get_backend(p, *T, *Cs)

# Correct order mistake in Equation 14 in [12]
tmpsum = sum([
lambdas[s] * nx.dot(
nx.dot(T[s].T, Cs[s]),
T[s]
nx.dot(T[s], Cs[s]),
T[s].T
) for s in range(len(T))
])
ppt = nx.outer(p, p)

return tmpsum / ppt


def update_kl_loss(p, lambdas, T, Cs):
def update_kl_loss(p, lambdas, T, Cs, nx=None):
r"""
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S`
:math:`\mathbf{T}_s` couplings calculated at each iteration of the GW
barycenter problem in :ref:`[12]`:

.. math::

\mathbf{C}^* = \mathop{\arg \min}_{\mathbf{C}\in \mathbb{R}^{N \times N}} \quad \sum_s \lambda_s \mathrm{GW}(\mathbf{C}, \mathbf{C}_s, \mathbf{p}, \mathbf{p}_s)

Where :

- :math:`\mathbf{C}_s`: metric cost matrix
- :math:`\mathbf{p}_s`: distribution


Parameters
Expand All @@ -301,33 +330,41 @@ def update_kl_loss(p, lambdas, T, Cs):
Weights in the targeted barycenter.
lambdas : list of float
List of the `S` spaces' weights
T : list of S array-like of shape (ns,N)
T : list of S array-like of shape (N, ns)
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration.
Cs : list of S array-like, shape(ns,ns)
Metric cost matrices.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
----------
C : array-like, shape (`ns`, `ns`)
updated :math:`\mathbf{C}` matrix

References
----------
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon,
"Gromov-Wasserstein averaging of kernel and distance matrices."
International Conference on Machine Learning (ICML). 2016.

"""
Cs = list_to_array(*Cs)
T = list_to_array(*T)
p = list_to_array(p)
nx = get_backend(p, *T, *Cs)
if nx is None:
nx = get_backend(p, *T, *Cs)

# Correct order mistake in Equation 15 in [12]
tmpsum = sum([
lambdas[s] * nx.dot(
nx.dot(T[s].T, Cs[s]),
T[s]
nx.dot(T[s], nx.log(nx.maximum(Cs[s], 1e-15))),
T[s].T
) for s in range(len(T))
])
ppt = nx.outer(p, p)

return nx.exp(tmpsum / ppt)


def update_feature_matrix(lambdas, Ys, Ts, p):
def update_feature_matrix(lambdas, Ys, Ts, p, nx=None):
r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings.


Expand All @@ -340,10 +377,12 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
masses in the targeted barycenter
lambdas : list of float
List of the `S` spaces' weights
Ts : list of S array-like, shape (ns,N)
Ts : list of S array-like, shape (N, ns)
The `S` :math:`\mathbf{T}_s` couplings calculated at each iteration
Ys : list of S array-like, shape (d,ns)
The features.
nx : backend, optional
If let to its default value None, a backend test will be conducted.

Returns
-------
Expand All @@ -357,10 +396,8 @@ def update_feature_matrix(lambdas, Ys, Ts, p):
"Optimal Transport for structured data with application on graphs"
International Conference on Machine Learning (ICML). 2019.
"""
p = list_to_array(p)
Ts = list_to_array(*Ts)
Ys = list_to_array(*Ys)
nx = get_backend(*Ys, *Ts, p)
if nx is None:
nx = get_backend(*Ys, *Ts, p)

p = 1. / p
tmpsum = sum([
Expand Down
32 changes: 23 additions & 9 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,20 +1441,29 @@ def test_fgw_barycenter(nx):
p = ot.unif(n_samples)

ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)

Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], None, [.5, .5], 0.5, fixed_structure=False,
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3, random_state=12345
lambdas = [.5, .5]
Csb = [C1b, C2b]
Ysb = [ysb, ytb]
Xb, Cb, logb = ot.gromov.fgw_barycenters(
n_samples, Ysb, Csb, None, lambdas, 0.5, fixed_structure=False,
fixed_features=False, p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
random_state=12345, log=True
)
# test correspondance with utils function
recovered_Cb = ot.gromov.update_square_loss(pb, lambdas, logb['Ts_iter'][-1], Csb)
recovered_Xb = ot.gromov.update_feature_matrix(lambdas, [y.T for y in Ysb], logb['Ts_iter'][-1], pb).T

np.testing.assert_allclose(Cb, recovered_Cb)
np.testing.assert_allclose(Xb, recovered_Xb)

xalea = rng.randn(n_samples, 2)
init_C = ot.dist(xalea, xalea)
init_C /= init_C.max()
init_Cb = nx.from_numpy(init_C)

with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_structure=True`and `init_C=None`
Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
n_samples, Ysb, Csb, ps=[p1b, p2b], lambdas=None,
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
)
Expand All @@ -1471,7 +1480,7 @@ def test_fgw_barycenter(nx):
init_X = rng.randn(n_samples, ys.shape[1])
init_Xb = nx.from_numpy(init_X)

with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
with pytest.raises(ot.utils.UndefinedParameter): # to raise an error when `fixed_features=True`and `init_X=None`
Xb, Cb, logb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
fixed_structure=False, fixed_features=True, init_X=None,
Expand All @@ -1490,14 +1499,19 @@ def test_fgw_barycenter(nx):
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))

# add test with 'kl_loss'
X, C = ot.gromov.fgw_barycenters(
X, C, log = ot.gromov.fgw_barycenters(
n_samples, [ys, yt], [C1, C2], [p1, p2], [.5, .5], 0.5,
fixed_structure=False, fixed_features=False, p=p, loss_fun='kl_loss',
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True, random_state=12345
max_iter=100, tol=1e-3, init_C=C, init_X=X, warmstartT=True,
random_state=12345, log=True
)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))

# test correspondance with utils function
recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['Ts_iter'][-1], [C1, C2])
np.testing.assert_allclose(C, recovered_C)


def test_gromov_wasserstein_linear_unmixing(nx):
n = 4
Expand Down