Skip to content

[MRG] Add normalize parameter to sliced_wasserstein_distance#808

Merged
rflamary merged 8 commits into
PythonOT:masterfrom
Harguna:feature/normalize-sliced-wasserstein
May 29, 2026
Merged

[MRG] Add normalize parameter to sliced_wasserstein_distance#808
rflamary merged 8 commits into
PythonOT:masterfrom
Harguna:feature/normalize-sliced-wasserstein

Conversation

@Harguna

@Harguna Harguna commented Apr 29, 2026

Copy link
Copy Markdown
Contributor

Types of changes

  • New Feature

Motivation and context / Related issue

Addresses #807.

Sliced Wasserstein Distance is sensitive to feature scale: features with larger numerical ranges dominate the random projections, drowning out meaningful differences in smaller-scale features. Users often don't realize this is happening and, when they do, the manual fix (preprocessing inputs with a scaler) is verbose and easy to get wrong — fitting each distribution independently silently corrupts the distance.

This PR adds optional normalize and normalize_mode parameters to sliced_wasserstein_distance and max_sliced_wasserstein_distance to handle this cleanly inside the function. Default behavior (normalize=None) is unchanged, so the change is fully backward-compatible.

This is a [WIP] skeleton PR - it establishes the API surface, signatures, docstrings, and a helper function so the design can be reviewed before the full implementation lands. The actual normalization math, edge case handling, behavioral tests, and example script will follow in subsequent commits on this same branch.

How has this been tested (if it applies)

In this skeleton:

  • Existing test/test_sliced.py test suite continues to pass (verifies the new keyword parameters didn't break anything).
  • pre-commit run --all-files passes locally.

Tests related to the new feature will be added with the full implementation in the subsequent commits.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@rflamary

Copy link
Copy Markdown
Collaborator

Thanks for this PR, we are a bit busy at te moement and will have more time to give some feedback after the neurips deadline in two weeks.

@Harguna

Harguna commented May 3, 2026

Copy link
Copy Markdown
Contributor Author

Sounds good, best of luck with the NeurIPS.

@rflamary

rflamary commented May 12, 2026

Copy link
Copy Markdown
Collaborator

Hello @Harguna , thanks for the PR.

We had a look with @clbonet and we are not really comfortable with having a normalization inside the sliced wasserstein function. While this might make sens in some applications it also means that for instance when optimizing the SWD, the loss between two optimization steps or minibtach is not comparable (since normalized locally) which poses a practical problem because it is an intuitive behavior and leads to different minimizers.

But we agree with you that normalization should be easier to handle. So we propose to handle it in a slightly different way as follows :

scaler = ot.utils.DataScaler(norm='standard').fit([X_s,X_t]) # can take a tensor or a list for joint normalization
swd = ot.sliced_wasserstein_distance(X_s, X_t, scaler=scaler)

this means that the normalization is fitted outside on a class (compatible with sklearn with a fit and transform function but that handles backends). The scaler parameter should also accept a function (detected with __call__ and can apply it so this would allow pytorch pre-processing pipeline or models). I thinks we need a helper function

def apply_scaler(X_s, X_t, scaler=None)

that handles the preprocessing of the data (or not if scaler=None) so that we can add this API to other functions in POT such as ot.solve_sample.

Would you be OK with implementing our suggestions?

@Harguna

Harguna commented May 13, 2026

Copy link
Copy Markdown
Contributor Author

Hello @rflamary,

Thanks for the detailed feedback, this makes sense. I had accounted for the relative shift between X_s and X_t between optimization steps for the same batch by fitting normalization statistics jointly on concat(X_s, X_t), but you're right that this doesn't address inter-batch variations, which would destabilize the objective during gradient-based training.

I agree with your suggested design which decouples the fitting step from the distance computation. I'm happy to implement DataScaler with backend compatibility and apply_scaler as a standalone helper so it can be reused across other POT functions. I will get started on that.

@Harguna

Harguna commented May 22, 2026

Copy link
Copy Markdown
Contributor Author

Hello @rflamary,

I've implemented the changes based on your feedback. The DataScaler class and apply_scaler helper have been added to ot/utils.py, and both sliced_wasserstein_distance and max_sliced_wasserstein_distance now accept a scaler parameter. Tests are included for all new functionality. Happy to update the documentation and example gallery once you've had a chance to review.

Thanks,
Harguna

@codecov

codecov Bot commented May 22, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 96.78%. Comparing base (048f3ae) to head (297a14a).

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #808      +/-   ##
==========================================
+ Coverage   96.74%   96.78%   +0.03%     
==========================================
  Files         118      118              
  Lines       23466    23721     +255     
==========================================
+ Hits        22703    22958     +255     
  Misses        763      763              
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rflamary rflamary left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @Harguna thanks for the modifications.

We are nearly there I just need a little more consistence between transform and fit_transform so that they can both accept lists. Also we need test of the new classes across backends . See comments below.

Comment thread ot/utils.py
Comment thread ot/utils.py Outdated
Comment thread test/test_sliced.py
@Harguna

Harguna commented May 28, 2026

Copy link
Copy Markdown
Contributor Author

Hey @rflamary,

Thanks for the review and for catching these gaps. I've addressed your comments and made a few minor improvements. Happy to iterate further if anything needs changing.

Best,
Harguna

@rflamary rflamary changed the title WIP: Add normalize parameter to sliced_wasserstein_distance [MRG] Add normalize parameter to sliced_wasserstein_distance May 29, 2026
@rflamary rflamary merged commit 6e21613 into PythonOT:master May 29, 2026
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants