Skip to content

[MRG] add the sparsity-constrained optimal transport funtionality and example #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,5 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.

[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.

[50] Liu, T., Puigcerver, J., & Blondel, M. (2023). [Sparsity-constrained optimal transport](https://openreview.net/forum?id=yHY9NbQJ5BP). Proceedings of the Eleventh International Conference on Learning Representations (ICLR).
3 changes: 1 addition & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
## 0.9.1dev

#### New features

- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)

- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
#### Closed issues

- Fix circleci-redirector action and codecov (PR #460)
Expand Down
51 changes: 19 additions & 32 deletions examples/plot_OT_1D_smooth.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# -*- coding: utf-8 -*-
"""
================================
Smooth optimal transport example
Smooth and sparse OT example
================================

This example illustrates the computation of EMD, Sinkhorn and smooth OT plans
and their visualization.
This example illustrates the computation of
Smooth and Sparse (KL an L2 reg.) OT and
sparsity-constrained OT, together with their visualizations.

"""

Expand Down Expand Up @@ -58,32 +59,6 @@
pl.figure(2, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')

##############################################################################
# Solve EMD
# ---------


#%% EMD

G0 = ot.emd(a, b, M)

pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, G0, 'OT matrix G0')

##############################################################################
# Solve Sinkhorn
# --------------


#%% Sinkhorn

lambd = 2e-3
Gs = ot.sinkhorn(a, b, M, lambd, verbose=True)

pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gs, 'OT matrix Sinkhorn')

pl.show()

##############################################################################
# Solve Smooth OT
Expand All @@ -95,18 +70,30 @@
lambd = 2e-3
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='kl')

pl.figure(5, figsize=(5, 5))
pl.figure(3, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT KL reg.')

pl.show()


#%% Smooth OT with KL regularization
#%% Smooth OT with squared l2 regularization

lambd = 1e-1
Gsm = ot.smooth.smooth_ot_dual(a, b, M, lambd, reg_type='l2')

pl.figure(6, figsize=(5, 5))
pl.figure(4, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsm, 'OT matrix Smooth OT l2 reg.')

pl.show()

#%% Sparsity-constrained OT

lambd = 1e-1

max_nz = 2 # two non-zero entries are permitted per column of the OT plan
Gsc = ot.smooth.smooth_ot_dual(
a, b, M, lambd, reg_type='sparsity_constrained', max_nz=max_nz)
pl.figure(5, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, Gsc, 'Sparsity contrained OT matrix; k=2.')

pl.show()
80 changes: 75 additions & 5 deletions ot/smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,42 @@

# Author: Mathieu Blondel
# Remi Flamary <remi.flamary@unice.fr>
# Tianlin Liu <t.liu@unibas.ch>

"""
Smooth and Sparse Optimal Transport solvers (KL an L2 reg.)
Smooth and Sparse (KL an L2 reg.) and sparsity-constrained OT solvers.

Implementation of :
Smooth and Sparse Optimal Transport.
Mathieu Blondel, Vivien Seguy, Antoine Rolet.
In Proc. of AISTATS 2018.
https://arxiv.org/abs/1710.06276

(Original code from https://github.com/mblondel/smooth-ot/)

Sparsity-Constrained Optimal Transport.
Liu, T., Puigcerver, J., & Blondel, M. (2023).
Sparsity-constrained optimal transport.
Proceedings of the Eleventh International Conference on
Learning Representations (ICLR).
https://arxiv.org/abs/2209.15466


[17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal
Transport. Proceedings of the Twenty-First International Conference on
Artificial Intelligence and Statistics (AISTATS).

Original code from https://github.com/mblondel/smooth-ot/
[50] Liu, T., Puigcerver, J., & Blondel, M. (2023).
Sparsity-constrained optimal transport.
Proceedings of the Eleventh International Conference on
Learning Representations (ICLR).

"""

import numpy as np
from scipy.optimize import minimize
from .backend import get_backend
import ot


def projection_simplex(V, z=1, axis=None):
Expand Down Expand Up @@ -209,6 +224,39 @@ def Omega(self, T):
return 0.5 * self.gamma * np.sum(T ** 2)


class SparsityConstrained(Regularization):
""" Squared L2 regularization with sparsity constraints """

def __init__(self, max_nz, gamma=1.0):
self.max_nz = max_nz
self.gamma = gamma

def delta_Omega(self, X):
# For each column of X, find entries that are not among the top max_nz.
non_top_indices = np.argpartition(
-X, self.max_nz, axis=0)[self.max_nz:]
# Set these entries to -inf.
if X.ndim == 1:
X[non_top_indices] = 0.0
else:
X[non_top_indices, np.arange(X.shape[1])] = 0.0
max_X = np.maximum(X, 0)
val = np.sum(max_X ** 2, axis=0) / (2 * self.gamma)
G = max_X / self.gamma
return val, G

def max_Omega(self, X, b):
# Project the scaled X onto the simplex with sparsity constraint.
G = ot.utils.projection_sparse_simplex(
X / (b * self.gamma), self.max_nz, axis=0)
val = np.sum(X * G, axis=0)
val -= 0.5 * self.gamma * b * np.sum(G * G, axis=0)
return val, G

def Omega(self, T):
return 0.5 * self.gamma * np.sum(T ** 2)


def dual_obj_grad(alpha, beta, a, b, C, regul):
r"""
Compute objective value and gradients of dual objective.
Expand Down Expand Up @@ -435,8 +483,9 @@ def get_plan_from_semi_dual(alpha, b, C, regul):
return regul.max_Omega(X, b)[1] * b


def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
numItermax=500, verbose=False, log=False):
def smooth_ot_dual(a, b, M, reg, reg_type='l2',
method="L-BFGS-B", stopThr=1e-9,
numItermax=500, verbose=False, log=False, max_nz=None):
r"""
Solve the regularized OT problem in the dual and return the OT matrix

Expand Down Expand Up @@ -477,6 +526,9 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
:ref:`[2] <references-smooth-ot-dual>`)

- 'l2' : Squared Euclidean regularization
- 'sparsity_constrained' : Sparsity-constrained regularization [50]
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Expand Down Expand Up @@ -504,6 +556,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,

.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).

.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

See Also
--------
ot.lp.emd : Unregularized OT
Expand All @@ -518,6 +572,11 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
if not isinstance(max_nz, int):
raise ValueError(
f'max_nz {max_nz} must be an integer')
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')

Expand All @@ -539,7 +598,8 @@ def smooth_ot_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
return G


def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=1e-9,
def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', max_nz=None,
method="L-BFGS-B", stopThr=1e-9,
numItermax=500, verbose=False, log=False):
r"""
Solve the regularized OT problem in the semi-dual and return the OT matrix
Expand Down Expand Up @@ -583,6 +643,9 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
:ref:`[2] <references-smooth-ot-semi-dual>`)

- 'l2' : Squared Euclidean regularization
- 'sparsity_constrained' : Sparsity-constrained regularization [50]
max_nz : int or None, optional. Used only in the case of reg_type = 'sparsity_constrained' to specify the maximum number of nonzeros per column of the optimal plan;
not used for other regularization types.
method : str
Solver to use for scipy.optimize.minimize
numItermax : int, optional
Expand Down Expand Up @@ -610,6 +673,8 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=

.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS).

.. [50] Liu, T., Puigcerver, J., & Blondel, M. (2023). Sparsity-constrained optimal transport. Proceedings of the Eleventh International Conference on Learning Representations (ICLR).

See Also
--------
ot.lp.emd : Unregularized OT
Expand All @@ -621,6 +686,11 @@ def smooth_ot_semi_dual(a, b, M, reg, reg_type='l2', method="L-BFGS-B", stopThr=
regul = SquaredL2(gamma=reg)
elif reg_type.lower() in ['entropic', 'negentropy', 'kl']:
regul = NegEntropy(gamma=reg)
elif reg_type.lower() in ['sparsity_constrained', 'sparsity-constrained']:
if not isinstance(max_nz, int):
raise ValueError(
f'max_nz {max_nz} must be an integer')
regul = SparsityConstrained(gamma=reg, max_nz=max_nz)
else:
raise NotImplementedError('Unknown regularization')

Expand Down
81 changes: 80 additions & 1 deletion ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import sys
import warnings
from inspect import signature
from .backend import get_backend, Backend, NumpyBackend
from .backend import get_backend, Backend, NumpyBackend, JaxBackend

__time_tic_toc = time.time()

Expand Down Expand Up @@ -117,6 +117,85 @@ def proj_simplex(v, z=1):
return w


def projection_sparse_simplex(V, max_nz, z=1, axis=None, nx=None):
r"""Projection of :math:`\mathbf{V}` onto the simplex with cardinality constraint (maximum number of non-zero elements) and then scaled by `z`.

.. math::
P\left(\mathbf{V}, max_nz, z\right) = \mathop{\arg \min}_{\substack{\mathbf{y} >= 0 \\ \sum_i \mathbf{y}_i = z} \\ ||p||_0 \le \text{max_nz}} \quad \|\mathbf{y} - \mathbf{V}\|^2

Parameters
----------
V: 1-dim or 2-dim ndarray
z: float or array
If array, len(z) must be compatible with :math:`\mathbf{V}`
axis: None or int
- axis=None: project :math:`\mathbf{V}` by :math:`P(\mathbf{V}.\mathrm{ravel}(), max_nz, z)`
- axis=1: project each :math:`\mathbf{V}_i` by :math:`P(\mathbf{V}_i, max_nz, z_i)`
- axis=0: project each :math:`\mathbf{V}_{:, j}` by :math:`P(\mathbf{V}_{:, j}, max_nz, z_j)`

Returns
-------
projection: ndarray, shape :math:`\mathbf{V}`.shape

References:
Sparse projections onto the simplex
Anastasios Kyrillidis, Stephen Becker, Volkan Cevher and, Christoph Koch
ICML 2013
https://arxiv.org/abs/1206.1529
"""
if nx is None:
nx = get_backend(V)
if V.ndim == 1:
return projection_sparse_simplex(
# V[nx.newaxis, :], max_nz, z, axis=1).ravel()
V[None, :], max_nz, z, axis=1).ravel()

if V.ndim > 2:
raise ValueError('V.ndim must be <= 2')

if axis == 1:
# For each row of V, find top max_nz values; arrange the
# corresponding column indices such that their values are
# in a descending order.
max_nz_indices = nx.argsort(V, axis=1)[:, -max_nz:]
max_nz_indices = nx.flip(max_nz_indices, axis=1)

row_indices = nx.arange(V.shape[0])
row_indices = row_indices.reshape(-1, 1)
print(row_indices.shape)
# Extract the top max_nz values for each row
# and then project to simplex.
U = V[row_indices, max_nz_indices]
z = nx.ones(len(U)) * z
cssv = nx.cumsum(U, axis=1) - z[:, None]
ind = nx.arange(max_nz) + 1
cond = U - cssv / ind > 0
# rho = nx.count_nonzero(cond, axis=1)
rho = nx.sum(cond, axis=1)
theta = cssv[nx.arange(len(U)), rho - 1] / rho
nz_projection = nx.maximum(U - theta[:, None], 0)

# Put the projection of max_nz_values to their original column indices
# while keeping other values zero.
sparse_projection = nx.zeros(V.shape, type_as=nz_projection)

if isinstance(nx, JaxBackend):
# in Jax, we need to use the `at` property of `jax.numpy.ndarray`
# to do in-place array modificatons.
sparse_projection = sparse_projection.at[
row_indices, max_nz_indices].set(nz_projection)
else:
sparse_projection[row_indices, max_nz_indices] = nz_projection
return sparse_projection

elif axis == 0:
return projection_sparse_simplex(V.T, max_nz, z, axis=1).T

else:
V = V.ravel().reshape(1, -1)
return projection_sparse_simplex(V, max_nz, z, axis=1).ravel()


def unif(n, type_as=None):
r"""
Return a uniform histogram of length `n` (simplex).
Expand Down
Loading