Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POT generating incorrect result for very simple OT problem #345

Closed
jstac opened this issue Feb 6, 2022 · 2 comments · Fixed by #343
Closed

POT generating incorrect result for very simple OT problem #345

jstac opened this issue Feb 6, 2022 · 2 comments · Fixed by #343

Comments

@jstac
Copy link

jstac commented Feb 6, 2022

Apologies if I'm doing something stupid --- I don't think I am. The simple example

import numpy as np
import ot

phi = np.array((0.5, 0.5))   # distribution 1
psi = np.array((0.5, 0.5))   # distribution 2
c = ((2, 1),
     (1, 1))
c = np.array(c)

pi = ot.emd(phi, psi, c)

produces the incorrect result

array([[0, 0],
       [0, 0]])

(Clearly we should send all mass at 1 to 2 and all mass at 2 to 1.)

Direct application of linear programming produces the correct result

array([[ 0. ,  0.5],
       [ 0.5, -0. ]])

Here's the direct linear programming code

# Define parameters
m = n = 2

# Vectorize matrix C
c_vec = c.reshape((m * n, 1), order='F')

# Construct matrix A by Kronecker product
A1 = np.kron(np.ones((1, n)), np.identity(m))
A2 = np.kron(np.identity(n), np.ones((1, m)))
A  = np.vstack([A1, A2])

# Construct vector b
b = np.hstack([phi, psi])

# Solve the primal problem
res = linprog(c_vec, A_eq=A, b_eq=b, method='highs-ipm')

# Print results
pi = res.x.reshape((m,n), order='F')

Environment (please complete the following information):

Manjaro linux, POT installed via pip in Anaconda environment.

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Linux-5.13.19-2-MANJARO-x86_64-with-glibc2.17
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0]
NumPy 1.20.3
SciPy 1.7.1
POT 0.8.0

@rflamary
Copy link
Collaborator

rflamary commented Feb 7, 2022

Thank you for reporting this, I can reproduce the bug indeed.

This code here finds the correct OT plan:

import numpy as np
import ot

phi = np.array((0.5, 0.5))   # distribution 1
psi = np.array((0.5, 0.5))   # distribution 2
c = ((2, 1),
     (1, 1.0))
c = np.array(c)

pi = ot.emd(phi, psi, c)

Interestingly if I put a unique float in the tuple defining C it works so it seems to be a problem of type wher integer values of C do not find a solution. Will look into it

@rflamary
Copy link
Collaborator

rflamary commented Feb 7, 2022

found it! The backend forces the type of the output to be the same as the type of C which in your case is an integer and sets everything to 0 when rounding. This is done so as to ensure that you can solve emd across pytorch or tensorflow and returns the same type/device as the input.

Maybe we should enforce the output type to be the type of phi/psi? Not sure if it is a good idea a warning for C of integer type is probably better. In any case we need to update the doc to make this more clear.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants