-
Notifications
You must be signed in to change notification settings - Fork 72
Hotfix: numercial stability of non-log-stabilized sinkhorn plan #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
* add log_gamma diagnostic * add missing export for log_gamma * add missing export for gamma_null_distribution, gamma_discrepancy * fix broken unit tests * rename log_gamma module to sbc * add test_log_gamma unit test * add return information to log_gamma doc string * fix typo in docstring, use fixed-length np array to collect log_gammas instead of appending to an empty list
…525) * standardization: add test for multi-input values (failing) This test reveals to bugs in the standarization layer: - count is updated multiple times - batch_count is too small, as the sizes from reduce_axes have to be multiplied * breaking: fix bugs regarding count in standardization layer Fixes #524 This fixes the two bugs described in c4cc133: - count was accidentally updated, leading to wrong values - count was calculated wrongly, as only the batch size was used. Correct is the product of all reduce dimensions. This lead to wrong standard deviations While the batch dimension is the same for all inputs, the size of the second dimension might vary. For this reason, we need to introduce an input-specific `count` variable. This breaks serialization. * fix assert statement in test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This hotfix ensures numerical stability for the non-log-stabilized Sinkhorn implementation and aligns tests to cover both the new and existing normalization methods.
- Updated the Sinkhorn plan initialization and normalization logic for better stability.
- Changed
max_steps
default toNone
to run until convergence. - Extended existing tests to parameterize over
"log_sinkhorn"
and"sinkhorn"
methods.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
File | Description |
---|---|
tests/test_utils/test_optimal_transport.py | Parameterized two tests over both Sinkhorn variants and removed a skip decorator |
bayesflow/utils/optimal_transport/sinkhorn.py | Added numerical-stability steps, switched normalization from softmax to sum-based, and updated max_steps default |
Comments suppressed due to low confidence (1)
bayesflow/utils/optimal_transport/sinkhorn.py:45
- The default value of
None
conflicts with theint
annotation. Consider changing the signature tomax_steps: Optional[int] = None
and importingOptional
fromtyping
for clarity.
max_steps: int = None,
Codecov ReportAll modified and coverable lines are covered by tests ✅
|
…w into fix-sinkhorn-plan
… the plan, instead of the plan directly
…ch that x2[assignments] matches x1
…_plan), log_sinkhorn_plan returns logits of the transport plan
I have added some additional commits. Most notable changes are:
@LarsKue could you please double check that the comment in this function definition is what you intended? (the such that.. part) def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Tensor):
"""
Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.
Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
transport plan, containing assignment probabilities.
The permutation is then sampled randomly according to the transport plan.
:param x1: Tensor of shape (n, ...)
Samples from the first distribution.
:param x2: Tensor of shape (m, ...)
Samples from the second distribution.
:param kwargs:
Additional keyword arguments that are passed to :py:func:`sinkhorn_plan`.
:param seed: Random seed to use for sampling indices.
Default: None, which means the seed will be auto-determined for non-compiled contexts.
:return: Tensor of shape (n,)
Assignment indices for x2.
"""
plan = sinkhorn_plan(x1, x2, **kwargs)
# we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
# such that x2[assignments] matches x1
assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
assignments = keras.ops.squeeze(assignments, axis=1)
return assignments From my point of view this is ready to be merged now. |
This is an important hotfix, so sending it directly to main. Thanks @daniel-habermann for the report!
I also did some benchmarking and it seems the current implementation is optimal with respect to performance.
Edit: I also fixed an issue with the convergence check in log_sinkhorn_plan which takes its runtime down from ~4s to ~0.02s.