Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA out of memory when using ot.sinkhorn2 as a loss function #565

Closed
Perian-Yan opened this issue Nov 7, 2023 · 4 comments
Closed

CUDA out of memory when using ot.sinkhorn2 as a loss function #565

Perian-Yan opened this issue Nov 7, 2023 · 4 comments

Comments

@Perian-Yan
Copy link

Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2 causes CUDA out of memory error when it's being computed:

image

The ot.emd2 can also give this error when I use a larger data set.

By viewing nvidia-smi, when I try to train the same dataset, ot.emd2 uses up about 7G/12G memory (which is fine), while ot.sinkhorn2 uses 11G/12G memory and causes the error above.

Thank you!

@rflamary
Copy link
Collaborator

rflamary commented Nov 9, 2023

Hello,

This probably comes form the fact that torch needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence (implicit differettiation) for a while. I will get to that when I have more time.

Could you please check if sinkhorn2 explodes in memory when you give it arrays with keep_gradient=False?

@rflamary
Copy link
Collaborator

rflamary commented Feb 20, 2024

Hello, I have added an option to ot.solvethat forces the use of implicit gradients and limits the memory use:

you can compute the sinkhorn loss with

loss = ot.solve(M, a, b, reg=1, grad='implicit').value

All iterations are detached and the gradient is set at the end with no memory overhead but then it is differentiable only wrt value (not value_linear or the OT plan). Could you tell me if it solves your problem, it is merges in master brach ?

@Guddubhaiya07

This comment was marked as spam.

@cedricvincentcuaz
Copy link
Collaborator

This new functionality is now available in ot.solve and ot.solve_sample thanks to PR #605 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants