Skip to content

Commit

Permalink
[MRG] Batch/Lazy Log Sinkhorn Knopp on samples (#259)
Browse files Browse the repository at this point in the history
* Add batch implementation of Sinkhorn

* Reformat to pep8 and modify parameter

* Fix error in batch size

* Code review and add test

* Fix accidental typo in test_empirical_sinkhorn

* Remove whitespace

* Edit config.yml
  • Loading branch information
6Ulm committed Jun 14, 2021
1 parent 982510e commit 2dbeeda
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 21 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Expand Up @@ -73,6 +73,7 @@ jobs:
command: |
cd docs;
make html;
no_output_timeout: 30m

# Save the outputs
- store_artifacts:
Expand Down
134 changes: 113 additions & 21 deletions ot/bregman.py
Expand Up @@ -11,13 +11,15 @@
# Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
# Alexander Tong <alexander.tong@yale.edu>
# Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

import warnings

import numpy as np
from scipy.optimize import fmin_l_bfgs_b
from scipy.special import logsumexp

from ot.utils import unif, dist, list_to_array
from .backend import get_backend
Expand Down Expand Up @@ -1684,7 +1686,7 @@ def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,


def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
Expand Down Expand Up @@ -1723,6 +1725,12 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
isLazy: boolean, optional
If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
If False, calculate full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batcheses used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
Expand Down Expand Up @@ -1758,24 +1766,78 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''

ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = unif(np.shape(X_s)[0])
a = unif(ns)
if b is None:
b = unif(np.shape(X_t)[0])
b = unif(nt)

if isLazy:
if log:
dict_log = {"err": []}

M = dist(X_s, X_t, metric=metric)
log_a, log_b = np.log(a), np.log(b)
f, g = np.zeros(ns), np.zeros(nt)

if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
raise ValueError("Batch size must be in integer or a tuple of two integers")

range_s, range_t = range(0, ns, bs), range(0, nt, bt)

lse_f = np.zeros(ns)
lse_g = np.zeros(nt)

for i_ot in range(numIterMax):

for i in range_s:
M = dist(X_s[i:i + bs, :], X_t, metric=metric)
lse_f[i:i + bs] = logsumexp(g[None, :] - M / reg, axis=1)
f = log_a - lse_f

for j in range_t:
M = dist(X_s, X_t[j:j + bt, :], metric=metric)
lse_g[j:j + bt] = logsumexp(f[:, None] - M / reg, axis=0)
g = log_b - lse_g

if (i_ot + 1) % 10 == 0:
m1 = np.zeros_like(a)
for i in range_s:
M = dist(X_s[i:i + bs, :], X_t, metric=metric)
m1[i:i + bs] = np.exp(f[i:i + bs, None] + g[None, :] - M / reg).sum(1)
err = np.abs(m1 - a).sum()
if log:
dict_log["err"].append(err)

if verbose and (i_ot + 1) % 100 == 0:
print("Error in marginal at iteration {} = {}".format(i_ot + 1, err))

if err <= stopThr:
break

if log:
dict_log["u"] = f
dict_log["v"] = g
return (f, g, dict_log)
else:
return (f, g)

if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
return pi
M = dist(X_s, X_t, metric=metric)

if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=True, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=False, **kwargs)
return pi


def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
verbose=False, log=False, **kwargs):
isLazy=False, batchSize=100, verbose=False, log=False, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss
Expand Down Expand Up @@ -1814,6 +1876,12 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
Max number of iterations
stopThr : float, optional
Stop threshol on error (>0)
isLazy: boolean, optional
If True, then only calculate the cost matrix by block and return the dual potentials only (to save memory)
If False, calculate full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batcheses used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
Expand Down Expand Up @@ -1850,21 +1918,45 @@ def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', num
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''

ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = unif(np.shape(X_s)[0])
a = unif(ns)
if b is None:
b = unif(np.shape(X_t)[0])
b = unif(nt)

M = dist(X_s, X_t, metric=metric)
if isLazy:
if log:
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric, numIterMax=numIterMax, stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize, verbose=verbose, log=log)

bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)

loss = 0
for i in range_s:
M_block = dist(X_s[i:i + bs, :], X_t, metric=metric)
pi_block = np.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
loss += np.sum(M_block * pi_block)

if log:
return loss, dict_log
else:
return loss

if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
return sinkhorn_loss
M = dist(X_s, X_t, metric=metric)

if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr, verbose=verbose, log=log,
**kwargs)
return sinkhorn_loss


def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9,
Expand Down
44 changes: 44 additions & 0 deletions test/test_bregman.py
Expand Up @@ -2,6 +2,7 @@

# Author: Remi Flamary <remi.flamary@unice.fr>
# Kilian Fatras <kilian.fatras@irisa.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

Expand Down Expand Up @@ -329,6 +330,49 @@ def test_empirical_sinkhorn():
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)


def test_lazy_empirical_sinkhorn():
# test sinkhorn
n = 100
a = ot.unif(n)
b = ot.unif(n)
numIterMax = 1000

X_s = np.reshape(np.arange(n), (n, 1))
X_t = np.reshape(np.arange(0, n), (n, 1))
M = ot.dist(X_s, X_t)
M_m = ot.dist(X_s, X_t, metric='minkowski')

f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=(1, 1), verbose=True)
G_sqe = np.exp(f[:, None] + g[None, :] - M / 1)
sinkhorn_sqe = ot.sinkhorn(a, b, M, 1)

f, g, log_es = ot.bregman.empirical_sinkhorn(X_s, X_t, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
G_log = np.exp(f[:, None] + g[None, :] - M / 0.1)
sinkhorn_log, log_s = ot.sinkhorn(a, b, M, 0.1, log=True)

f, g = ot.bregman.empirical_sinkhorn(X_s, X_t, 1, metric='minkowski', numIterMax=numIterMax, isLazy=True, batchSize=1)
G_m = np.exp(f[:, None] + g[None, :] - M_m / 1)
sinkhorn_m = ot.sinkhorn(a, b, M_m, 1)

loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2(X_s, X_t, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True)
loss_sinkhorn = ot.sinkhorn2(a, b, M, 1)

# check constratints
np.testing.assert_allclose(
sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
sinkhorn_sqe.sum(0), G_sqe.sum(0), atol=1e-05) # metric sqeuclidian
np.testing.assert_allclose(
sinkhorn_log.sum(1), G_log.sum(1), atol=1e-05) # log
np.testing.assert_allclose(
sinkhorn_log.sum(0), G_log.sum(0), atol=1e-05) # log
np.testing.assert_allclose(
sinkhorn_m.sum(1), G_m.sum(1), atol=1e-05) # metric euclidian
np.testing.assert_allclose(
sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian
np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05)


def test_empirical_sinkhorn_divergence():
# Test sinkhorn divergence
n = 10
Expand Down

0 comments on commit 2dbeeda

Please sign in to comment.