Skip to content
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:

* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] and greedy SInkhorn [22] with optional GPU implementation (requires cudamat).
* Smooth optimal transport solvers (dual and semi-dual) for KL and squared L2 regularizations [17].
* Non regularized Wasserstein barycenters [16] with LP solver (only small scale).
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Bregman projections for Wasserstein barycenter [3], convolutional barycenter [21] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
* Linear OT [14] and Joint OT matrix and mapping estimation [8].
Expand Down Expand Up @@ -165,6 +165,7 @@ The contributors to this library are:
* [Antoine Rolet](https://arolet.github.io/)
* Erwan Vautier (Gromov-Wasserstein)
* [Kilian Fatras](https://kilianfatras.github.io/)
* [Alain Rakotomamonjy](https://sites.google.com/site/alainrakotomamonjy/home)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -230,3 +231,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[20] Cuturi, M. and Doucet, A. (2014) [Fast Computation of Wasserstein Barycenters](http://proceedings.mlr.press/v32/cuturi14.html). International Conference in Machine Learning

[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). [Convolutional wasserstein distances: Efficient optimal transportation on geometric domains](https://dl.acm.org/citation.cfm?id=2766963). ACM Transactions on Graphics (TOG), 34(4), 66.

[22] J. Altschuler, J.Weed, P. Rigollet, (2017) [Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration](https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf), Advances in Neural Information Processing Systems (NIPS) 31
38 changes: 35 additions & 3 deletions docs/source/readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ It provides the following solvers:
- OT Network Flow solver for the linear program/ Earth Movers Distance
[1].
- Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2]
and stabilized version [9][10] with optional GPU implementation
(requires cudamat).
and stabilized version [9][10] and greedy SInkhorn [22] with optional
GPU implementation (requires cudamat).
- Smooth optimal transport solvers (dual and semi-dual) for KL and
squared L2 regularizations [17].
- Non regularized Wasserstein barycenters [16] with LP solver (only
small scale).
- Bregman projections for Wasserstein barycenter [3] and unmixing [4].
- Bregman projections for Wasserstein barycenter [3], convolutional
barycenter [21] and unmixing [4].
- Optimal transport for domain adaptation with group lasso
regularization [5]
- Conditional gradient [6] and Generalized conditional gradient for
Expand All @@ -29,6 +30,9 @@ It provides the following solvers:
pymanopt).
- Gromov-Wasserstein distances and barycenters ([13] and regularized
[12])
- Stochastic Optimization for Large-scale Optimal Transport (semi-dual
problem [18] and dual problem [19])
- Non regularized free support Wasserstein barycenters [20].

Some demonstrations (both in Python and Jupyter Notebook format) are
available in the examples folder.
Expand Down Expand Up @@ -219,6 +223,9 @@ The contributors to this library are:
- `Stanislas Chambon <https://slasnista.github.io/>`__
- `Antoine Rolet <https://arolet.github.io/>`__
- Erwan Vautier (Gromov-Wasserstein)
- `Kilian Fatras <https://kilianfatras.github.io/>`__
- `Alain
Rakotomamonjy <https://sites.google.com/site/alainrakotomamonjy/home>`__

This toolbox benefit a lot from open source research and we would like
to thank the following persons for providing some code (in various
Expand Down Expand Up @@ -334,6 +341,31 @@ Optimal Transport <https://arxiv.org/abs/1710.06276>`__. Proceedings of
the Twenty-First International Conference on Artificial Intelligence and
Statistics (AISTATS).

[18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
Optimization for Large-scale Optimal
Transport <https://arxiv.org/abs/1605.08527>`__. Advances in Neural
Information Processing Systems (2016).

[19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
A.& Blondel, M. `Large-scale Optimal Transport and Mapping
Estimation <https://arxiv.org/pdf/1711.02283.pdf>`__. International
Conference on Learning Representation (2018)

[20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
Barycenters <http://proceedings.mlr.press/v32/cuturi14.html>`__.
International Conference in Machine Learning

[21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
Efficient optimal transportation on geometric
domains <https://dl.acm.org/citation.cfm?id=2766963>`__. ACM
Transactions on Graphics (TOG), 34(4), 66.

[22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
approximation algorithms for optimal transport via Sinkhorn
iteration <https://papers.nips.cc/paper/6792-near-linear-time-approximation-algorithms-for-optimal-transport-via-sinkhorn-iteration.pdf>`__,
Advances in Neural Information Processing Systems (NIPS) 31

.. |PyPI version| image:: https://badge.fury.io/py/POT.svg
:target: https://badge.fury.io/py/POT
.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg
Expand Down
151 changes: 150 additions & 1 deletion ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
method used for the solver either 'sinkhorn', 'greenkhorn', 'sinkhorn_stabilized' or
'sinkhorn_epsilon_scaling', see those function for specific parameters
numItermax : int, optional
Max number of iterations
Expand Down Expand Up @@ -103,6 +103,10 @@ def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000,
def sink():
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log, **kwargs)
elif method.lower() == 'greenkhorn':
def sink():
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log)
elif method.lower() == 'sinkhorn_stabilized':
def sink():
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
Expand Down Expand Up @@ -197,13 +201,16 @@ def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

[21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017



See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn [2]
ot.bregman.greenkhorn : Greenkhorn [21]
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn [9][10]
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epslilon scaling [9][10]

Expand Down Expand Up @@ -410,6 +417,148 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000,
return u.reshape((-1, 1)) * K * v.reshape((1, -1))


def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, log=False):
"""
Solve the entropic regularization optimal transport problem and return the OT matrix

The algorithm used is based on the paper

Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration
by Jason Altschuler, Jonathan Weed, Philippe Rigollet
appeared at NIPS 2017

which is a stochastic version of the Sinkhorn-Knopp algorithm [2].

The function solves the following optimization problem:

.. math::
\gamma = arg\min_\gamma <\gamma,M>_F + reg\cdot\Omega(\gamma)

s.t. \gamma 1 = a

\gamma^T 1= b

\gamma\geq 0
where :

- M is the (ns,nt) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- a and b are source and target weights (sum to 1)



Parameters
----------
a : np.ndarray (ns,)
samples weights in the source domain
b : np.ndarray (nt,) or np.ndarray (nt,nbb)
samples in the target domain, compute sinkhorn with multiple targets
and fixed M if b is a matrix (return OT loss + dual variables in log)
M : np.ndarray (ns,nt)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
log : bool, optional
record log if True


Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.bregman.greenkhorn(a,b,M,1)
array([[ 0.36552929, 0.13447071],
[ 0.13447071, 0.36552929]])


References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
[22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time approximation algorithms for optimal transport via Sinkhorn iteration, Advances in Neural Information Processing Systems (NIPS) 31, 2017


See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

n = a.shape[0]
m = b.shape[0]

# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty_like(M)
np.divide(M, -reg, out=K)
np.exp(K, out=K)

u = np.full(n, 1. / n)
v = np.full(m, 1. / m)
G = u[:, np.newaxis] * K * v[np.newaxis, :]

viol = G.sum(1) - a
viol_2 = G.sum(0) - b
stopThr_val = 1
if log:
log['u'] = u
log['v'] = v

for i in range(numItermax):
i_1 = np.argmax(np.abs(viol))
i_2 = np.argmax(np.abs(viol_2))
m_viol_1 = np.abs(viol[i_1])
m_viol_2 = np.abs(viol_2[i_2])
stopThr_val = np.maximum(m_viol_1, m_viol_2)

if m_viol_1 > m_viol_2:
old_u = u[i_1]
u[i_1] = a[i_1] / (K[i_1, :].dot(v))
G[i_1, :] = u[i_1] * K[i_1, :] * v

viol[i_1] = u[i_1] * K[i_1, :].dot(v) - a[i_1]
viol_2 += (K[i_1, :].T * (u[i_1] - old_u) * v)

else:
old_v = v[i_2]
v[i_2] = b[i_2] / (K[:, i_2].T.dot(u))
G[:, i_2] = u * K[:, i_2] * v[i_2]
#aviol = (G@one_m - a)
#aviol_2 = (G.T@one_n - b)
viol += (-old_v + v[i_2]) * K[:, i_2] * u
viol_2[i_2] = v[i_2] * K[:, i_2].dot(u) - b[i_2]

#print('b',np.max(abs(aviol -viol)),np.max(abs(aviol_2 - viol_2)))

if stopThr_val <= stopThr:
break
else:
print('Warning: Algorithm did not converge')

if log:
log['u'] = u
log['v'] = v

if log:
return G, log
else:
return G


def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20, log=False, **kwargs):
"""
Expand Down
3 changes: 3 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,14 @@ def test_sinkhorn_variants():
Ges = ot.sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
G_green = ot.sinkhorn(u, u, M, 1, method='greenkhorn', stopThr=1e-10)

# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
np.testing.assert_allclose(G0, Ges, atol=1e-05)
np.testing.assert_allclose(G0, Gerr)
np.testing.assert_allclose(G0, G_green, atol=1e-5)
print(G0, G_green)


def test_bary():
Expand Down