Skip to content

[MRG] raise error if mass mismatch in emd2 #386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 21, 2022
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ The contributors to this library are:
* [Nathan Cassereau](https://github.com/ncassereau-idris) (Backends)
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Fixed an issue where pointers would overflow in the EMD solver, returning an
incomplete transport plan above a certain size (slightly above 46k, its square being
roughly 2^31) (PR #381)
- Error raised when mass mismatch in emd2 (PR #386)


## 0.8.2
Expand Down
9 changes: 9 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1):
If this behaviour is unwanted, please make sure to provide a
floating point input.

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Parameters
Expand Down Expand Up @@ -389,6 +391,8 @@ def emd2(a, b, M, processes=1,
If this behaviour is unwanted, please make sure to provide a
floating point input.

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

Uses the algorithm proposed in :ref:`[1] <references-emd2>`.

Parameters
Expand Down Expand Up @@ -481,6 +485,11 @@ def emd2(a, b, M, processes=1,
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"

# ensure that same mass
np.testing.assert_almost_equal(a.sum(0),
b.sum(0,keepdims=True), err_msg='a and b vector must have the same sum')
b = b * a.sum(0) / b.sum(0,keepdims=True)

asel = a != 0

numThreads = check_number_threads(numThreads)
Expand Down
3 changes: 3 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ def test_emd_dimension_and_mass_mismatch():

np.testing.assert_raises(AssertionError, ot.emd2, a, a, M)

# test emd and emd2 for mass mismatch
a = ot.utils.unif(n_samples)
b = a.copy()
a[0] = 100
np.testing.assert_raises(AssertionError, ot.emd, a, b, M)
np.testing.assert_raises(AssertionError, ot.emd2, a, b, M)


def test_emd_backends(nx):
Expand Down