You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It seems that the cost_correction matrix is computed incorrectly. This is the current code that can be found here:
# labels_match is a (ns, nt) matrix of {True, False} such that# the cells (i, j) has False if ys[i] != yt[i]label_match= (ys[:, None] -yt[None, :]) !=0# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such# that he cells (i, j) has -Inf where there's no correction necessary# by 'correction' we mean setting cost to a large value when# labels do not match# ...withwarnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction=label_match*missing_labels*self.limit_max
The issues:
First, the comment says that label_match is False if ys[i] != yt[i].
However, if ys[i] != yt[i] then (ys[:, None] - yt[None, :]) != 0 will be True, hence label_match will be True - although the labels do not match (the naming is confusing in this case). Therefore, either
the variable should be named label_mismatch and the comment should be fixed OR
we check for equality label_match = (ys[:, None] - yt[None, :]) == 0 and flip the value in cost_correction, i.e. cost_correction = (1 - label_match) * ...
Second, cost_correction = label_match * missing_labels * self.limit_max will apply a cost correction only if missing_labels is True. However, it must not correct if missing_labels is True - hence, we need to flip it to ... * (1 - missing_labels ) * ...
Therefore, I'd propose the following change
# label_mismatch is a (ns, nt) matrix of {True, False} such that# the cells (i, j) has True if ys[i] != yt[i]label_mismatch= (ys[:, None] -yt[None, :]) !=0# cost correction is a (ns, nt) matrix of {-Inf, float, Inf} such# that he cells (i, j) has -Inf where there's no correction necessary# by 'correction' we mean setting cost to a large value when# labels do not match# ...withwarnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
cost_correction=label_mismatch* (1-missing_labels) *self.limit_max
Happy to send the corresponding PR if you agree.
Screenshots
The following screenshots show the effect of flipping the missing_labels value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.
Environment (please complete the following information):
Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)
The text was updated successfully, but these errors were encountered:
Hello @martinrohbeck and thanks for the issue. I think this is indeed a bug especially the missing_label weight. I am curious of the input of @kachayev who did the original code if I remember well but I think we would welcome a PR with your two fixes.
Thanks for the report! Your suggestion sounds reasonable to me. There are a couple of test cases in the test suite designed to verify that the vectorized version of the algorithm produces the same results as the previous version of the code. If you find that these tests don’t fail while working on the PR, it would indicate that the discrepancy was introduced during the vectorization process. Otherwise, it would be worth revisiting the logic in the older code. Either way, I’ll be glad to review the PR.
Describe the bug
It seems that the
cost_correction
matrix is computed incorrectly. This is the current code that can be found here:The issues:
label_match
is False ifys[i] != yt[i]
.However, if
ys[i] != yt[i]
then(ys[:, None] - yt[None, :]) != 0
will be True, hencelabel_match
will be True - although the labels do not match (the naming is confusing in this case). Therefore, eitherlabel_mismatch
and the comment should be fixed ORlabel_match = (ys[:, None] - yt[None, :]) == 0
and flip the value incost_correction
, i.e.cost_correction = (1 - label_match) * ...
cost_correction = label_match * missing_labels * self.limit_max
will apply a cost correction only ifmissing_labels
is True. However, it must not correct ifmissing_labels
is True - hence, we need to flip it to... * (1 - missing_labels ) * ...
Therefore, I'd propose the following change
Happy to send the corresponding PR if you agree.
Screenshots
The following screenshots show the effect of flipping the
missing_labels
value. Here we map samples across multiple Gaussian distributions with 2 labels (p = 1 and p = 2). All labels are given. Without the fix, the transport plans are not computed correctly. With the fix, only samples from the same target class are linked.Environment (please complete the following information):
Linux-4.18.0-372.75.1.el8_6.x86_64-x86_64-with-glibc2.28
Python 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:23:07) [GCC 12.3.0]
NumPy 2.0.0
SciPy 1.14.0
POT 0.9.4 (pip installed)
The text was updated successfully, but these errors were encountered: