Closed
Description
Describe the bug
Hello,
I used the 1d earth mover's distance function ot.emd2_1d
to measure the distance between two outputs from a PyTorch neural network. With default parameters, all the numbers from the PyTorch model are float32
.
The distance function raises an Assertion Error because of, what looks to me like a precision error.
File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:361, in emd2_1d(x_a, x_b, a, b, metric, p, dense, log)
276 r"""Solves the Earth Movers distance problem between 1d measures and returns
277 the loss
278
(...)
357 instead of the cost)
358 """
359 # If we do not return G (log==False), then we should not to cast it to dense
360 # (useless overhead)
--> 361 G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
362 dense=dense and log, log=True)
363 cost = log_emd['cost']
364 if log:
File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:237, in emd_1d(x_a, x_b, a, b, metric, p, dense, log)
234 b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
236 # ensure that same mass
--> 237 np.testing.assert_almost_equal(
238 nx.to_numpy(nx.sum(a, axis=0)),
239 nx.to_numpy(nx.sum(b, axis=0)),
240 err_msg='a and b vector must have the same sum'
241 )
242 b = b * nx.sum(a) / nx.sum(b)
244 x_a_1d = nx.reshape(x_a, (-1,))
File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/numpy/testing/_private/utils.py:599, in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
597 pass
598 if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
--> 599 raise AssertionError(_build_err_msg())
AssertionError:
Arrays are not almost equal to 7 decimals a and b vector must have the same sum
ACTUAL: 1.0000001
DESIRED: 0.99999994
Expected behavior
Is this the correct behaviour? If this 0.00000016 discrepancy changes the output of the function then I will have to reconsider using higher precision numbers. However, if this doesn't impact the result too much, maybe changing this to a warning rather than an error would be good.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): observed for same code both on Linux and MacOS
- How was POT installed (source,
pip
,conda
): conda-forge
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Output:
Linux-3.10.0-1160.80.1.el7.x86_64-x86_64-with-glibc2.17
Python 3.9.12 (main, Jun 1 2022, 11:38:51)
[GCC 7.5.0]
NumPy 1.21.5
SciPy 1.7.3
POT 0.8.2