Skip to content

Hotfix: numercial stability of non-log-stabilized sinkhorn plan #531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

LarsKue
Copy link
Contributor

@LarsKue LarsKue commented Jul 4, 2025

This is an important hotfix, so sending it directly to main. Thanks @daniel-habermann for the report!

I also did some benchmarking and it seems the current implementation is optimal with respect to performance.

Edit: I also fixed an issue with the convergence check in log_sinkhorn_plan which takes its runtime down from ~4s to ~0.02s.

vpratz and others added 8 commits June 22, 2025 04:02
* add log_gamma diagnostic

* add missing export for log_gamma

* add missing export for gamma_null_distribution, gamma_discrepancy

* fix broken unit tests

* rename log_gamma module to sbc

* add test_log_gamma unit test

* add return information to log_gamma doc string

* fix typo in docstring, use fixed-length np array to collect log_gammas instead of appending to an empty list
…525)

* standardization: add test for multi-input values (failing)

This test reveals to bugs in the standarization layer:

- count is updated multiple times
- batch_count is too small, as the sizes from reduce_axes have to be
  multiplied

* breaking: fix bugs regarding count in standardization layer

Fixes #524

This fixes the two bugs described in c4cc133:

- count was accidentally updated, leading to wrong values
- count was calculated wrongly, as only the batch size was used. Correct
  is the product of all reduce dimensions. This lead to wrong standard
  deviations

While the batch dimension is the same for all inputs, the size of the
second dimension might vary. For this reason, we need to introduce an
input-specific `count` variable. This breaks serialization.

* fix assert statement in test
@LarsKue LarsKue requested a review from Copilot July 4, 2025 09:24
@LarsKue LarsKue self-assigned this Jul 4, 2025
@LarsKue LarsKue added the fix Pull request that fixes a bug label Jul 4, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This hotfix ensures numerical stability for the non-log-stabilized Sinkhorn implementation and aligns tests to cover both the new and existing normalization methods.

  • Updated the Sinkhorn plan initialization and normalization logic for better stability.
  • Changed max_steps default to None to run until convergence.
  • Extended existing tests to parameterize over "log_sinkhorn" and "sinkhorn" methods.

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
tests/test_utils/test_optimal_transport.py Parameterized two tests over both Sinkhorn variants and removed a skip decorator
bayesflow/utils/optimal_transport/sinkhorn.py Added numerical-stability steps, switched normalization from softmax to sum-based, and updated max_steps default
Comments suppressed due to low confidence (1)

bayesflow/utils/optimal_transport/sinkhorn.py:45

  • The default value of None conflicts with the int annotation. Consider changing the signature to max_steps: Optional[int] = None and importing Optional from typing for clarity.
    max_steps: int = None,

Copy link

codecov bot commented Jul 4, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
bayesflow/diagnostics/metrics/__init__.py 100.00% <100.00%> (ø)
...sflow/diagnostics/metrics/calibration_log_gamma.py 100.00% <100.00%> (ø)
bayesflow/distributions/diagonal_normal.py 95.34% <100.00%> (-0.11%) ⬇️
bayesflow/distributions/diagonal_student_t.py 95.91% <100.00%> (-0.09%) ⬇️
bayesflow/distributions/mixture.py 98.03% <ø> (ø)
...esflow/networks/standardization/standardization.py 95.45% <100.00%> (ø)
bayesflow/scores/multivariate_normal_score.py 97.72% <100.00%> (ø)
bayesflow/utils/optimal_transport/log_sinkhorn.py 100.00% <100.00%> (ø)
bayesflow/utils/optimal_transport/sinkhorn.py 100.00% <100.00%> (ø)

... and 2 files with indirect coverage changes

@LarsKue LarsKue requested a review from stefanradev93 July 4, 2025 10:14
@daniel-habermann
Copy link
Contributor

daniel-habermann commented Jul 5, 2025

I have added some additional commits. Most notable changes are:

  • sinkhorn_plan and log_sinkhorn_plan now return proper transport plans, that is, the row and column sums match the marginals.
  • the assignments in sinkhorn and log_sinkhorn are now calculated as assignments = keras.ops.categorical(log_plan), because keras.ops.categorical expects log-probs instead of probabilities as inputs (which they confusingly call logits).
  • added some additional unit tests to ensure that the transport plans are correct.

@LarsKue could you please double check that the comment in this function definition is what you intended? (the such that.. part)

def sinkhorn(x1: Tensor, x2: Tensor, seed: int = None, **kwargs) -> (Tensor, Tensor):
    """
    Matches elements from x2 onto x1 using the Sinkhorn-Knopp algorithm.

    Sinkhorn-Knopp is an iterative algorithm that repeatedly normalizes the cost matrix into a
    transport plan, containing assignment probabilities.
    The permutation is then sampled randomly according to the transport plan.

    :param x1: Tensor of shape (n, ...)
        Samples from the first distribution.

    :param x2: Tensor of shape (m, ...)
        Samples from the second distribution.

    :param kwargs:
        Additional keyword arguments that are passed to :py:func:`sinkhorn_plan`.

    :param seed: Random seed to use for sampling indices.
        Default: None, which means the seed will be auto-determined for non-compiled contexts.

    :return: Tensor of shape (n,)
        Assignment indices for x2.

    """
    plan = sinkhorn_plan(x1, x2, **kwargs)

    # we sample from log(plan) to receive assignments of length n, corresponding to indices of x2
    # such that x2[assignments] matches x1
    assignments = keras.random.categorical(keras.ops.log(plan), num_samples=1, seed=seed)
    assignments = keras.ops.squeeze(assignments, axis=1)

    return assignments

From my point of view this is ready to be merged now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
fix Pull request that fixes a bug
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants