You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2 causes CUDA out of memory error when it's being computed:
The ot.emd2 can also give this error when I use a larger data set.
By viewing nvidia-smi, when I try to train the same dataset, ot.emd2 uses up about 7G/12G memory (which is fine), while ot.sinkhorn2 uses 11G/12G memory and causes the error above.
Thank you!
The text was updated successfully, but these errors were encountered:
This probably comes form the fact that torch needs to keep the intermediate values in sinkhorn in memory to allow for a backward(). I have been planning to add a detach_iterations parameter for a hile that will allow to run the whole algorithm and plug the gradient assuming convrgence (implicit differettiation) for a while. I will get to that when I have more time.
Could you please check if sinkhorn2 explodes in memory when you give it arrays with keep_gradient=False?
Hello, I have added an option to ot.solvethat forces the use of implicit gradients and limits the memory use:
you can compute the sinkhorn loss with
loss=ot.solve(M, a, b, reg=1, grad='implicit').value
All iterations are detached and the gradient is set at the end with no memory overhead but then it is differentiable only wrt value (not value_linear or the OT plan). Could you tell me if it solves your problem, it is merges in master brach ?
Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However,
ot.sinkhorn2
causes CUDA out of memory error when it's being computed:The
ot.emd2
can also give this error when I use a larger data set.By viewing
nvidia-smi
, when I try to train the same dataset,ot.emd2
uses up about 7G/12G memory (which is fine), whileot.sinkhorn2
uses 11G/12G memory and causes the error above.Thank you!
The text was updated successfully, but these errors were encountered: