Skip to content

CUDA out of memory when using ot.sinkhorn2 as a loss function #565

@Perian-Yan

Description

@Perian-Yan

Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2 causes CUDA out of memory error when it's being computed:

image

The ot.emd2 can also give this error when I use a larger data set.

By viewing nvidia-smi, when I try to train the same dataset, ot.emd2 uses up about 7G/12G memory (which is fine), while ot.sinkhorn2 uses 11G/12G memory and causes the error above.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions