Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add implicit Sinkhorn gradients #605

Merged
merged 7 commits into from
Feb 20, 2024
Merged

[MRG] Add implicit Sinkhorn gradients #605

merged 7 commits into from
Feb 20, 2024

Conversation

rflamary
Copy link
Collaborator

@rflamary rflamary commented Feb 19, 2024

Types of changes

This PR aims at

  • implementing the detach function in the backend to allow speedup on CPU/GPU in some solvers (which was already done in a previous PR but with limited doc).
  • Implement variants of Sinkhorn where computations are detached and gradients at convergence is returned instead

This PR should solve #565 and greatly limit memory for sinkhorn when computing gradienst wrt the value.

In order to use implicit diffeerntiation one needs to set the grad parameter in ot.solveand ot.solve_sampleas such

sol = ot.solve(M, a, b, reg=10, grad='implicit')
sol.value.backward()
# beware with  grad='implicit', sol.value_linear and sol.plan are not differentiable (not implemented yet).

On a simple example with pytorch arrays with required gradients, I has a 1000x gain in memory for solving the problem when a large number of sinkhorn operations are needed.

Motivation and context / Related issue

How has this been tested (if it applies)

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Copy link

codecov bot commented Feb 19, 2024

Codecov Report

Merging #605 (28fe869) into master (c84ef33) will increase coverage by 0.03%.
The diff coverage is 100.00%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #605      +/-   ##
==========================================
+ Coverage   96.75%   96.78%   +0.03%     
==========================================
  Files          77       77              
  Lines       15961    16002      +41     
==========================================
+ Hits        15443    15488      +45     
+ Misses        518      514       -4     

@rflamary rflamary changed the title [WIP] Add detach function to backend [WIP] Add implicit Sinkhorn gradients Feb 20, 2024
@rflamary rflamary changed the title [WIP] Add implicit Sinkhorn gradients [MRG] Add implicit Sinkhorn gradients Feb 20, 2024
Copy link
Collaborator

@cedricvincentcuaz cedricvincentcuaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello Rémi ! Ready to merge ;)

@rflamary rflamary merged commit 6f35804 into master Feb 20, 2024
17 of 18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants