From 3a29218194b172c4dc11d3b06715b5b6bbb003d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 19 Feb 2024 10:44:39 +0100 Subject: [PATCH 1/7] add detach function to backend --- ot/backend.py | 27 +++++++++++++++++++++++++++ test/test_backend.py | 7 +++++++ 2 files changed, 34 insertions(+) diff --git a/ot/backend.py b/ot/backend.py index 7645c4237..1214e3706 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -281,6 +281,17 @@ def set_gradients(self, val, inputs, grads): """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() + def detach(self, a): + """Detach the tensor from the computation graph""" + if len(arrays) == 1: + return self._detach(arrays[0], type_as=type_as) + else: + return [self._detach(array, type_as=type_as) for array in arrays] + + def _detach(self, a): + """Detach the tensor from the computation graph""" + raise NotImplementedError() + def zeros(self, shape, type_as=None): r""" Creates a tensor full of zeros. @@ -1082,6 +1093,10 @@ def set_gradients(self, val, inputs, grads): # No gradients for numpy return val + def _detach(self, a): + # No gradients for numpy + return a + def zeros(self, shape, type_as=None): if type_as is None: return np.zeros(shape) @@ -1462,6 +1477,9 @@ def set_gradients(self, val, inputs, grads): val, = jax.tree_map(lambda z: z + aux, (val,)) return val + def _detach(self, a): + return jax.lax.stop_gradient(a) + def zeros(self, shape, type_as=None): if type_as is None: return jnp.zeros(shape) @@ -1851,6 +1869,9 @@ def set_gradients(self, val, inputs, grads): return res + def _detach(self, a): + return a.detach() + def zeros(self, shape, type_as=None): if isinstance(shape, int): shape = (shape,) @@ -2312,6 +2333,9 @@ def set_gradients(self, val, inputs, grads): # No gradients for cupy return val + def _detach(self, a): + return a + def zeros(self, shape, type_as=None): if isinstance(shape, (list, tuple)): shape = tuple(int(i) for i in shape) @@ -2729,6 +2753,9 @@ def grad(upstream): return val, grad return tmp(inputs) + def _detach(self, a): + return tf.stop_gradient(a) + def zeros(self, shape, type_as=None): if type_as is None: return tnp.zeros(shape) diff --git a/test/test_backend.py b/test/test_backend.py index 3bc1e5480..2cf66dcf8 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -266,6 +266,8 @@ def test_empty_backend(): nx.matmul(M, M.T) with pytest.raises(NotImplementedError): nx.nan_to_num(M) + with pytest.raises(NotImplementedError): + nx.detach(M) def test_func_backends(nx): @@ -311,6 +313,11 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append('set_gradients') + A = nx.detach(Mb) + A, B = nx.detach(Mb, Mb) + lst_b.append(nx.to_numpy(A)) + lst_name.append('detach') + A = nx.zeros((10, 3)) A = nx.zeros((10, 3), type_as=Mb) lst_b.append(nx.to_numpy(A)) From 361c27b1120bb999aa591f6ca7305cd9335fae1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 19 Feb 2024 14:02:34 +0100 Subject: [PATCH 2/7] debug function --- ot/backend.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 1214e3706..f1b6c7e5d 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -281,12 +281,12 @@ def set_gradients(self, val, inputs, grads): """Define the gradients for the value val wrt the inputs """ raise NotImplementedError() - def detach(self, a): + def detach(self, *arrays): """Detach the tensor from the computation graph""" if len(arrays) == 1: - return self._detach(arrays[0], type_as=type_as) + return self._detach(arrays[0]) else: - return [self._detach(array, type_as=type_as) for array in arrays] + return [self._detach(array) for array in arrays] def _detach(self, a): """Detach the tensor from the computation graph""" From 05346e9fd2d10e1f640ebad31c7c7b5ba7de5213 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 19 Feb 2024 14:16:12 +0100 Subject: [PATCH 3/7] better detach --- ot/backend.py | 37 +++---------------------------------- test/test_backend.py | 12 +++++++----- 2 files changed, 10 insertions(+), 39 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index f1b6c7e5d..9cc6446bf 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -282,7 +282,9 @@ def set_gradients(self, val, inputs, grads): raise NotImplementedError() def detach(self, *arrays): - """Detach the tensor from the computation graph""" + """Detach the tensors from the computation graph + + See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html""" if len(arrays) == 1: return self._detach(arrays[0]) else: @@ -1038,14 +1040,6 @@ def transpose(self, a, axes=None): """ raise NotImplementedError() - def detach(self, *args): - r""" - Detach tensors in arguments from the current graph. - - See: https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html - """ - raise NotImplementedError() - def matmul(self, a, b): r""" Matrix product of two arrays. @@ -1407,11 +1401,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return np.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return args[0] - return args - def matmul(self, a, b): return np.matmul(a, b) @@ -1783,11 +1772,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return jnp.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return jax.lax.stop_gradient((args[0],))[0] - return [jax.lax.stop_gradient((a,))[0] for a in args] - def matmul(self, a, b): return jnp.matmul(a, b) @@ -2277,11 +2261,6 @@ def transpose(self, a, axes=None): axes = tuple(range(a.ndim)[::-1]) return a.permute(axes) - def detach(self, *args): - if len(args) == 1: - return args[0].detach() - return [a.detach() for a in args] - def matmul(self, a, b): return torch.matmul(a, b) @@ -2681,11 +2660,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return cp.transpose(a, axes) - def detach(self, *args): - if len(args) == 1: - return args[0] - return args - def matmul(self, a, b): return cp.matmul(a, b) @@ -3110,11 +3084,6 @@ def atan2(self, a, b): def transpose(self, a, axes=None): return tf.transpose(a, perm=axes) - def detach(self, *args): - if len(args) == 1: - return tf.stop_gradient(args[0]) - return [tf.stop_gradient(a) for a in args] - def matmul(self, a, b): return tnp.matmul(a, b) diff --git a/test/test_backend.py b/test/test_backend.py index 2cf66dcf8..da7293821 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -267,7 +267,13 @@ def test_empty_backend(): with pytest.raises(NotImplementedError): nx.nan_to_num(M) with pytest.raises(NotImplementedError): - nx.detach(M) + nx.sign(M) + with pytest.raises(NotImplementedError): + nx.dtype_device(M) + with pytest.raises(NotImplementedError): + nx.assert_same_dtype_device(M, M) + with pytest.raises(NotImplementedError): + nx.eigh(M) def test_func_backends(nx): @@ -659,10 +665,6 @@ def test_func_backends(nx): lst_b.append(nx.to_numpy(A)) lst_name.append("transpose") - A = nx.detach(Mb) - lst_b.append(nx.to_numpy(A)) - lst_name.append("detach") - A, B = nx.detach(Mb, Mb) lst_b.append(nx.to_numpy(A)) lst_name.append("detach A") From 6ba0f261492db3e75bdc9c33fc0be5ec19f347a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Feb 2024 10:00:41 +0100 Subject: [PATCH 4/7] new implementation --- ot/solvers.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index e4eca9575..14908f0ab 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -29,7 +29,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): + potentials_init=None, tol=None, verbose=False, grad='implicit'): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem @@ -79,6 +79,9 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either 'implicit' or 'explicit' only for + inkhorn solver. By default 'implicit'. Returns ------- @@ -297,6 +300,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if reg_type.lower() in ['entropy', 'kl']: + if grad == 'implicit': # if implicit then detach the input + M0, a0, b0 = M, a, b + M, a, b = nx.detach(M, a, b) + # default values for sinkhorn if max_iter is None: max_iter = 1000 @@ -316,6 +323,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, potentials = (log['log_u'], log['log_v']) + if grad == 'implicit': # set the gradient at convergence + + value = nx.set_gradients(value, (M0, a0, b0), + (plan, potentials[0], potentials[1])) + elif reg_type.lower() == 'l2': if max_iter is None: @@ -869,7 +881,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, - potentials_init=None, X_init=None, tol=None, verbose=False): + potentials_init=None, X_init=None, tol=None, verbose=False, + grad='implicit'): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -935,6 +948,9 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False + grad : str, optional + Type of gradient computation, either 'implicit' or 'unroll' (only for + sinkhorn solver), by default 'implicit'. Returns ------- @@ -1189,7 +1205,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) return res From 37f3eed58f4f491154cda0a4b204f34b06a8ea20 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Feb 2024 10:58:01 +0100 Subject: [PATCH 5/7] add test for gradient --- ot/solvers.py | 2 +- test/test_solvers.py | 42 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index 14908f0ab..700945e21 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -326,7 +326,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, if grad == 'implicit': # set the gradient at convergence value = nx.set_gradients(value, (M0, a0, b0), - (plan, potentials[0], potentials[1])) + (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) elif reg_type.lower() == 'l2': diff --git a/test/test_solvers.py b/test/test_solvers.py index 164989811..24fb304bb 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -12,6 +12,7 @@ import ot from ot.bregman import geomloss +from ot.backend import torch lst_reg = [None, 1] lst_reg_type = ['KL', 'entropy', 'L2'] @@ -107,6 +108,47 @@ def test_solve(nx): sol0 = ot.solve(M, reg=1, reg_type='cryptic divergence') +@pytest.mark.skipif(not torch, reason="torch no installed") +def test_solve_implicit(): + + n_samples_s = 10 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + M = ot.dist(x, y) + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol0 = ot.solve(M, a, b, reg=1) + sol0.value.backward() + + gM0 = M.grad.clone() + ga0 = a.grad.clone() + gb0 = b.grad.clone() + + a = torch.tensor(a, requires_grad=True) + b = torch.tensor(b, requires_grad=True) + M = torch.tensor(M, requires_grad=True) + + sol = ot.solve(M, a, b, reg=1, grad='unroll') + sol.value.backward() + + gM = M.grad.clone() + ga = a.grad.clone() + gb = b.grad.clone() + + assert torch.allclose(gM0, gM) + assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) + assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) + + @pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) def test_solve_grid(nx, reg, reg_type, unbalanced, unbalanced_type): n_samples_s = 10 From 2c27a430dd25a358e8c6f8fe3b3f2cfc9c6265c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Feb 2024 11:37:57 +0100 Subject: [PATCH 6/7] better default --- ot/solvers.py | 18 ++++++++++++------ test/test_solvers.py | 5 +++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 700945e21..689c8b5d2 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -29,7 +29,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False, grad='implicit'): + potentials_init=None, tol=None, verbose=False, grad='autodiff'): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem @@ -80,8 +80,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either 'implicit' or 'explicit' only for - inkhorn solver. By default 'implicit'. + Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'implicit' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. Returns ------- @@ -882,7 +885,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False, - grad='implicit'): + grad='autodiff'): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -949,8 +952,11 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t verbose : bool, optional Print information in the solver, by default False grad : str, optional - Type of gradient computation, either 'implicit' or 'unroll' (only for - sinkhorn solver), by default 'implicit'. + Type of gradient computation, either or 'autodiff' or 'implicit' used only for + Sinkhorn solver. By default 'autodiff' provides gradients wrt all + outputs (`plan, value, value_linear`) but with important memory cost. + 'implicit' provides gradients only for `value` and and other outputs are + detached. This is useful for memory saving when only the value is needed. Returns ------- diff --git a/test/test_solvers.py b/test/test_solvers.py index 24fb304bb..168b111e4 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -126,7 +126,7 @@ def test_solve_implicit(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol0 = ot.solve(M, a, b, reg=1) + sol0 = ot.solve(M, a, b, reg=10, grad='implicit') sol0.value.backward() gM0 = M.grad.clone() @@ -137,13 +137,14 @@ def test_solve_implicit(): b = torch.tensor(b, requires_grad=True) M = torch.tensor(M, requires_grad=True) - sol = ot.solve(M, a, b, reg=1, grad='unroll') + sol = ot.solve(M, a, b, reg=10, grad='autodiff') sol.value.backward() gM = M.grad.clone() ga = a.grad.clone() gb = b.grad.clone() + # Note, gradients aer invariant to change in constant so we center them assert torch.allclose(gM0, gM) assert torch.allclose(ga0 - ga0.mean(), ga - ga.mean()) assert torch.allclose(gb0 - gb0.mean(), gb - gb.mean()) From 28fe8695c1dec4e191f16e84e923b5fc17ad542f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Feb 2024 11:42:46 +0100 Subject: [PATCH 7/7] update documentation --- ot/solvers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ot/solvers.py b/ot/solvers.py index 689c8b5d2..de817d7f7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -140,6 +140,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') + # Use implicit differentiation for memory saving + res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `implicit` mode computes the implicit gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: @@ -1024,6 +1034,16 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # lazy OT plan lazy_plan = res.lazy_plan + # Use implicit differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `implicit` mode computes the implicit gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + We also have a very efficient solver with compiled CPU/CUDA code using geomloss/PyKeOps that can be used with the following code: