Skip to content

Commit

Permalink
[MRG] Fix issue 317 (#318)
Browse files Browse the repository at this point in the history
* Fix issue 317

* Update with docs and tests

Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
  • Loading branch information
cshjin and rflamary committed Dec 6, 2021
1 parent ca69658 commit b3dc68f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 10 deletions.
24 changes: 14 additions & 10 deletions ot/gromov.py
Expand Up @@ -1368,6 +1368,8 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
Expand Down Expand Up @@ -1401,7 +1403,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
Cprev = C

T = [entropic_gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, epsilon,
max_iter, 1e-4, verbose, log) for s in range(S)]
max_iter, 1e-4, verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

Expand All @@ -1414,9 +1416,6 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,
err = nx.norm(C - Cprev)
error.append(err)

if log:
log['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
Expand All @@ -1425,7 +1424,10 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon,

cpt += 1

return C
if log:
return C, {"err": error}
else:
return C


def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Expand Down Expand Up @@ -1479,6 +1481,8 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
-------
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
References
----------
Expand Down Expand Up @@ -1513,7 +1517,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
Cprev = C

T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=False) for s in range(S)]
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs)

Expand All @@ -1526,9 +1530,6 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,
err = nx.norm(C - Cprev)
error.append(err)

if log:
log['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
Expand All @@ -1537,7 +1538,10 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun,

cpt += 1

return C
if log:
return C, {"err": error}
else:
return C


def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False,
Expand Down
56 changes: 56 additions & 0 deletions test/test_gromov.py
Expand Up @@ -385,6 +385,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))

# test of gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
'square_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))

Cb2 = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, random_state=42
Expand All @@ -396,6 +410,20 @@ def test_gromov_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))

# test of gromov_barycenters with `log` on
Cb2_, err2_ = ot.gromov.gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cb2b_, err2b_ = ot.gromov.gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
'kl_loss', max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))


@pytest.mark.filterwarnings("ignore:divide")
def test_gromov_entropic_barycenter(nx):
Expand Down Expand Up @@ -429,6 +457,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb, Cbb, atol=1e-06)
np.testing.assert_allclose(Cbb.shape, (n_samples, n_samples))

# test of entropic_gromov_barycenters with `log` on
Cb_, err_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_, errb_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
'square_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cbb_ = nx.to_numpy(Cbb_)
np.testing.assert_allclose(Cb_, Cbb_, atol=1e-06)
np.testing.assert_array_almost_equal(err_['err'], errb_['err'])
np.testing.assert_allclose(Cbb_.shape, (n_samples, n_samples))

Cb2 = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, random_state=42
Expand All @@ -440,6 +482,20 @@ def test_gromov_entropic_barycenter(nx):
np.testing.assert_allclose(Cb2, Cb2b, atol=1e-06)
np.testing.assert_allclose(Cb2b.shape, (n_samples, n_samples))

# test of entropic_gromov_barycenters with `log` on
Cb2_, err2_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], [p1, p2], p, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cb2b_, err2b_ = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1b, C2b], [p1b, p2b], pb, [.5, .5],
'kl_loss', 1e-3, max_iter=100, tol=1e-3, verbose=True, random_state=42, log=True
)
Cb2b_ = nx.to_numpy(Cb2b_)
np.testing.assert_allclose(Cb2_, Cb2b_, atol=1e-06)
np.testing.assert_array_almost_equal(err2_['err'], err2_['err'])
np.testing.assert_allclose(Cb2b_.shape, (n_samples, n_samples))


def test_fgw(nx):
n_samples = 50 # nb samples
Expand Down

0 comments on commit b3dc68f

Please sign in to comment.