Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 207 additions & 23 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,158 @@
from .cvx import barycenter
from ..utils import dist

__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']
__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx',
'emd_1d', 'emd2_1d', 'wasserstein_1d']


def emd(a, b, M, numItermax=100000, log=False, dense=True):
def center_ot_dual(alpha0, beta0, a=None, b=None):
r"""Center dual OT potentials w.r.t. theirs weights

The main idea of this function is to find unique dual potentials
that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having
stability when multiple calling of the OT solver with small changes.

Basically we add another constraint to the potential that will not
change the objective value but will ensure unicity. The constraint
is the following:

.. math::
\alpha^T a= \beta^T b

in addition to the OT problem constraints.

since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing
a constant from both :math:`\alpha_0` and :math:`\beta_0`.

.. math::
c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta}

\alpha=\alpha_0+c

\beta=\beta0+c

Parameters
----------
alpha0 : (ns,) numpy.ndarray, float64
Source dual potential
beta0 : (nt,) numpy.ndarray, float64
Target dual potential
a : (ns,) numpy.ndarray, float64
Source histogram (uniform weight if empty list)
b : (nt,) numpy.ndarray, float64
Target histogram (uniform weight if empty list)

Returns
-------
alpha : (ns,) numpy.ndarray, float64
Source centered dual potential
beta : (nt,) numpy.ndarray, float64
Target centered dual potential

"""
# if no weights are provided, use uniform
if a is None:
a = np.ones(alpha0.shape[0]) / alpha0.shape[0]
if b is None:
b = np.ones(beta0.shape[0]) / beta0.shape[0]

# compute constant that balances the weighted sums of the duals
c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum())

# update duals
alpha = alpha0 + c
beta = beta0 - c

return alpha, beta


def estimate_dual_null_weights(alpha0, beta0, a, b, M):
r"""Estimate feasible values for 0-weighted dual potentials

The feasible values are computed efficiently but rather coarsely.

.. warning::
This function is necessary because the C++ solver in emd_c
discards all samples in the distributions with
zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
on the samples with weights different from zero.

First we compute the constraints violations:

.. math::
V=\alpha+\beta^T-M

Next we compute the max amount of violation per row (alpha) and
columns (beta)

.. math::
v^a_i=\max_j V_{i,j}

v^b_j=\max_i V_{i,j}

Finally we update the dual potential with 0 weights if a
constraint is violated

.. math::
\alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0

\beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0

In the end the dual potentials are centered using function
:ref:`center_ot_dual`.

Note that all those updates do not change the objective value of the
solution but provide dual potentials that do not violate the constraints.

Parameters
----------
alpha0 : (ns,) numpy.ndarray, float64
Source dual potential
beta0 : (nt,) numpy.ndarray, float64
Target dual potential
alpha0 : (ns,) numpy.ndarray, float64
Source dual potential
beta0 : (nt,) numpy.ndarray, float64
Target dual potential
a : (ns,) numpy.ndarray, float64
Source distribution (uniform weights if empty list)
b : (nt,) numpy.ndarray, float64
Target distribution (uniform weights if empty list)
M : (ns,nt) numpy.ndarray, float64
Loss matrix (c-order array with type float64)

Returns
-------
alpha : (ns,) numpy.ndarray, float64
Source corrected dual potential
beta : (nt,) numpy.ndarray, float64
Target corrected dual potential

"""

# binary indexing of non-zeros weights
asel = a != 0
bsel = b != 0

# compute dual constraints violation
constraint_violation = alpha0[:, None] + beta0[None, :] - M

# Compute largest violation per line and columns
aviol = np.max(constraint_violation, 1)
bviol = np.max(constraint_violation, 0)

# update corrects violation of
alpha_up = -1 * ~asel * np.maximum(aviol, 0)
beta_up = -1 * ~bsel * np.maximum(bviol, 0)

alpha = alpha0 + alpha_up
beta = beta0 + beta_up

return center_ot_dual(alpha, beta, a, b)


def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True):
r"""Solves the Earth Movers distance problem and returns the OT matrix


Expand All @@ -43,7 +190,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
- a and b are the sample weights

.. warning::
Note that the M matrix needs to be a C-order numpy.array in float64
Note that the M matrix needs to be a C-order numpy.array in float64
format.

Uses the algorithm proposed in [1]_
Expand All @@ -66,6 +213,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
format.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.

Returns
-------
Expand Down Expand Up @@ -107,7 +257,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)


# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
Expand All @@ -117,11 +266,27 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"

asel = a != 0
bsel = b != 0

if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)

if center_dual:
u, v = center_ot_dual(u, v, a, b)

if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)

else:
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))

if center_dual:
u, v = center_ot_dual(u, v, a, b)

if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)

result_code_string = check_result(result_code)
if log:
Expand All @@ -136,7 +301,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True):


def emd2(a, b, M, processes=multiprocessing.cpu_count(),
numItermax=100000, log=False, dense=True, return_matrix=False):
numItermax=100000, log=False, dense=True, return_matrix=False,
center_dual=True):
r"""Solves the Earth Movers distance problem and returns the loss

.. math::
Expand All @@ -151,7 +317,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
- a and b are the sample weights

.. warning::
Note that the M matrix needs to be a C-order numpy.array in float64
Note that the M matrix needs to be a C-order numpy.array in float64
format.

Uses the algorithm proposed in [1]_
Expand All @@ -177,7 +343,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
dense: boolean, optional (default=True)
If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt).
Otherwise returns a sparse representation using scipy's `coo_matrix`
format.
format.
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:ref:`center_ot_dual`.

Returns
-------
Expand Down Expand Up @@ -221,7 +390,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),

# problem with pikling Forks
if sys.platform.endswith('win32'):
processes=1
processes = 1

# if empty array given then use uniform distributions
if len(a) == 0:
Expand All @@ -232,13 +401,22 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(),
assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"

asel = a != 0

if log or return_matrix:
def f(b):
bsel = b != 0
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))

if center_dual:
u, v = center_ot_dual(u, v, a, b)

if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)

result_code_string = check_result(result_code)
log = {}
Expand All @@ -251,11 +429,18 @@ def f(b):
return [cost, log]
else:
def f(b):
bsel = b != 0
if dense:
G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
else:
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))
Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense)
G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0]))

if center_dual:
u, v = center_ot_dual(u, v, a, b)

if np.any(~asel) or np.any(~bsel):
u, v = estimate_dual_null_weights(u, v, a, b, M)

result_code_string = check_result(result_code)
check_result(result_code)
Expand All @@ -265,15 +450,14 @@ def f(b):
return f(b)
nb = b.shape[1]

if processes>1:
if processes > 1:
res = parmap(f, [b[:, i] for i in range(nb)], processes)
else:
res = list(map(f, [b[:, i].copy() for i in range(nb)]))

return res



def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None):
"""
Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance)
Expand Down Expand Up @@ -326,7 +510,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
b = np.ones((k,))/k
b = np.ones((k,)) / k
if weights is None:
weights = np.ones((N,)) / N

Expand All @@ -337,7 +521,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None

displacement_square_norm = stopThr + 1.

while ( displacement_square_norm > stopThr and iter_count < numItermax ):
while (displacement_square_norm > stopThr and iter_count < numItermax):

T_sum = np.zeros((k, d))

Expand All @@ -347,7 +531,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None
T_i = emd(b, measure_weights_i, M_i)
T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i)

displacement_square_norm = np.sum(np.square(T_sum-X))
displacement_square_norm = np.sum(np.square(T_sum - X))
if log:
displacement_square_norms.append(displacement_square_norm)

Expand Down
8 changes: 8 additions & 0 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def check_result(result_code):
return message




@cython.boundscheck(False)
@cython.wraparound(False)
def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense):
Expand All @@ -64,6 +66,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
.. warning::
Note that the M matrix needs to be a C-order :py.cls:`numpy.array`

.. warning::
The C++ solver discards all samples in the distributions with
zeros weights. This means that while the primal variable (transport
matrix) is exact, the solver only returns feasible dual potentials
on the samples with weights different from zero.

Parameters
----------
a : (ns,) numpy.ndarray, float64
Expand Down
4 changes: 4 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,10 @@ def test_dual_variables():
np.testing.assert_almost_equal(cost1, log['cost'])
check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost'])

constraint_violation = log['u'][:, None] + log['v'][None, :] - M

assert constraint_violation.max() < 1e-8


def check_duality_gap(a, b, M, G, u, v, cost):
cost_dual = np.vdot(a, u) + np.vdot(b, v)
Expand Down