-
Notifications
You must be signed in to change notification settings - Fork 528
Closed
Labels
Description
Describe the bug
I am trying to use empirical_sinkhorn_divergence as a loss function in pytorch, but the returned tensor does not have a grad_fn, so the gradient can't be propagated.
Code sample
loss = ot.bregman.empirical_sinkhorn_divergence(source, target, 1)
Expected behavior
Return a tensor with a grad_fn.
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux): Linux
- Python version: 3.8.13
- How was POT installed (source,
pip
,conda
): pip - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version: 11.2
- GPU models and configuration: Quadro RTX 8000
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
Linux-5.4.0-73-generic-x86_64-with-glibc2.17
Python 3.8.13 (default, Mar 28 2022, 11:38:47)
[GCC 7.5.0]
NumPy 1.21.6
SciPy 1.8.1
POT 0.8.2