Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Bures-Wasserstein distance #468

Merged
merged 5 commits into from
May 4, 2023
Merged

Improve Bures-Wasserstein distance #468

merged 5 commits into from
May 4, 2023

Conversation

francois-rozet
Copy link
Contributor

@francois-rozet francois-rozet commented May 2, 2023

Types of changes

  • Fixed typo in the documentation of the Bures-Wasserstein distance ($\Sigma_s$ instead of $\Sigma_s^{1/2}$).
  • Faster way of computing the trace of the square-root of the product of $\Sigma_s$ and $\Sigma_t$.

The implementation is based on two facts:

  1. The trace of $A$ equals the sum of its eigenvalues.
  2. The eigenvalues of $\sqrt{A}$ are the square-roots of the eigenvalues of $A$.

Then, $\mathrm{tr}(\sqrt{A})$ is the sum of the square-roots of the eigenvalues of $A$.

See Lightning-AI/torchmetrics#1705.

Motivation and context / Related issue

Computing the square-root of a matrix is slow and unstable.

How has this been tested (if it applies)

The new implementation still passes the tests (at least with NumPy backend).

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@francois-rozet
Copy link
Contributor Author

francois-rozet commented May 2, 2023

It seems like the TensorFlow backend does not support .real on a complex tensor. I am not sure how to solve that. There are also issues with older torch versions.

@rflamary
Copy link
Collaborator

rflamary commented May 3, 2023

Hello @francois-rozet, when a sepcific backend type misses a feature is a feature we usually add it to the backend, in this case maybe we shoudl add nx.real and nx.imagto the backend functions

@rflamary
Copy link
Collaborator

rflamary commented May 3, 2023

Also @francois-rozet could you please do a quick test of the timing before and after your speedup for different backends (at least numpy and pytorch) and put it in the text description of the PR?

I like to have quantified performance gain in the history/github to know why we changed stuff. Also maybe a quick test that checks that the new function returns the same thing as the np.trace () up to numerical precision. It seems right but such a test will help detecting potential problems in the future.

@francois-rozet
Copy link
Contributor Author

After some tests, I found that this implementation is only faster for NumPy. Computing the square root of a general matrix is indeed slower than computing its eigenvalues. However, computing the square root of a symmetric matrix takes more or less the same time as computing its eigenvalues. In fact, the PyTorch backend uses torch.linalg.eigh to implement sqrtm. So I think instead of modifying the algorithm, we can simply replace the sqrtm of the NumPy backend.

@rflamary
Copy link
Collaborator

rflamary commented May 3, 2023

OK it make sens, happy I pushed you to investigate. You can leave the envals function in the backend it can be usefull in the future (trace norm regularization for instance)

@francois-rozet
Copy link
Contributor Author

francois-rozet commented May 3, 2023

You can leave the eigvals function in the backend

Oops I already removed it. Are the eigvals or the singular values necessary for trace norm regularization? And is it for a symmetric matrix? I turns out that eigh is much faster than eig for a symmetric matrix.

@rflamary
Copy link
Collaborator

rflamary commented May 3, 2023

OK no worry we can add them (properlyd depending on symmetry or not) later.

I'm neraly OK for a merge but please add a short description of the PR in the RELEASES file file.

@francois-rozet
Copy link
Contributor Author

francois-rozet commented May 3, 2023

I added a line to the RELEASES.md file. Also, here is a small notebook demo to show that np.linalg.eigh is much faster than scipy.linalg.sqrtm:

import numpy as np
import scipy.linalg as sl

A = np.random.rand(512, 512) / 512 ** 0.5
A = A @ A.T + np.eye(512) * 1e-6  # definite positive

%timeit np.linalg.eigh(A)
%timeit sl.sqrtm(A)

returns

19.5 ms ± 2.27 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
166 ms ± 15.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

The time gap increases for larger matrices.

@francois-rozet
Copy link
Contributor Author

All tests have passed, but the CircleCI one.

@rflamary rflamary merged commit 83dc498 into PythonOT:master May 4, 2023
17 of 18 checks passed
@francois-rozet francois-rozet deleted the patch branch May 4, 2023 06:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants