-
Notifications
You must be signed in to change notification settings - Fork 528
Description
Describe the bug
Backpropagating through the emd2 (earth mover distance) of pot does not take into account scaling factors that can be performed afterwards.
The gradient of
To Reproduce
Here is a minimalistic code sample to reproduce the unexpected behaviour of backpropagation.
Code sample
import torch
import ot
# Fix the seed
torch.manual_seed(0)
# Number of samples / dimension of data
N,d=5,2
# Generate random dummy data
X=torch.randn(N,d).double()
# Distance matrix
M=ot.dist(X)
# Create random coefficient vectors (normalised) that require gradients
a=torch.abs(torch.randn(N,dtype=torch.float64))
b=torch.abs(torch.randn(N,dtype=torch.float64))
a=a/torch.sum(a)
b=b/torch.sum(b)
a.requires_grad=True
b.requires_grad=True
# Compute Earth Mover distance
emd=ot.emd2(a,b,M)
# Backprop
emd.backward()
# Print gradients
print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
# Now, do all the same operations, but seek to maximise the loss instead of minimise
a.grad,b.grad=None,None
# Adding a - sign here
emd=-ot.emd2(a,b,M)
emd.backward()
# Print gradients, only the sign should change. It is not the case
print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
# The same applies in fact to other constants
a.grad,b.grad=None,None
emd=100*ot.emd2(a,b,M)
# Backprop
emd.backward()
print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
Expected behavior
If we multiply the result of the emd2 function by a constant
Environment (please complete the following information):
- OS: Linux
- Python version: 3.8.8
- How was POT installed: pip
Output of the following code snippet:
import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
import torch; print("Torch", torch.__version__)
Linux-5.11.0-40-generic-x86_64-with-glibc2.10
Python 3.8.8 (default, Feb 24 2021, 21:46:12)
[GCC 7.3.0]
NumPy 1.20.2
SciPy 1.6.1
POT 0.8.0
Torch 1.8.1+cu102
Additional context
We did not put it here in the example, but working on GPU device with torch yields the same result.