Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent of log in entropic_gromov_barycenters and gromov_barycenters when cal gw with log information #317

Closed
cshjin opened this issue Dec 2, 2021 · 1 comment

Comments

@cshjin
Copy link
Contributor

cshjin commented Dec 2, 2021

Describe the bug

To Reproduce

Steps to reproduce the behavior:

  1. Calculating the barycenter with the optional arg log=True.
  2. log=True in gromov_wasserstein returns an additional log dictionary, similar as entropic_gromov_wasserstein

Screenshots

image

image

Code sample

import networkx as nx
import numpy as np
from scipy.sparse.csgraph import shortest_path
from ot.gromov import gromov_barycenters

Gs = [nx.cycle_graph(4)]
Ds = [shortest_path(nx.adjacency_matrix(g)) for g in Gs]
ps = [np.ones(4) / 4]
lambdas = np.ones(len(Gs)) / len(Gs)
N = 4
p = np.ones(N) / N

C = gromov_barycenters(N, Ds, ps, p, lambdas, "square_loss", log=True)

Expected behavior

  • the internal log information from gromov_barycenter is not necessary.
  • only the error in each iteration in the while loop need to be recorded.
  • return C if log is False, return C, {"err": [...]} if log=True

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux):
  • Python version:
  • How was POT installed (source, pip, conda):
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

# output:
Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid
Python 3.7.10 (default, Feb 26 2021, 18:47:35) 
[GCC 7.3.0]
NumPy 1.19.2
SciPy 1.6.2
POT 0.7.0

Additional context

The issue happens in the version 0.7.0, but I checked the code in the latest version (0.8.0).
The problem exists as well.

Issue happens in the following lines when log=True

POT/ot/gromov.py

Lines 1506 to 1507 in cb51064

T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun,
numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)]

POT/ot/gromov.py

Lines 1520 to 1521 in cb51064

if log:
log['err'].append(err)

@cshjin
Copy link
Contributor Author

cshjin commented Dec 6, 2021

Fixed in b3dc68f

@cshjin cshjin closed this as completed Dec 6, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant