-
Notifications
You must be signed in to change notification settings - Fork 538
Closed
Labels
Description
Describe the bug
It seems ot.emd fails to return an optimal plan (up to numerical precision) if there is large entries in the cost matrix (even if the optimal weight to put on these entries is 0).
To Reproduce
import numpy as np
import ot
M = np.array(
[
[2.50275352e02, 3.74653218e02, 2.41352736e03, 1.00000000e32, 1.51751540e-03],
[2.13082030e02, 3.28812836e02, 2.29487946e03, 1.00000000e32, 1.37109800e-01],
[1.97333083e02, 3.09175848e02, 2.24250550e03, 1.00000000e32, 2.46506283e00],
[1.00000000e32, 1.00000000e32, 1.00000000e32, 5.26223432e00, 2.50000000e31],
[3.84690152e01, 8.09465684e01, 3.33064175e02, 2.50000000e31, 0.00000000e00],
]
)
a = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
b = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
Q = np.array(
[
[0, 0, 0, 0, 0.125],
[0, 0, 0, 0, 0.125],
[0, 0, 0, 0, 0.125],
[0, 0, 0, 0.125, 0],
[0.125, 0.125, 0.125, 0, 0.125],
]
)
assert (P.sum(axis=0) == a).all()
assert (P.sum(axis=1) == a).all()
assert (Q.sum(axis=0) == a).all()
assert (Q.sum(axis=1) == a).all()
print("my cost matrix:\n", Q)
print("POT matrix:\n", P)
print("POT cost:", np.sum(np.multiply(P, M)))
print("my cost:", np.sum(np.multiply(Q, M)))
returns:
my cost matrix:
[[0. 0. 0. 0. 0.125]
[0. 0. 0. 0. 0.125]
[0. 0. 0. 0. 0.125]
[0. 0. 0. 0.125 0. ]
[0.125 0.125 0.125 0. 0.125]]
POT matrix:
[[0. 0.125 0. 0. 0. ]
[0.125 0. 0. 0. 0. ]
[0. 0. 0.125 0. 0. ]
[0. 0. 0. 0.125 0. ]
[0. 0. 0. 0. 0.5 ]]
POT cost: 354.43787279000003
my cost: 57.54321038317501
Expected behavior
ot.emd should return (up to numerical precision) a transport plan (at least) as good as the Q I manually propose.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Ubuntu 20.04
- Python version: 3.7
- How was POT installed (source,
pip,conda): conda
Output of the following code snippet:
>>> import platform; print(platform.platform())
Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid
>>> import sys; print("Python", sys.version)
Python 3.7.4 (default, Aug 13 2019, 20:35:49)
[GCC 7.3.0]
>>> import numpy; print("NumPy", numpy.__version__)
NumPy 1.16.4
>>> import scipy; print("SciPy", scipy.__version__)
SciPy 1.3.1
>>> import ot; print("POT", ot.__version__)
POT 0.7.0Additional context
As shown, I set numIterMax at 2000000 and didn't get any warning (and the code run fast) so the algorithm does converge.