Skip to content

Assertion Error when using low-precision numbers #429

Closed
@TheSeparatrix

Description

@TheSeparatrix

Describe the bug

Hello,
I used the 1d earth mover's distance function ot.emd2_1d to measure the distance between two outputs from a PyTorch neural network. With default parameters, all the numbers from the PyTorch model are float32.
The distance function raises an Assertion Error because of, what looks to me like a precision error.

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:361, in emd2_1d(x_a, x_b, a, b, metric, p, dense, log)
    276 r"""Solves the Earth Movers distance problem between 1d measures and returns
    277 the loss
    278 
   (...)
    357     instead of the cost)
    358 """
    359 # If we do not return G (log==False), then we should not to cast it to dense
    360 # (useless overhead)
--> 361 G, log_emd = emd_1d(x_a=x_a, x_b=x_b, a=a, b=b, metric=metric, p=p,
    362                     dense=dense and log, log=True)
    363 cost = log_emd['cost']
    364 if log:

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/ot/lp/solver_1d.py:237, in emd_1d(x_a, x_b, a, b, metric, p, dense, log)
    234     b = nx.ones((x_b.shape[0],), type_as=x_b) / x_b.shape[0]
    236 # ensure that same mass
--> 237 np.testing.assert_almost_equal(
    238     nx.to_numpy(nx.sum(a, axis=0)),
    239     nx.to_numpy(nx.sum(b, axis=0)),
    240     err_msg='a and b vector must have the same sum'
    241 )
    242 b = b * nx.sum(a) / nx.sum(b)
    244 x_a_1d = nx.reshape(x_a, (-1,))

File ~/miniconda3/envs/disentanglement_env/lib/python3.9/site-packages/numpy/testing/_private/utils.py:599, in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
    597     pass
    598 if abs(desired - actual) >= 1.5 * 10.0**(-decimal):
--> 599     raise AssertionError(_build_err_msg())

AssertionError: 
Arrays are not almost equal to 7 decimals a and b vector must have the same sum
 ACTUAL: 1.0000001
 DESIRED: 0.99999994

Expected behavior

Is this the correct behaviour? If this 0.00000016 discrepancy changes the output of the function then I will have to reconsider using higher precision numbers. However, if this doesn't impact the result too much, maybe changing this to a warning rather than an error would be good.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): observed for same code both on Linux and MacOS
  • How was POT installed (source, pip, conda): conda-forge

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Output:

Linux-3.10.0-1160.80.1.el7.x86_64-x86_64-with-glibc2.17
Python 3.9.12 (main, Jun  1 2022, 11:38:51) 
[GCC 7.5.0]
NumPy 1.21.5
SciPy 1.7.3
POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions