Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect computation of cost_correction matrix in ot.da.EMDTransport #664

Open
martinrohbeck opened this issue Jul 23, 2024 · 2 comments
Open
Assignees

Comments

@martinrohbeck
Copy link

Describe the bug

It seems that the cost_correction matrix is computed incorrectly. This is the current code that can be found here:

# labels_match is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has False if ys[i] != yt[i]
label_match = (ys[:, None] - yt[None, :]) != 0 
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    cost_correction = label_match * missing_labels * self.limit_max

The issues:

  • First, the comment says that label_match is False if ys[i] != yt[i].
    However, if ys[i] != yt[i] then (ys[:, None] - yt[None, :]) != 0 will be True, hence label_match will be True - although the labels do not match (the naming is confusing in this case). Therefore, either
    • the variable should be named label_mismatch and the comment should be fixed OR
    • we check for equality label_match = (ys[:, None] - yt[None, :]) == 0 and flip the value in cost_correction, i.e. cost_correction = (1 - label_match) * ...
  • Second, cost_correction = label_match * missing_labels * self.limit_max will apply a cost correction only if missing_labels is True. However, it must not correct if missing_labels is True - hence, we need to flip it to ... * (1 - missing_labels ) * ...

Therefore, I'd propose the following change

# label_mismatch is a (ns, nt) matrix of {True, False} such that
# the cells (i, j) has True if ys[i] != yt[i]
label_mismatch = (ys[:, None] - yt[None, :]) != 0 
# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such
# that he cells (i, j) has -Inf where there's no correction necessary
# by 'correction' we mean setting cost to a large value when
# labels do not match
# ...
with warnings.catch_warnings():
    warnings.simplefilter('ignore', category=RuntimeWarning)
    cost_correction = label_mismatch * (1 - missing_labels) * self.limit_max

Happy to send the corresponding PR if you agree.

Screenshots

The following screenshots show the effect of flipping the missing_labels value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.

image

image

Environment (please complete the following information):

Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)

@rflamary
Copy link
Collaborator

Hello @martinrohbeck and thanks for the issue. I think this is indeed a bug especially the missing_label weight. I am curious of the input of @kachayev who did the original code if I remember well but I think we would welcome a PR with your two fixes.

@kachayev
Copy link
Collaborator

Hi @martinrohbeck,

Thanks for the report! Your suggestion sounds reasonable to me. There are a couple of test cases in the test suite designed to verify that the vectorized version of the algorithm produces the same results as the previous version of the code. If you find that these tests don’t fail while working on the PR, it would indicate that the discrepancy was introduced during the vectorization process. Otherwise, it would be worth revisiting the logic in the older code. Either way, I’ll be glad to review the PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants