From eeaca579ed9842a4faa989b6394b72a543f9af11 Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Fri, 14 Apr 2023 19:24:32 +0200 Subject: [PATCH 1/7] add sparsity-constrained ot funtionality and example --- README.md | 2 + examples/plot_OT_1D_smooth.py | 2 +- examples/plot_OT_1D_sparsity_constrained.py | 94 ++++++++ ot/__init__.py | 1 + ot/sparse.py | 229 ++++++++++++++++++++ test/test_sparse.py | 62 ++++++ 6 files changed, 389 insertions(+), 1 deletion(-) create mode 100644 examples/plot_OT_1D_sparsity_constrained.py create mode 100644 ot/sparse.py create mode 100644 test/test_sparse.py diff --git a/README.md b/README.md index 2a81e95ab..f0fb4bd0e 100644 --- a/README.md +++ b/README.md @@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. + +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR). diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 5415e4f6c..0ddc052dd 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -101,7 +101,7 @@ pl.show() -#%% Smooth OT with KL regularization +#%% Smooth OT with squared l2 regularization lambd = 1e-1 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') diff --git a/examples/plot_OT_1D_sparsity_constrained.py b/examples/plot_OT_1D_sparsity_constrained.py new file mode 100644 index 000000000..88a801b7f --- /dev/null +++ b/examples/plot_OT_1D_sparsity_constrained.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +""" +================================ +Sparsity-constrained optimal transport example +================================ + +This example illustrates EMD, squared l2 regularized OT, and sparsity-constrained OT plans. +The sparsity-constrained OT can be considered as a middle ground between EMD and squared l2 regularized OT. + +""" + +# Author: Tianlin Liu +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 5 + +import numpy as np +import matplotlib.pylab as pl +import ot +import ot.plot +from ot.datasets import make_1D_gauss as gauss + +############################################################################## +# Generate data +# ------------- + + +#%% parameters + +n = 100 # nb bins + +# bin positions +x = np.arange(n, dtype=np.float64) + +# Gaussian distributions +a = gauss(n, m=20, s=5) # m= mean, s= std +b = gauss(n, m=60, s=10) + +# loss matrix +M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) +M /= M.max() + + +############################################################################## +# Plot distributions and loss matrix +# ---------------------------------- + +#%% plot the distributions + +pl.figure(1, figsize=(6.4, 3)) +pl.plot(x, a, 'b', label='Source distribution') +pl.plot(x, b, 'r', label='Target distribution') +pl.legend() + +#%% plot distributions and loss matrix + +pl.figure(2, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') + + +#%% EMD + +# use fast 1D solver +G0 = ot.emd_1d(x, x, a, b) + +# Equivalent to +# G0 = ot.emd(a, b, M) + +pl.figure(3, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') + + +#%% Smooth OT with squared l2 regularization + +lambd = 1e-1 +Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') + +pl.figure(4, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') + +pl.show() + + +#%% Smooth OT with squared l2 regularization + +lambd = 1e-1 +Gsc = ot.sparse.sparsity_constrained_ot_dual(a, b, M, lambd, max_nz=2) +pl.figure(5, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') + +pl.show() + +# %% diff --git a/ot/__init__.py b/ot/__init__.py index 1a685b601..2c4cc6cf4 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -27,6 +27,7 @@ from . import gromov from . import smooth from . import stochastic +from . import sparse from . import unbalanced from . import partial from . import backend diff --git a/ot/sparse.py b/ot/sparse.py new file mode 100644 index 000000000..dd89ddbc1 --- /dev/null +++ b/ot/sparse.py @@ -0,0 +1,229 @@ +""" +Sparsity-constrained optimal transport solvers. + +Implementation of : +Sparsity-Constrained Optimal Transport. +Tianlin Liu, Joan Puigcerver, Mathieu Blondel. +In Proc. of AISTATS 2018. +https://arxiv.org/abs/1710.06276 + +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). +""" + +# Author: Tianlin Liu +# +# License: MIT License + + +import numpy as np +import ot +from .backend import get_backend + + +class SparsityConstrained(ot.smooth.Regularization): + """ Squared L2 regularization with sparsity constraints """ + + def __init__(self, max_nz, gamma=1.0): + self.max_nz = max_nz + self.gamma = gamma + + def delta_Omega(self, X): + # For each column of X, find entries that are not among the top max_nz. + non_top_indices = np.argpartition( + -X, self.max_nz, axis=0)[self.max_nz:] + # Set these entries to -inf. + X[non_top_indices, np.arange(X.shape[1])] = -np.inf + max_X = np.maximum(X, 0) + val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + G = max_X / self.gamma + return val, G + + def max_Omega(self, X, b): + # For each column of X, find top max_nz values and + # their corresponding indices. + max_nz_indices = np.argpartition( + X, + kth=-self.max_nz, + axis=0)[-self.max_nz:] + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values / (b * self.gamma), axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + val = np.sum(X * G, axis=0) + val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) + return val, G + + def Omega(self, T): + return 0.5 * self.gamma * np.sum(T ** 2) + + +def sparsity_constrained_ot_dual( + a, b, M, reg, max_nz, + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the sparsity-constrained OT problem in the dual and return the OT matrix. + + The function solves the sparsity-contrained OT in dual formulation in + :ref:`[50] `. + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + max_nz: int + Maximum number of non-zero entries permitted in each column of the + optimal transport matrix. + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sparsity-constrained-ot-dual: + References + ---------- + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.smooth : Entropic regularized and squared l2 regularized OT + ot.optim.cg : General regularized OT + + """ + + nx = get_backend(a, b, M) + max_nz = min(max_nz, M.shape[0]) + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) + + a0, b0, M0 = a, b, M + + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + + # solve dual + alpha, beta, res = ot.smooth.solve_dual( + a, b, M, regul, + max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = nx.from_numpy(ot.smooth.get_plan_from_dual(alpha, beta, M, regul), + type_as=M0) + + if log: + log = {'alpha': nx.from_numpy(alpha, type_as=a0), + 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} + return G, log + else: + return G + + +def sparsity_constrained_ot_semi_dual( + a, b, M, reg, max_nz, + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the semi-dual and return the OT matrix + + The function solves the sparsity-contrained OT in semi-dual formulation in + :ref:`[50] `. + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + max_nz: int + Maximum number of non-zero entries permitted in each column of the optimal transport matrix. + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sparsity-constrained-ot-semi-dual: + References + ---------- + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.smooth : Entropic regularized and squared l2 regularized OT + ot.optim.cg : General regularized OT + + """ + + max_nz = min(max_nz, M.shape[0]) + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) + # solve dual + alpha, res = ot.smooth.solve_semi_dual( + a, b, M, regul, max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = ot.smooth.get_plan_from_semi_dual(alpha, b, M, regul) + + if log: + log = {'alpha': alpha, 'res': res} + return G, log + else: + return G diff --git a/test/test_sparse.py b/test/test_sparse.py new file mode 100644 index 000000000..f27c567cf --- /dev/null +++ b/test/test_sparse.py @@ -0,0 +1,62 @@ +"""Tests for ot.smooth model """ + +# Author: Tianlin Liu +# +# License: MIT License + +import numpy as np +import ot + + +def test_sparsity_constrained_ot_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + max_nz = 2 + + plan = ot.sparse.sparsity_constrained_ot_dual( + u, u, M, 1, max_nz=max_nz, stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) + np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) + + # check sparsity constraint + np.testing.assert_array_less( + np.sum(plan > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_sparsity_constrained_ot_semi_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + max_nz = 2 + plan, log = ot.sparse.sparsity_constrained_ot_semi_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) + np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) + + # check sparsity constraint + np.testing.assert_array_less( + np.sum(plan > 0, axis=0), + np.ones(n) * max_nz + 1) From e8bb4e0c61db57b4029f984a03422a0393bb884c Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Sat, 15 Apr 2023 18:16:50 +0200 Subject: [PATCH 2/7] correct typos; add projection_sparse_simplex --- ot/sparse.py | 88 +++++++++++++++++++++++++++++++++++---------- test/test_sparse.py | 55 +++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 20 deletions(-) diff --git a/ot/sparse.py b/ot/sparse.py index dd89ddbc1..0f8fbe8a5 100644 --- a/ot/sparse.py +++ b/ot/sparse.py @@ -3,9 +3,11 @@ Implementation of : Sparsity-Constrained Optimal Transport. -Tianlin Liu, Joan Puigcerver, Mathieu Blondel. -In Proc. of AISTATS 2018. -https://arxiv.org/abs/1710.06276 +Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). +https://arxiv.org/abs/2209.15466 [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. @@ -23,6 +25,67 @@ from .backend import get_backend +def projection_sparse_simplex(V, max_nz, z=1, axis=None): + r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. + + .. math:: + P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: ndarray, rank 2 + z: float or array + If array, len(z) must be compatible with :math:`\mathbf{V}` + axis: None or int + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape + + References: + Sparse projections onto the simplex + Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch + ICML 2013 + https://arxiv.org/abs/1206.1529 + """ + if axis == 1: + max_nz_indices = np.argpartition( + V, + kth=-max_nz, + axis=1)[:, -max_nz:] + # Record nonzero column indices in a descending order + max_nz_indices = max_nz_indices[:, ::-1] + + row_indices = np.arange(V.shape[0])[:, np.newaxis] + + # Extract the top max_nz values for each row + # and then project to simplex. + U = V[row_indices, max_nz_indices] + z = np.ones(len(U)) * z + cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] + ind = np.arange(max_nz) + 1 + cond = U - cssv / ind > 0 + rho = np.count_nonzero(cond, axis=1) + theta = cssv[np.arange(len(U)), rho - 1] / rho + nz_projection = np.maximum(U - theta[:, np.newaxis], 0) + + # Put the projection of max_nz_values to their original column indices + # while keeping other values zero. + sparse_projection = np.zeros_like(V) + sparse_projection[row_indices, max_nz_indices] = nz_projection + return sparse_projection + + elif axis == 0: + return projection_sparse_simplex(V.T, max_nz, z, axis=1).T + + else: + V = V.ravel().reshape(1, -1) + return projection_sparse_simplex(V, max_nz, z, axis=1).ravel() + + class SparsityConstrained(ot.smooth.Regularization): """ Squared L2 regularization with sparsity constraints """ @@ -42,22 +105,9 @@ def delta_Omega(self, X): return val, G def max_Omega(self, X, b): - # For each column of X, find top max_nz values and - # their corresponding indices. - max_nz_indices = np.argpartition( - X, - kth=-self.max_nz, - axis=0)[-self.max_nz:] - max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] - - # Project the top max_nz values onto the simplex. - G_nz_values = ot.smooth.projection_simplex( - max_nz_values / (b * self.gamma), axis=0) - - # Put the projection of max_nz_values to their original indices - # and set all other values zero. - G = np.zeros_like(X) - G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + # Project the scaled X onto the simplex with sparsity constraint. + G = projection_sparse_simplex( + X / (b * self.gamma), self.max_nz, axis=0) val = np.sum(X * G, axis=0) val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) return val, G diff --git a/test/test_sparse.py b/test/test_sparse.py index f27c567cf..ecc04f745 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -1,4 +1,4 @@ -"""Tests for ot.smooth model """ +"""Tests for ot.sparse model """ # Author: Tianlin Liu # @@ -60,3 +60,56 @@ def test_sparsity_constrained_ot_semi_dual(): np.testing.assert_array_less( np.sum(plan > 0, axis=0), np.ones(n) * max_nz + 1) + + +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.sparse.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.sparse.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce the same result + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) From 99daa664165948c7f15a30dd55213177495e6310 Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Tue, 18 Apr 2023 20:10:15 +0200 Subject: [PATCH 3/7] add gradcheck; merge ot.sparse into ot.smooth. --- examples/plot_OT_1D_smooth.py | 15 +- examples/plot_OT_1D_sparsity_constrained.py | 94 ------- ot/__init__.py | 1 - ot/smooth.py | 216 ++++++++++++++- ot/sparse.py | 279 -------------------- ot/utils.py | 68 +++++ test/test_smooth.py | 134 ++++++++++ test/test_sparse.py | 115 -------- 8 files changed, 428 insertions(+), 494 deletions(-) delete mode 100644 examples/plot_OT_1D_sparsity_constrained.py delete mode 100644 ot/sparse.py delete mode 100644 test/test_sparse.py diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 0ddc052dd..517639277 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -4,8 +4,8 @@ Smooth optimal transport example ================================ -This example illustrates the computation of EMD, Sinkhorn and smooth OT plans -and their visualization. +This example illustrates the computation of EMD, Sinkhorn, smooth OT plans, +sparsity-constrained plans, and their visualization. """ @@ -110,3 +110,14 @@ ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') pl.show() + +#%% Sparsity-constrained OT + +lambd = 1e-1 + +max_nz = 2 # two non-zero entries are permitted per column of the OT plan +Gsc = ot.smooth.sparsity_constrained_ot_dual(a, b, M, lambd, max_nz=max_nz) +pl.figure(5, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') + +pl.show() diff --git a/examples/plot_OT_1D_sparsity_constrained.py b/examples/plot_OT_1D_sparsity_constrained.py deleted file mode 100644 index 88a801b7f..000000000 --- a/examples/plot_OT_1D_sparsity_constrained.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -""" -================================ -Sparsity-constrained optimal transport example -================================ - -This example illustrates EMD, squared l2 regularized OT, and sparsity-constrained OT plans. -The sparsity-constrained OT can be considered as a middle ground between EMD and squared l2 regularized OT. - -""" - -# Author: Tianlin Liu -# -# License: MIT License - -# sphinx_gallery_thumbnail_number = 5 - -import numpy as np -import matplotlib.pylab as pl -import ot -import ot.plot -from ot.datasets import make_1D_gauss as gauss - -############################################################################## -# Generate data -# ------------- - - -#%% parameters - -n = 100 # nb bins - -# bin positions -x = np.arange(n, dtype=np.float64) - -# Gaussian distributions -a = gauss(n, m=20, s=5) # m= mean, s= std -b = gauss(n, m=60, s=10) - -# loss matrix -M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() - - -############################################################################## -# Plot distributions and loss matrix -# ---------------------------------- - -#%% plot the distributions - -pl.figure(1, figsize=(6.4, 3)) -pl.plot(x, a, 'b', label='Source distribution') -pl.plot(x, b, 'r', label='Target distribution') -pl.legend() - -#%% plot distributions and loss matrix - -pl.figure(2, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') - - -#%% EMD - -# use fast 1D solver -G0 = ot.emd_1d(x, x, a, b) - -# Equivalent to -# G0 = ot.emd(a, b, M) - -pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') - - -#%% Smooth OT with squared l2 regularization - -lambd = 1e-1 -Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') - -pl.show() - - -#%% Smooth OT with squared l2 regularization - -lambd = 1e-1 -Gsc = ot.sparse.sparsity_constrained_ot_dual(a, b, M, lambd, max_nz=2) -pl.figure(5, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') - -pl.show() - -# %% diff --git a/ot/__init__.py b/ot/__init__.py index 2c4cc6cf4..1a685b601 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -27,7 +27,6 @@ from . import gromov from . import smooth from . import stochastic -from . import sparse from . import unbalanced from . import partial from . import backend diff --git a/ot/smooth.py b/ot/smooth.py index 8e0ef38e9..8308ef8dd 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -26,7 +26,8 @@ # Remi Flamary """ -Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) +Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) [17] and sparsity +constrained Optimal Transport solvers [50]. Implementation of : Smooth and Sparse Optimal Transport. @@ -34,18 +35,31 @@ In Proc. of AISTATS 2018. https://arxiv.org/abs/1710.06276 +(Original code from https://github.com/mblondel/smooth-ot/) + +Sparsity-Constrained Optimal Transport. +Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). +https://arxiv.org/abs/2209.15466 + + [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). -Original code from https://github.com/mblondel/smooth-ot/ +[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). +Sparsity-constrained optimal transport. +Proceedings of the Eleventh International Conference on +Learning Representations (ICLR). """ import numpy as np from scipy.optimize import minimize from .backend import get_backend - +import ot def projection_simplex(V, z=1, axis=None): r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` @@ -209,6 +223,39 @@ def Omega(self, T): return 0.5 * self.gamma * np.sum(T ** 2) +class SparsityConstrained(Regularization): + """ Squared L2 regularization with sparsity constraints """ + + def __init__(self, max_nz, gamma=1.0): + self.max_nz = max_nz + self.gamma = gamma + + def delta_Omega(self, X): + # For each column of X, find entries that are not among the top max_nz. + non_top_indices = np.argpartition( + -X, self.max_nz, axis=0)[self.max_nz:] + # Set these entries to -inf. + if X.ndim == 1: + X[non_top_indices] = 0.0 + else: + X[non_top_indices, np.arange(X.shape[1])] = 0.0 + max_X = np.maximum(X, 0) + val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) + G = max_X / self.gamma + return val, G + + def max_Omega(self, X, b): + # Project the scaled X onto the simplex with sparsity constraint. + G = ot.utils.projection_sparse_simplex( + X / (b * self.gamma), self.max_nz, axis=0) + val = np.sum(X * G, axis=0) + val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) + return val, G + + def Omega(self, T): + return 0.5 * self.gamma * np.sum(T ** 2) + + def dual_obj_grad(alpha, beta, a, b, C, regul): r""" Compute objective value and gradients of dual objective. @@ -636,3 +683,166 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= return G, log else: return G + + +def sparsity_constrained_ot_dual( + a, b, M, reg, max_nz, + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the sparsity-constrained OT problem in the dual and return the OT matrix. + + The function solves the sparsity-contrained OT in dual formulation in + :ref:`[50] `. + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + max_nz: int + Maximum number of non-zero entries permitted in each column of the + optimal transport matrix. + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sparsity-constrained-ot-dual: + References + ---------- + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.smooth : Entropic regularized and squared l2 regularized OT + ot.optim.cg : General regularized OT + + """ + + nx = get_backend(a, b, M) + max_nz = min(max_nz, M.shape[0]) + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) + + a0, b0, M0 = a, b, M + + # convert to humpy + a, b, M = nx.to_numpy(a, b, M) + + # solve dual + alpha, beta, res = solve_dual( + a, b, M, regul, + max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), + type_as=M0) + + if log: + log = {'alpha': nx.from_numpy(alpha, type_as=a0), + 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} + return G, log + else: + return G + + +def sparsity_constrained_ot_semi_dual( + a, b, M, reg, max_nz, + method="L-BFGS-B", stopThr=1e-9, + numItermax=500, verbose=False, log=False): + r""" + Solve the regularized OT problem in the semi-dual and return the OT matrix + + The function solves the sparsity-contrained OT in semi-dual formulation in + :ref:`[50] `. + + + Parameters + ---------- + a : np.ndarray (ns,) + samples weights in the source domain + b : np.ndarray (nt,) or np.ndarray (nt,nbb) + samples in the target domain, compute sinkhorn with multiple targets + and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix + (return OT loss + dual variables in log) + M : np.ndarray (ns,nt) + loss matrix + reg : float + Regularization term >0 + max_nz: int + Maximum number of non-zero entries permitted in each column of the optimal transport matrix. + method : str + Solver to use for scipy.optimize.minimize + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + verbose : bool, optional + Print information along iterations + log : bool, optional + record log if True + + + Returns + ------- + gamma : (ns, nt) ndarray + Optimal transportation matrix for the given parameters + log : dict + log dictionary return only if log==True in parameters + + + .. _references-sparsity-constrained-ot-semi-dual: + References + ---------- + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + + See Also + -------- + ot.lp.emd : Unregularized OT + ot.sinhorn : Entropic regularized OT + ot.smooth : Entropic regularized and squared l2 regularized OT + ot.optim.cg : General regularized OT + + """ + + max_nz = min(max_nz, M.shape[0]) + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) + # solve dual + alpha, res = solve_semi_dual( + a, b, M, regul, max_iter=numItermax, + tol=stopThr, verbose=verbose) + + # reconstruct transport matrix + G = get_plan_from_semi_dual(alpha, b, M, regul) + + if log: + log = {'alpha': alpha, 'res': res} + return G, log + else: + return G diff --git a/ot/sparse.py b/ot/sparse.py deleted file mode 100644 index 0f8fbe8a5..000000000 --- a/ot/sparse.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Sparsity-constrained optimal transport solvers. - -Implementation of : -Sparsity-Constrained Optimal Transport. -Liu, T., Puigcerver, J., & Blondel, M. (2023). -Sparsity-constrained optimal transport. -Proceedings of the Eleventh International Conference on -Learning Representations (ICLR). -https://arxiv.org/abs/2209.15466 - -[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). -Sparsity-constrained optimal transport. -Proceedings of the Eleventh International Conference on -Learning Representations (ICLR). -""" - -# Author: Tianlin Liu -# -# License: MIT License - - -import numpy as np -import ot -from .backend import get_backend - - -def projection_sparse_simplex(V, max_nz, z=1, axis=None): - r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. - - .. math:: - P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2 - - Parameters - ---------- - V: ndarray, rank 2 - z: float or array - If array, len(z) must be compatible with :math:`\mathbf{V}` - axis: None or int - - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)` - - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)` - - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)` - - Returns - ------- - projection: ndarray, shape :math:`\mathbf{V}`.shape - - References: - Sparse projections onto the simplex - Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch - ICML 2013 - https://arxiv.org/abs/1206.1529 - """ - if axis == 1: - max_nz_indices = np.argpartition( - V, - kth=-max_nz, - axis=1)[:, -max_nz:] - # Record nonzero column indices in a descending order - max_nz_indices = max_nz_indices[:, ::-1] - - row_indices = np.arange(V.shape[0])[:, np.newaxis] - - # Extract the top max_nz values for each row - # and then project to simplex. - U = V[row_indices, max_nz_indices] - z = np.ones(len(U)) * z - cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] - ind = np.arange(max_nz) + 1 - cond = U - cssv / ind > 0 - rho = np.count_nonzero(cond, axis=1) - theta = cssv[np.arange(len(U)), rho - 1] / rho - nz_projection = np.maximum(U - theta[:, np.newaxis], 0) - - # Put the projection of max_nz_values to their original column indices - # while keeping other values zero. - sparse_projection = np.zeros_like(V) - sparse_projection[row_indices, max_nz_indices] = nz_projection - return sparse_projection - - elif axis == 0: - return projection_sparse_simplex(V.T, max_nz, z, axis=1).T - - else: - V = V.ravel().reshape(1, -1) - return projection_sparse_simplex(V, max_nz, z, axis=1).ravel() - - -class SparsityConstrained(ot.smooth.Regularization): - """ Squared L2 regularization with sparsity constraints """ - - def __init__(self, max_nz, gamma=1.0): - self.max_nz = max_nz - self.gamma = gamma - - def delta_Omega(self, X): - # For each column of X, find entries that are not among the top max_nz. - non_top_indices = np.argpartition( - -X, self.max_nz, axis=0)[self.max_nz:] - # Set these entries to -inf. - X[non_top_indices, np.arange(X.shape[1])] = -np.inf - max_X = np.maximum(X, 0) - val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma) - G = max_X / self.gamma - return val, G - - def max_Omega(self, X, b): - # Project the scaled X onto the simplex with sparsity constraint. - G = projection_sparse_simplex( - X / (b * self.gamma), self.max_nz, axis=0) - val = np.sum(X * G, axis=0) - val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0) - return val, G - - def Omega(self, T): - return 0.5 * self.gamma * np.sum(T ** 2) - - -def sparsity_constrained_ot_dual( - a, b, M, reg, max_nz, - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): - r""" - Solve the sparsity-constrained OT problem in the dual and return the OT matrix. - - The function solves the sparsity-contrained OT in dual formulation in - :ref:`[50] `. - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - max_nz: int - Maximum number of non-zero entries permitted in each column of the - optimal transport matrix. - method : str - Solver to use for scipy.optimize.minimize - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - .. _references-sparsity-constrained-ot-dual: - References - ---------- - .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.sinhorn : Entropic regularized OT - ot.smooth : Entropic regularized and squared l2 regularized OT - ot.optim.cg : General regularized OT - - """ - - nx = get_backend(a, b, M) - max_nz = min(max_nz, M.shape[0]) - regul = SparsityConstrained(gamma=reg, max_nz=max_nz) - - a0, b0, M0 = a, b, M - - # convert to humpy - a, b, M = nx.to_numpy(a, b, M) - - # solve dual - alpha, beta, res = ot.smooth.solve_dual( - a, b, M, regul, - max_iter=numItermax, - tol=stopThr, verbose=verbose) - - # reconstruct transport matrix - G = nx.from_numpy(ot.smooth.get_plan_from_dual(alpha, beta, M, regul), - type_as=M0) - - if log: - log = {'alpha': nx.from_numpy(alpha, type_as=a0), - 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} - return G, log - else: - return G - - -def sparsity_constrained_ot_semi_dual( - a, b, M, reg, max_nz, - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): - r""" - Solve the regularized OT problem in the semi-dual and return the OT matrix - - The function solves the sparsity-contrained OT in semi-dual formulation in - :ref:`[50] `. - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - max_nz: int - Maximum number of non-zero entries permitted in each column of the optimal transport matrix. - method : str - Solver to use for scipy.optimize.minimize - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - .. _references-sparsity-constrained-ot-semi-dual: - References - ---------- - .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.sinhorn : Entropic regularized OT - ot.smooth : Entropic regularized and squared l2 regularized OT - ot.optim.cg : General regularized OT - - """ - - max_nz = min(max_nz, M.shape[0]) - regul = SparsityConstrained(gamma=reg, max_nz=max_nz) - # solve dual - alpha, res = ot.smooth.solve_semi_dual( - a, b, M, regul, max_iter=numItermax, - tol=stopThr, verbose=verbose) - - # reconstruct transport matrix - G = ot.smooth.get_plan_from_semi_dual(alpha, b, M, regul) - - if log: - log = {'alpha': alpha, 'res': res} - return G, log - else: - return G diff --git a/ot/utils.py b/ot/utils.py index 3423a7e4d..873975bbf 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -117,6 +117,74 @@ def proj_simplex(v, z=1): return w +def projection_sparse_simplex(V, max_nz, z=1, axis=None): + r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. + + .. math:: + P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2 + + Parameters + ---------- + V: 1-dim or 2-dim ndarray + z: float or array + If array, len(z) must be compatible with :math:`\mathbf{V}` + axis: None or int + - axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)` + - axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)` + - axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)` + + Returns + ------- + projection: ndarray, shape :math:`\mathbf{V}`.shape + + References: + Sparse projections onto the simplex + Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch + ICML 2013 + https://arxiv.org/abs/1206.1529 + """ + + if V.ndim == 1: + return projection_sparse_simplex( + V[np.newaxis, :], max_nz, z, axis=1).ravel() + + if V.ndim > 2: + raise ValueError('V.ndim must be <= 2') + + if axis == 1: + # For each row of V, find top max_nz values; arrange the + # corresponding column indices such that their values are + # in a descending order. + max_nz_indices = np.argsort(V, axis=1)[:, -max_nz:] + max_nz_indices = max_nz_indices[:, ::-1] + + row_indices = np.arange(V.shape[0])[:, np.newaxis] + + # Extract the top max_nz values for each row + # and then project to simplex. + U = V[row_indices, max_nz_indices] + z = np.ones(len(U)) * z + cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] + ind = np.arange(max_nz) + 1 + cond = U - cssv / ind > 0 + rho = np.count_nonzero(cond, axis=1) + theta = cssv[np.arange(len(U)), rho - 1] / rho + nz_projection = np.maximum(U - theta[:, np.newaxis], 0) + + # Put the projection of max_nz_values to their original column indices + # while keeping other values zero. + sparse_projection = np.zeros_like(V) + sparse_projection[row_indices, max_nz_indices] = nz_projection + return sparse_projection + + elif axis == 0: + return projection_sparse_simplex(V.T, max_nz, z, axis=1).T + + else: + V = V.ravel().reshape(1, -1) + return projection_sparse_simplex(V, max_nz, z, axis=1).ravel() + + def unif(n, type_as=None): r""" Return a uniform histogram of length `n` (simplex). diff --git a/test/test_smooth.py b/test/test_smooth.py index 31e0b2e4b..9688937c0 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -7,6 +7,7 @@ import numpy as np import ot import pytest +from scipy.optimize import check_grad def test_smooth_ot_dual(): @@ -77,3 +78,136 @@ def test_smooth_ot_semi_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + + +def test_sparsity_constrained_ot_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + max_nz = 2 + + plan = ot.smooth.sparsity_constrained_ot_dual( + u, u, M, 1, max_nz=max_nz, stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) + np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) + + # check sparsity constraint + np.testing.assert_array_less( + np.sum(plan > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_sparsity_constrained_ot_semi_dual(): + + # get data + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + M = ot.dist(x, x) + + max_nz = 5 + plan, log = ot.smooth.sparsity_constrained_ot_semi_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) + np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) + + # check sparsity constraint + np.testing.assert_array_less( + np.sum(plan > 0, axis=0), + np.ones(n) * max_nz + 1) + + +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.utils.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce the same result + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) + + +def test_sparsity_constrained_gradient(): + max_nz = 5 + regularizer = ot.smooth.SparsityConstrained(max_nz=max_nz) + rng = np.random.RandomState(0) + X = rng.randn(10,) + b = 0.5 + + def delta_omega_func(X): + return regularizer.delta_Omega(X)[0] + + def delta_omega_grad(X): + return regularizer.delta_Omega(X)[1] + + dual_grad_err = check_grad(delta_omega_func, delta_omega_grad, X) + np.testing.assert_allclose(dual_grad_err, 0.0, atol=1e-07) + + def max_omega_func(X, b): + return regularizer.max_Omega(X, b)[0] + + def max_omega_grad(X, b): + return regularizer.max_Omega(X, b)[1] + + semi_dual_grad_err = check_grad(max_omega_func, max_omega_grad, X, b) + np.testing.assert_allclose(semi_dual_grad_err, 0.0, atol=1e-07) diff --git a/test/test_sparse.py b/test/test_sparse.py deleted file mode 100644 index ecc04f745..000000000 --- a/test/test_sparse.py +++ /dev/null @@ -1,115 +0,0 @@ -"""Tests for ot.sparse model """ - -# Author: Tianlin Liu -# -# License: MIT License - -import numpy as np -import ot - - -def test_sparsity_constrained_ot_dual(): - - # get data - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - max_nz = 2 - - plan = ot.sparse.sparsity_constrained_ot_dual( - u, u, M, 1, max_nz=max_nz, stopThr=1e-10) - - # check marginal constraints - np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) - np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) - - # check sparsity constraint - np.testing.assert_array_less( - np.sum(plan > 0, axis=0), - np.ones(n) * max_nz + 1) - - -def test_sparsity_constrained_ot_semi_dual(): - - # get data - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - max_nz = 2 - plan, log = ot.sparse.sparsity_constrained_ot_semi_dual( - u, u, M, 1, - max_nz=max_nz, - log=True, - stopThr=1e-10) - - # check marginal constraints - np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) - np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) - - # check sparsity constraint - np.testing.assert_array_less( - np.sum(plan > 0, axis=0), - np.ones(n) * max_nz + 1) - - -def test_projection_sparse_simplex(): - - def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): - r"""This is an equivalent but less efficient version - of ot.sparse.projection_sparse_simplex, as it uses two - sorts instead of one. - """ - - if axis == 0: - # For each column of X, find top max_nz values and - # their corresponding indices. This incurs a sort. - max_nz_indices = np.argpartition( - X, - kth=-max_nz, - axis=0)[-max_nz:] - - max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] - - # Project the top max_nz values onto the simplex. - # This incurs a second sort. - G_nz_values = ot.smooth.projection_simplex( - max_nz_values, z=z, axis=0) - - # Put the projection of max_nz_values to their original indices - # and set all other values zero. - G = np.zeros_like(X) - G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values - return G - elif axis == 1: - return double_sort_projection_sparse_simplex( - X.T, max_nz, z, axis=0).T - - else: - X = X.ravel().reshape(-1, 1) - return double_sort_projection_sparse_simplex( - X, max_nz, z, axis=0).ravel() - - m, n = 5, 10 - rng = np.random.RandomState(0) - X = rng.uniform(size=(m, n)) - max_nz = 3 - - for axis in [0, 1, None]: - slow_sparse_proj = double_sort_projection_sparse_simplex( - X, max_nz, axis=axis) - fast_sparse_proj = ot.sparse.projection_sparse_simplex( - X, max_nz, axis=axis) - - # check that two versions produce the same result - np.testing.assert_allclose( - slow_sparse_proj, fast_sparse_proj) From bd347578914919fe0eea0fe3266cf8d76c95225d Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Wed, 19 Apr 2023 09:48:28 +0200 Subject: [PATCH 4/7] reuse existing ot.smooth functions with a new 'sparsity_constrained' reg_type --- examples/plot_OT_1D_smooth.py | 3 +- ot/smooth.py | 190 +++++----------------------------- test/test_smooth.py | 131 ++++++----------------- test/test_utils.py | 53 ++++++++++ 4 files changed, 109 insertions(+), 268 deletions(-) diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 517639277..548ddc81a 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -116,7 +116,8 @@ lambd = 1e-1 max_nz = 2 # two non-zero entries are permitted per column of the OT plan -Gsc = ot.smooth.sparsity_constrained_ot_dual(a, b, M, lambd, max_nz=max_nz) +Gsc = ot.smooth.smooth_ot_dual( + a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz) pl.figure(5, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.') diff --git a/ot/smooth.py b/ot/smooth.py index 8308ef8dd..d9cf498e9 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -482,7 +482,8 @@ def get_plan_from_semi_dual(alpha, b, C, regul): return regul.max_Omega(X, b)[1] * b -def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, +def smooth_ot_dual(a, b, M, reg, reg_type='l2', max_nz=None, + method="L-BFGS-B", stopThr=1e-9, numItermax=500, verbose=False, log=False): r""" Solve the regularized OT problem in the dual and return the OT matrix @@ -524,6 +525,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, :ref:`[2] `) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -551,6 +555,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -565,6 +571,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') @@ -586,7 +597,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, return G -def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, +def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, + method="L-BFGS-B", stopThr=1e-9, numItermax=500, verbose=False, log=False): r""" Solve the regularized OT problem in the semi-dual and return the OT matrix @@ -630,6 +642,10 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= :ref:`[2] `) - 'l2' : Squared Euclidean regularization + - 'sparsity_constrained' : Sparsity-constrained regularization [50] + + max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; + not used for other regularization types. method : str Solver to use for scipy.optimize.minimize numItermax : int, optional @@ -657,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). + .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). + See Also -------- ot.lp.emd : Unregularized OT @@ -668,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= regul = SquaredL2(gamma=reg) elif reg_type.lower() in ['entropic', 'negentropy', 'kl']: regul = NegEntropy(gamma=reg) + elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']: + if not isinstance(max_nz, int): + raise ValueError( + f'max_nz {max_nz} must be an integer') + regul = SparsityConstrained(gamma=reg, max_nz=max_nz) else: raise NotImplementedError('Unknown regularization') @@ -683,166 +706,3 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr= return G, log else: return G - - -def sparsity_constrained_ot_dual( - a, b, M, reg, max_nz, - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): - r""" - Solve the sparsity-constrained OT problem in the dual and return the OT matrix. - - The function solves the sparsity-contrained OT in dual formulation in - :ref:`[50] `. - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - max_nz: int - Maximum number of non-zero entries permitted in each column of the - optimal transport matrix. - method : str - Solver to use for scipy.optimize.minimize - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - .. _references-sparsity-constrained-ot-dual: - References - ---------- - .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.sinhorn : Entropic regularized OT - ot.smooth : Entropic regularized and squared l2 regularized OT - ot.optim.cg : General regularized OT - - """ - - nx = get_backend(a, b, M) - max_nz = min(max_nz, M.shape[0]) - regul = SparsityConstrained(gamma=reg, max_nz=max_nz) - - a0, b0, M0 = a, b, M - - # convert to humpy - a, b, M = nx.to_numpy(a, b, M) - - # solve dual - alpha, beta, res = solve_dual( - a, b, M, regul, - max_iter=numItermax, - tol=stopThr, verbose=verbose) - - # reconstruct transport matrix - G = nx.from_numpy(get_plan_from_dual(alpha, beta, M, regul), - type_as=M0) - - if log: - log = {'alpha': nx.from_numpy(alpha, type_as=a0), - 'beta': nx.from_numpy(beta, type_as=b0), 'res': res} - return G, log - else: - return G - - -def sparsity_constrained_ot_semi_dual( - a, b, M, reg, max_nz, - method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): - r""" - Solve the regularized OT problem in the semi-dual and return the OT matrix - - The function solves the sparsity-contrained OT in semi-dual formulation in - :ref:`[50] `. - - - Parameters - ---------- - a : np.ndarray (ns,) - samples weights in the source domain - b : np.ndarray (nt,) or np.ndarray (nt,nbb) - samples in the target domain, compute sinkhorn with multiple targets - and fixed:math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix - (return OT loss + dual variables in log) - M : np.ndarray (ns,nt) - loss matrix - reg : float - Regularization term >0 - max_nz: int - Maximum number of non-zero entries permitted in each column of the optimal transport matrix. - method : str - Solver to use for scipy.optimize.minimize - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - verbose : bool, optional - Print information along iterations - log : bool, optional - record log if True - - - Returns - ------- - gamma : (ns, nt) ndarray - Optimal transportation matrix for the given parameters - log : dict - log dictionary return only if log==True in parameters - - - .. _references-sparsity-constrained-ot-semi-dual: - References - ---------- - .. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR). - - See Also - -------- - ot.lp.emd : Unregularized OT - ot.sinhorn : Entropic regularized OT - ot.smooth : Entropic regularized and squared l2 regularized OT - ot.optim.cg : General regularized OT - - """ - - max_nz = min(max_nz, M.shape[0]) - regul = SparsityConstrained(gamma=reg, max_nz=max_nz) - # solve dual - alpha, res = solve_semi_dual( - a, b, M, regul, max_iter=numItermax, - tol=stopThr, verbose=verbose) - - # reconstruct transport matrix - G = get_plan_from_semi_dual(alpha, b, M, regul) - - if log: - log = {'alpha': alpha, 'res': res} - return G, log - else: - return G diff --git a/test/test_smooth.py b/test/test_smooth.py index 9688937c0..dbdd40541 100644 --- a/test/test_smooth.py +++ b/test/test_smooth.py @@ -24,6 +24,7 @@ def test_smooth_ot_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -44,6 +45,24 @@ def test_smooth_ot_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) + # sparsity-constrained regularisation + max_nz = 2 + Gsc, log = ot.smooth.smooth_ot_dual( + u, u, M, 1, + max_nz=max_nz, + log=True, + reg_type='sparsity_constrained', + stopThr=1e-10) + + # check marginal constraints + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) + + # check sparsity constraints + np.testing.assert_array_less( + np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) + def test_smooth_ot_semi_dual(): @@ -59,6 +78,7 @@ def test_smooth_ot_semi_dual(): with pytest.raises(NotImplementedError): Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='none') + # squared l2 regularisation Gl2, log = ot.smooth.smooth_ot_semi_dual(u, u, M, 1, reg_type='l2', log=True, stopThr=1e-10) # check constraints @@ -79,112 +99,19 @@ def test_smooth_ot_semi_dual(): G2 = ot.sinkhorn(u, u, M, 1, stopThr=1e-10) np.testing.assert_allclose(G, G2, atol=1e-05) - -def test_sparsity_constrained_ot_dual(): - - # get data - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - + # sparsity-constrained regularisation max_nz = 2 - - plan = ot.smooth.sparsity_constrained_ot_dual( - u, u, M, 1, max_nz=max_nz, stopThr=1e-10) - - # check marginal constraints - np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) - np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) - - # check sparsity constraint - np.testing.assert_array_less( - np.sum(plan > 0, axis=0), - np.ones(n) * max_nz + 1) - - -def test_sparsity_constrained_ot_semi_dual(): - - # get data - n = 100 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - u = ot.utils.unif(n) - - M = ot.dist(x, x) - - max_nz = 5 - plan, log = ot.smooth.sparsity_constrained_ot_semi_dual( - u, u, M, 1, - max_nz=max_nz, - log=True, - stopThr=1e-10) + Gsc = ot.smooth.smooth_ot_semi_dual( + u, u, M, 1, reg_type='sparsity_constrained', + max_nz=max_nz, stopThr=1e-10) # check marginal constraints - np.testing.assert_allclose(u, plan.sum(1), atol=1e-03) - np.testing.assert_allclose(u, plan.sum(0), atol=1e-03) - - # check sparsity constraint - np.testing.assert_array_less( - np.sum(plan > 0, axis=0), - np.ones(n) * max_nz + 1) - + np.testing.assert_allclose(u, Gsc.sum(1), atol=1e-03) + np.testing.assert_allclose(u, Gsc.sum(0), atol=1e-03) -def test_projection_sparse_simplex(): - - def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): - r"""This is an equivalent but less efficient version - of ot.utils.projection_sparse_simplex, as it uses two - sorts instead of one. - """ - - if axis == 0: - # For each column of X, find top max_nz values and - # their corresponding indices. This incurs a sort. - max_nz_indices = np.argpartition( - X, - kth=-max_nz, - axis=0)[-max_nz:] - - max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] - - # Project the top max_nz values onto the simplex. - # This incurs a second sort. - G_nz_values = ot.smooth.projection_simplex( - max_nz_values, z=z, axis=0) - - # Put the projection of max_nz_values to their original indices - # and set all other values zero. - G = np.zeros_like(X) - G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values - return G - elif axis == 1: - return double_sort_projection_sparse_simplex( - X.T, max_nz, z, axis=0).T - - else: - X = X.ravel().reshape(-1, 1) - return double_sort_projection_sparse_simplex( - X, max_nz, z, axis=0).ravel() - - m, n = 5, 10 - rng = np.random.RandomState(0) - X = rng.uniform(size=(m, n)) - max_nz = 3 - - for axis in [0, 1, None]: - slow_sparse_proj = double_sort_projection_sparse_simplex( - X, max_nz, axis=axis) - fast_sparse_proj = ot.utils.projection_sparse_simplex( - X, max_nz, axis=axis) - - # check that two versions produce the same result - np.testing.assert_allclose( - slow_sparse_proj, fast_sparse_proj) + # check sparsity constraints + np.testing.assert_array_less(np.sum(Gsc > 0, axis=0), + np.ones(n) * max_nz + 1) def test_sparsity_constrained_gradient(): diff --git a/test/test_utils.py b/test/test_utils.py index 31b12efeb..658214d21 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -41,6 +41,59 @@ def test_proj_simplex(nx): np.testing.assert_allclose(l1, l2, atol=1e-5) +def test_projection_sparse_simplex(): + + def double_sort_projection_sparse_simplex(X, max_nz, z=1, axis=None): + r"""This is an equivalent but less efficient version + of ot.utils.projection_sparse_simplex, as it uses two + sorts instead of one. + """ + + if axis == 0: + # For each column of X, find top max_nz values and + # their corresponding indices. This incurs a sort. + max_nz_indices = np.argpartition( + X, + kth=-max_nz, + axis=0)[-max_nz:] + + max_nz_values = X[max_nz_indices, np.arange(X.shape[1])] + + # Project the top max_nz values onto the simplex. + # This incurs a second sort. + G_nz_values = ot.smooth.projection_simplex( + max_nz_values, z=z, axis=0) + + # Put the projection of max_nz_values to their original indices + # and set all other values zero. + G = np.zeros_like(X) + G[max_nz_indices, np.arange(X.shape[1])] = G_nz_values + return G + elif axis == 1: + return double_sort_projection_sparse_simplex( + X.T, max_nz, z, axis=0).T + + else: + X = X.ravel().reshape(-1, 1) + return double_sort_projection_sparse_simplex( + X, max_nz, z, axis=0).ravel() + + m, n = 5, 10 + rng = np.random.RandomState(0) + X = rng.uniform(size=(m, n)) + max_nz = 3 + + for axis in [0, 1, None]: + slow_sparse_proj = double_sort_projection_sparse_simplex( + X, max_nz, axis=axis) + fast_sparse_proj = ot.utils.projection_sparse_simplex( + X, max_nz, axis=axis) + + # check that two versions produce consistent results + np.testing.assert_allclose( + slow_sparse_proj, fast_sparse_proj) + + def test_parmap(): n = 10 From d734fbb7051eb01cedfdf6c48fb841944644a04f Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Wed, 19 Apr 2023 10:41:56 +0200 Subject: [PATCH 5/7] address pep8 error --- ot/smooth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/smooth.py b/ot/smooth.py index d9cf498e9..15029fcce 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -61,6 +61,7 @@ from .backend import get_backend import ot + def projection_simplex(V, z=1, axis=None): r""" Projection of :math:`\mathbf{V}` onto the simplex, scaled by `z` @@ -643,7 +644,6 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None, - 'l2' : Squared Euclidean regularization - 'sparsity_constrained' : Sparsity-constrained regularization [50] - max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan; not used for other regularization types. method : str From 5484ed04979a3f1939f63391c172d1a557cfcbac Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Mon, 24 Apr 2023 17:32:41 +0200 Subject: [PATCH 6/7] add backends for --- examples/plot_OT_1D_smooth.py | 37 +++++------------------------- ot/smooth.py | 8 +++---- ot/utils.py | 43 ++++++++++++++++++++++------------- 3 files changed, 37 insertions(+), 51 deletions(-) diff --git a/examples/plot_OT_1D_smooth.py b/examples/plot_OT_1D_smooth.py index 548ddc81a..ff51b8abf 100644 --- a/examples/plot_OT_1D_smooth.py +++ b/examples/plot_OT_1D_smooth.py @@ -1,11 +1,12 @@ # -*- coding: utf-8 -*- """ ================================ -Smooth optimal transport example +Smooth and sparse OT example ================================ -This example illustrates the computation of EMD, Sinkhorn, smooth OT plans, -sparsity-constrained plans, and their visualization. +This example illustrates the computation of +Smooth and Sparse (KL an L2 reg.) OT and +sparsity-constrained OT, together with their visualizations. """ @@ -58,32 +59,6 @@ pl.figure(2, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, M, 'Cost matrix M') -############################################################################## -# Solve EMD -# --------- - - -#%% EMD - -G0 = ot.emd(a, b, M) - -pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0') - -############################################################################## -# Solve Sinkhorn -# -------------- - - -#%% Sinkhorn - -lambd = 2e-3 -Gs = ot.sinkhorn(a, b, M, lambd, verbose=True) - -pl.figure(4, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn') - -pl.show() ############################################################################## # Solve Smooth OT @@ -95,7 +70,7 @@ lambd = 2e-3 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl') -pl.figure(5, figsize=(5, 5)) +pl.figure(3, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.') pl.show() @@ -106,7 +81,7 @@ lambd = 1e-1 Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2') -pl.figure(6, figsize=(5, 5)) +pl.figure(4, figsize=(5, 5)) ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.') pl.show() diff --git a/ot/smooth.py b/ot/smooth.py index 15029fcce..331cfc04e 100644 --- a/ot/smooth.py +++ b/ot/smooth.py @@ -24,10 +24,10 @@ # Author: Mathieu Blondel # Remi Flamary +# Tianlin Liu """ -Smooth and Sparse Optimal Transport solvers (KL an L2 reg.) [17] and sparsity -constrained Optimal Transport solvers [50]. +Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers. Implementation of : Smooth and Sparse Optimal Transport. @@ -483,9 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul): return regul.max_Omega(X, b)[1] * b -def smooth_ot_dual(a, b, M, reg, reg_type='l2', max_nz=None, +def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9, - numItermax=500, verbose=False, log=False): + numItermax=500, verbose=False, log=False, max_nz=None): r""" Solve the regularized OT problem in the dual and return the OT matrix diff --git a/ot/utils.py b/ot/utils.py index 873975bbf..3343028ec 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,7 @@ import sys import warnings from inspect import signature -from .backend import get_backend, Backend, NumpyBackend +from .backend import get_backend, Backend, NumpyBackend, JaxBackend __time_tic_toc = time.time() @@ -117,7 +117,7 @@ def proj_simplex(v, z=1): return w -def projection_sparse_simplex(V, max_nz, z=1, axis=None): +def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None): r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`. .. math:: @@ -143,10 +143,12 @@ def projection_sparse_simplex(V, max_nz, z=1, axis=None): ICML 2013 https://arxiv.org/abs/1206.1529 """ - + if nx is None: + nx = get_backend(V) if V.ndim == 1: return projection_sparse_simplex( - V[np.newaxis, :], max_nz, z, axis=1).ravel() + # V[nx.newaxis, :], max_nz, z, axis=1).ravel() + V[None, :], max_nz, z, axis=1).ravel() if V.ndim > 2: raise ValueError('V.ndim must be <= 2') @@ -155,26 +157,35 @@ def projection_sparse_simplex(V, max_nz, z=1, axis=None): # For each row of V, find top max_nz values; arrange the # corresponding column indices such that their values are # in a descending order. - max_nz_indices = np.argsort(V, axis=1)[:, -max_nz:] - max_nz_indices = max_nz_indices[:, ::-1] - - row_indices = np.arange(V.shape[0])[:, np.newaxis] + max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:] + max_nz_indices = nx.flip(max_nz_indices, axis=1) + row_indices = nx.arange(V.shape[0]) + row_indices = row_indices.reshape(-1, 1) + print(row_indices.shape) # Extract the top max_nz values for each row # and then project to simplex. U = V[row_indices, max_nz_indices] - z = np.ones(len(U)) * z - cssv = np.cumsum(U, axis=1) - z[:, np.newaxis] - ind = np.arange(max_nz) + 1 + z = nx.ones(len(U)) * z + cssv = nx.cumsum(U, axis=1) - z[:, None] + ind = nx.arange(max_nz) + 1 cond = U - cssv / ind > 0 - rho = np.count_nonzero(cond, axis=1) - theta = cssv[np.arange(len(U)), rho - 1] / rho - nz_projection = np.maximum(U - theta[:, np.newaxis], 0) + # rho = nx.count_nonzero(cond, axis=1) + rho = nx.sum(cond, axis=1) + theta = cssv[nx.arange(len(U)), rho - 1] / rho + nz_projection = nx.maximum(U - theta[:, None], 0) # Put the projection of max_nz_values to their original column indices # while keeping other values zero. - sparse_projection = np.zeros_like(V) - sparse_projection[row_indices, max_nz_indices] = nz_projection + sparse_projection = nx.zeros(V.shape, type_as=nz_projection) + + if isinstance(nx, JaxBackend): + # in Jax, we need to use the `at` property of `jax.numpy.ndarray` + # to do in-place array modificatons. + sparse_projection = sparse_projection.at[ + row_indices, max_nz_indices].set(nz_projection) + else: + sparse_projection[row_indices, max_nz_indices] = nz_projection return sparse_projection elif axis == 0: From 9206c0063230df312a1da606fe81cc6fa3b94be8 Mon Sep 17 00:00:00 2001 From: liutianlin0121 Date: Tue, 25 Apr 2023 10:32:42 +0200 Subject: [PATCH 7/7] update releases --- RELEASES.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/RELEASES.md b/RELEASES.md index 214cc2a63..28635ff8c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,7 @@ ## 0.9.1dev #### New features - + - Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) #### Closed issues - Fix circleci-redirector action and codecov (PR #460)