Skip to content

Numerical issue of ot.emd with large entries #229

@tlacombe

Description

@tlacombe

Describe the bug

It seems ot.emd fails to return an optimal plan (up to numerical precision) if there is large entries in the cost matrix (even if the optimal weight to put on these entries is 0).

To Reproduce

import numpy as np
import ot

M = np.array(
    [
        [2.50275352e02, 3.74653218e02, 2.41352736e03, 1.00000000e32, 1.51751540e-03],
        [2.13082030e02, 3.28812836e02, 2.29487946e03, 1.00000000e32, 1.37109800e-01],
        [1.97333083e02, 3.09175848e02, 2.24250550e03, 1.00000000e32, 2.46506283e00],
        [1.00000000e32, 1.00000000e32, 1.00000000e32, 5.26223432e00, 2.50000000e31],
        [3.84690152e01, 8.09465684e01, 3.33064175e02, 2.50000000e31, 0.00000000e00],
    ]
)

a = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
b = np.array([0.125, 0.125, 0.125, 0.125, 0.5])
P = ot.emd(a=a, b=b, M=M, numItermax=2000000)
Q = np.array(
    [
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0, 0.125],
        [0, 0, 0, 0.125, 0],
        [0.125, 0.125, 0.125, 0, 0.125],
    ]
)
assert (P.sum(axis=0) == a).all()
assert (P.sum(axis=1) == a).all()
assert (Q.sum(axis=0) == a).all()
assert (Q.sum(axis=1) == a).all()
print("my cost matrix:\n", Q)
print("POT matrix:\n", P)
print("POT cost:", np.sum(np.multiply(P, M)))
print("my cost:", np.sum(np.multiply(Q, M)))

returns:

my cost matrix:
 [[0.    0.    0.    0.    0.125]
 [0.    0.    0.    0.    0.125]
 [0.    0.    0.    0.    0.125]
 [0.    0.    0.    0.125 0.   ]
 [0.125 0.125 0.125 0.    0.125]]
POT matrix:
 [[0.    0.125 0.    0.    0.   ]
 [0.125 0.    0.    0.    0.   ]
 [0.    0.    0.125 0.    0.   ]
 [0.    0.    0.    0.125 0.   ]
 [0.    0.    0.    0.    0.5  ]]
POT cost: 354.43787279000003
my cost: 57.54321038317501

Expected behavior

ot.emd should return (up to numerical precision) a transport plan (at least) as good as the Q I manually propose.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Ubuntu 20.04
  • Python version: 3.7
  • How was POT installed (source, pip, conda): conda

Output of the following code snippet:

>>> import platform; print(platform.platform())
Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid
>>> import sys; print("Python", sys.version)
Python 3.7.4 (default, Aug 13 2019, 20:35:49) 
[GCC 7.3.0]
>>> import numpy; print("NumPy", numpy.__version__)
NumPy 1.16.4
>>> import scipy; print("SciPy", scipy.__version__)
SciPy 1.3.1
>>> import ot; print("POT", ot.__version__)
POT 0.7.0

Additional context

As shown, I set numIterMax at 2000000 and didn't get any warning (and the code run fast) so the algorithm does converge.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions