Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Add implicit Sinkhorn gradients #605

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,19 @@ def set_gradients(self, val, inputs, grads):
"""Define the gradients for the value val wrt the inputs """
raise NotImplementedError()

def detach(self, *arrays):
"""Detach the tensors from the computation graph

See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html"""
if len(arrays) == 1:
return self._detach(arrays[0])
else:
return [self._detach(array) for array in arrays]

def _detach(self, a):
"""Detach the tensor from the computation graph"""
raise NotImplementedError()

def zeros(self, shape, type_as=None):
r"""
Creates a tensor full of zeros.
Expand Down Expand Up @@ -1027,14 +1040,6 @@ def transpose(self, a, axes=None):
"""
raise NotImplementedError()

def detach(self, *args):
r"""
Detach tensors in arguments from the current graph.

See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
"""
raise NotImplementedError()

def matmul(self, a, b):
r"""
Matrix product of two arrays.
Expand Down Expand Up @@ -1082,6 +1087,10 @@ def set_gradients(self, val, inputs, grads):
# No gradients for numpy
return val

def _detach(self, a):
# No gradients for numpy
return a

def zeros(self, shape, type_as=None):
if type_as is None:
return np.zeros(shape)
Expand Down Expand Up @@ -1392,11 +1401,6 @@ def atan2(self, a, b):
def transpose(self, a, axes=None):
return np.transpose(a, axes)

def detach(self, *args):
if len(args) == 1:
return args[0]
return args

def matmul(self, a, b):
return np.matmul(a, b)

Expand Down Expand Up @@ -1462,6 +1466,9 @@ def set_gradients(self, val, inputs, grads):
val, = jax.tree_map(lambda z: z + aux, (val,))
return val

def _detach(self, a):
return jax.lax.stop_gradient(a)

def zeros(self, shape, type_as=None):
if type_as is None:
return jnp.zeros(shape)
Expand Down Expand Up @@ -1765,11 +1772,6 @@ def atan2(self, a, b):
def transpose(self, a, axes=None):
return jnp.transpose(a, axes)

def detach(self, *args):
if len(args) == 1:
return jax.lax.stop_gradient((args[0],))[0]
return [jax.lax.stop_gradient((a,))[0] for a in args]

def matmul(self, a, b):
return jnp.matmul(a, b)

Expand Down Expand Up @@ -1851,6 +1853,9 @@ def set_gradients(self, val, inputs, grads):

return res

def _detach(self, a):
return a.detach()

def zeros(self, shape, type_as=None):
if isinstance(shape, int):
shape = (shape,)
Expand Down Expand Up @@ -2256,11 +2261,6 @@ def transpose(self, a, axes=None):
axes = tuple(range(a.ndim)[::-1])
return a.permute(axes)

def detach(self, *args):
if len(args) == 1:
return args[0].detach()
return [a.detach() for a in args]

def matmul(self, a, b):
return torch.matmul(a, b)

Expand Down Expand Up @@ -2312,6 +2312,9 @@ def set_gradients(self, val, inputs, grads):
# No gradients for cupy
return val

def _detach(self, a):
return a

def zeros(self, shape, type_as=None):
if isinstance(shape, (list, tuple)):
shape = tuple(int(i) for i in shape)
Expand Down Expand Up @@ -2657,11 +2660,6 @@ def atan2(self, a, b):
def transpose(self, a, axes=None):
return cp.transpose(a, axes)

def detach(self, *args):
if len(args) == 1:
return args[0]
return args

def matmul(self, a, b):
return cp.matmul(a, b)

Expand Down Expand Up @@ -2729,6 +2727,9 @@ def grad(upstream):
return val, grad
return tmp(inputs)

def _detach(self, a):
return tf.stop_gradient(a)

def zeros(self, shape, type_as=None):
if type_as is None:
return tnp.zeros(shape)
Expand Down Expand Up @@ -3083,11 +3084,6 @@ def atan2(self, a, b):
def transpose(self, a, axes=None):
return tf.transpose(a, perm=axes)

def detach(self, *args):
if len(args) == 1:
return tf.stop_gradient(args[0])
return [tf.stop_gradient(a) for a in args]

def matmul(self, a, b):
return tnp.matmul(a, b)

Expand Down
48 changes: 45 additions & 3 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None,
potentials_init=None, tol=None, verbose=False):
potentials_init=None, tol=None, verbose=False, grad='autodiff'):
r"""Solve the discrete optimal transport problem and return :any:`OTResult` object

The function solves the following general optimal transport problem
Expand Down Expand Up @@ -79,6 +79,12 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'implicit' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.

Returns
-------
Expand Down Expand Up @@ -134,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
# or for original Sinkhorn paper formulation [2]
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')

# Use implicit differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
res.value.backward() # only the value is differentiable

Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `implicit` mode computes the implicit gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.

- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):

.. math::
Expand Down Expand Up @@ -297,6 +313,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

if reg_type.lower() in ['entropy', 'kl']:

if grad == 'implicit': # if implicit then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

# default values for sinkhorn
if max_iter is None:
max_iter = 1000
Expand All @@ -316,6 +336,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

potentials = (log['log_u'], log['log_v'])

if grad == 'implicit': # set the gradient at convergence

value = nx.set_gradients(value, (M0, a0, b0),
(plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean())))

elif reg_type.lower() == 'l2':

if max_iter is None:
Expand Down Expand Up @@ -869,7 +894,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL",
unbalanced=None,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95,
potentials_init=None, X_init=None, tol=None, verbose=False):
potentials_init=None, X_init=None, tol=None, verbose=False,
grad='autodiff'):
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.

The function solves the following general optimal transport problem
Expand Down Expand Up @@ -935,6 +961,12 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'implicit' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.

Returns
-------
Expand Down Expand Up @@ -1002,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
# lazy OT plan
lazy_plan = res.lazy_plan

# Use implicit differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
res.value.backward() # only the value is differentiable

Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `implicit` mode computes the implicit gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.

We also have a very efficient solver with compiled CPU/CUDA code using
geomloss/PyKeOps that can be used with the following code:

Expand Down Expand Up @@ -1189,7 +1231,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
# compute cost matrix M and use solve function
M = dist(X_a, X_b, metric)

res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose)
res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad)

return res

Expand Down
17 changes: 13 additions & 4 deletions test/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ def test_empty_backend():
nx.matmul(M, M.T)
with pytest.raises(NotImplementedError):
nx.nan_to_num(M)
with pytest.raises(NotImplementedError):
nx.sign(M)
with pytest.raises(NotImplementedError):
nx.dtype_device(M)
with pytest.raises(NotImplementedError):
nx.assert_same_dtype_device(M, M)
with pytest.raises(NotImplementedError):
nx.eigh(M)


def test_func_backends(nx):
Expand Down Expand Up @@ -311,6 +319,11 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append('set_gradients')

A = nx.detach(Mb)
A, B = nx.detach(Mb, Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append('detach')

A = nx.zeros((10, 3))
A = nx.zeros((10, 3), type_as=Mb)
lst_b.append(nx.to_numpy(A))
Expand Down Expand Up @@ -652,10 +665,6 @@ def test_func_backends(nx):
lst_b.append(nx.to_numpy(A))
lst_name.append("transpose")

A = nx.detach(Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append("detach")

A, B = nx.detach(Mb, Mb)
lst_b.append(nx.to_numpy(A))
lst_name.append("detach A")
Expand Down
43 changes: 43 additions & 0 deletions test/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import ot
from ot.bregman import geomloss
from ot.backend import torch

lst_reg = [None, 1]
lst_reg_type = ['KL', 'entropy', 'L2']
Expand Down Expand Up @@ -107,6 +108,48 @@ def test_solve(nx):
sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence')


@pytest.mark.skipif(not torch, reason="torch no installed")
def test_solve_implicit():

n_samples_s = 10
n_samples_t = 7
n_features = 2
rng = np.random.RandomState(0)

x = rng.randn(n_samples_s, n_features)
y = rng.randn(n_samples_t, n_features)
a = ot.utils.unif(n_samples_s)
b = ot.utils.unif(n_samples_t)
M = ot.dist(x, y)

a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = torch.tensor(M, requires_grad=True)

sol0 = ot.solve(M, a, b, reg=10, grad='implicit')
sol0.value.backward()

gM0 = M.grad.clone()
ga0 = a.grad.clone()
gb0 = b.grad.clone()

a = torch.tensor(a, requires_grad=True)
b = torch.tensor(b, requires_grad=True)
M = torch.tensor(M, requires_grad=True)

sol = ot.solve(M, a, b, reg=10, grad='autodiff')
sol.value.backward()

gM = M.grad.clone()
ga = a.grad.clone()
gb = b.grad.clone()

# Note, gradients aer invariant to change in constant so we center them
assert torch.allclose(gM0, gM)
assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean())
assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean())


@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type))
def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type):
n_samples_s = 10
Expand Down
Loading