Skip to content

empirical_sinkhorn_divergence doesn't have a grad_fn #393

@gabrielsantosrv

Description

@gabrielsantosrv

Describe the bug

I am trying to use empirical_sinkhorn_divergence as a loss function in pytorch, but the returned tensor does not have a grad_fn, so the gradient can't be propagated.

Code sample

loss = ot.bregman.empirical_sinkhorn_divergence(source, target, 1)

Expected behavior

Return a tensor with a grad_fn.

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Linux
  • Python version: 3.8.13
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version: 11.2
    • GPU models and configuration: Quadro RTX 8000

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.4.0-73-generic-x86_64-with-glibc2.17
Python 3.8.13 (default, Mar 28 2022, 11:38:47) 
[GCC 7.5.0]
NumPy 1.21.6
SciPy 1.8.1
POT 0.8.2

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions