Skip to content

Constants do not backpropagate through function emd2 in torch #309

@oshillou

Description

@oshillou

Describe the bug

Backpropagating through the emd2 (earth mover distance) of pot does not take into account scaling factors that can be performed afterwards.

The gradient of $a \times \text{emd}(...)$ is the same as $\text{emd}(...)$.

To Reproduce

Here is a minimalistic code sample to reproduce the unexpected behaviour of backpropagation.

Code sample

import torch
import ot

# Fix the seed 
torch.manual_seed(0)

# Number of samples / dimension of data
N,d=5,2

# Generate random dummy data
X=torch.randn(N,d).double()

# Distance matrix
M=ot.dist(X)

# Create random coefficient vectors (normalised) that require gradients
a=torch.abs(torch.randn(N,dtype=torch.float64))
b=torch.abs(torch.randn(N,dtype=torch.float64))

a=a/torch.sum(a)
b=b/torch.sum(b)

a.requires_grad=True
b.requires_grad=True
# Compute Earth Mover distance
emd=ot.emd2(a,b,M)
# Backprop
emd.backward()

# Print gradients
print(a.grad,b.grad)

tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)

# Now, do all the same operations, but seek to maximise the loss instead of minimise

a.grad,b.grad=None,None

# Adding a - sign here
emd=-ot.emd2(a,b,M)

emd.backward()

# Print gradients, only the sign should change. It is not the case
print(a.grad,b.grad)

tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)

# The same applies in fact to other constants

a.grad,b.grad=None,None

emd=100*ot.emd2(a,b,M)
# Backprop
emd.backward()

print(a.grad,b.grad)

tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)

Expected behavior

If we multiply the result of the emd2 function by a constant $a$, then the gradient afterbackpropagation should also be scaled by that constant.

Environment (please complete the following information):

  • OS: Linux
  • Python version: 3.8.8
  • How was POT installed: pip

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)
import torch; print("Torch", torch.__version__)

Linux-5.11.0-40-generic-x86_64-with-glibc2.10
Python 3.8.8 (default, Feb 24 2021, 21:46:12)
[GCC 7.3.0]
NumPy 1.20.2
SciPy 1.6.1
POT 0.8.0
Torch 1.8.1+cu102

Additional context

We did not put it here in the example, but working on GPU device with torch yields the same result.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions