From 3844639d3dd4e0dd360ebef34dd657d26664039e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 27 Jan 2020 09:35:19 +0100 Subject: [PATCH 1/5] add test for constraint viuolation of duals --- test/test_ot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_ot.py b/test/test_ot.py index 18b629444..c756e5177 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -337,7 +337,10 @@ def test_dual_variables(): # Check that both cost computations are equivalent np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - + + viol=log['u'][:,None]+log['v'][None,:]-M + + assert viol.max()<1e-8 def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) From 30fc233f7f62d571a562971a945d68c3782f0780 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 27 Jan 2020 09:37:42 +0100 Subject: [PATCH 2/5] correct pep8 --- test/test_ot.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_ot.py b/test/test_ot.py index c756e5177..245a10712 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -337,10 +337,11 @@ def test_dual_variables(): # Check that both cost computations are equivalent np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - - viol=log['u'][:,None]+log['v'][None,:]-M - - assert viol.max()<1e-8 + + viol = log['u'][:, None] + log['v'][None, :] - M + + assert viol.max() < 1e-8 + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) From e5196fa7a8c493b831fd5dac52a89bbf29e7b0e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 27 Jan 2020 10:40:05 +0100 Subject: [PATCH 3/5] correct bug in emd emd2 still todo --- ot/lp/__init__.py | 194 ++++++++++++++++++++++++++++++++++++++++----- ot/lp/emd_wrap.pyx | 2 + 2 files changed, 174 insertions(+), 22 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index eabdd3a4b..a771ce4cf 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -23,11 +23,150 @@ from .cvx import barycenter from ..utils import dist -__all__=['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', - 'emd_1d', 'emd2_1d', 'wasserstein_1d'] +__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', + 'emd_1d', 'emd2_1d', 'wasserstein_1d'] -def emd(a, b, M, numItermax=100000, log=False, dense=True): +def center_ot_dual(alpha0, beta0, a=None, b=None): + r"""Center dual OT potentials wrt theirs weights + + The main idea of this function is to find unique dual potentials + that ensure some kind of centering/fairness. It will help have + stability when multiple calling of the OT solver with small changes. + + Basically we add another constraint to the potential that will not + change the objective value but will ensure unicity. The constraint + is the following: + + .. math:: + \alpha^T a= \beta^T b + + in addition to the OT problem constraints. + + since :math:`\sum_i a_i=\sum_j b_j` this can be solved by adding/removing + a constant from both :math:`\alpha_0` and :math:`\beta_0`. + + .. math:: + c=\frac{\beta0^T b-\alpha_0^T a}{1^Tb+1^Ta} + + \alpha=\alpha_0+c + + \beta=\beta0+c + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source centered dual potential + beta : (nt,) numpy.ndarray, float64 + Target centered dual potential + + """ + # if no weights are provided, use uniform + if a is None: + a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + if b is None: + b = np.ones(beta0.shape[0]) / beta0.shape[0] + + # compute constant that balances the weighted sums of the duals + c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + + # update duals + alpha = alpha0 + c + beta = beta0 - c + + return alpha, beta + + +def estimate_dual_null_weights(alpha0, beta0, a, b, M): + r"""Estimate feasible values for 0-weighted dual potentials + + The feasible values are computed efficiently bjt rather coarsely. + First we compute the constraints violations: + + .. math:: + V=\alpha+\beta^T-M + + Next we compute the max amount of violation per row (alpha) and + columns (beta) + + .. math:: + v^a_i=\max_j V_{i,j} + + v^b_j=\max_i V_{i,j} + + Finally we update the dual potential with 0 weights if a + constraint is violated + + .. math:: + \alpha_i = \alpha_i -v^a_i \quad \text{ if } a_i=0 \text{ and } v^a_i>0 + + \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 + + In the end the dual potential are centred using function + :ref:`center_ot_dual`. + + Note that all those updates do not change the objective value of the + solution but provide dual potential that do not violate the constraints. + + Parameters + ---------- + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + alpha0 : (ns,) numpy.ndarray, float64 + Source dual potential + beta0 : (nt,) numpy.ndarray, float64 + Target dual potential + a : (ns,) numpy.ndarray, float64 + Source histogram (uniform weight if empty list) + b : (nt,) numpy.ndarray, float64 + Target histogram (uniform weight if empty list) + M : (ns,nt) numpy.ndarray, float64 + Loss matrix (c-order array with type float64) + + Returns + ------- + alpha : (ns,) numpy.ndarray, float64 + Source corrected dual potential + beta : (nt,) numpy.ndarray, float64 + Target corrected dual potential + + """ + + # binary indexing of non-zeros weights + asel = a != 0 + bsel = b != 0 + + # compute dual constraints violation + Viol = alpha0[:, None] + beta0[None, :] - M + + # Compute worst violation per line and columns + aviol = np.max(Viol, 1) + bviol = np.max(Viol, 0) + + # update corrects violation of + alpha_up = -1 * ~asel * np.maximum(aviol, 0) + beta_up = -1 * ~bsel * np.maximum(bviol, 0) + + alpha = alpha0 + alpha_up + beta = beta0 + beta_up + + return center_ot_dual(alpha, beta, a, b) + + +def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): r"""Solves the Earth Movers distance problem and returns the OT matrix @@ -43,7 +182,7 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -66,6 +205,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -107,7 +249,6 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64) - # if empty array given then use uniform distributions if len(a) == 0: a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] @@ -117,11 +258,21 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True): assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + asel = a != 0 + bsel = b != 0 + if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) if log: @@ -151,7 +302,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), - a and b are the sample weights .. warning:: - Note that the M matrix needs to be a C-order numpy.array in float64 + Note that the M matrix needs to be a C-order numpy.array in float64 format. Uses the algorithm proposed in [1]_ @@ -177,7 +328,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), dense: boolean, optional (default=True) If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` - format. + format. Returns ------- @@ -221,7 +372,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), # problem with pikling Forks if sys.platform.endswith('win32'): - processes=1 + processes = 1 # if empty array given then use uniform distributions if len(a) == 0: @@ -235,10 +386,10 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) log = {} @@ -252,10 +403,10 @@ def f(b): else: def f(b): if dense: - G, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: - Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax,dense) - G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) result_code_string = check_result(result_code) check_result(result_code) @@ -265,7 +416,7 @@ def f(b): return f(b) nb = b.shape[1] - if processes>1: + if processes > 1: res = parmap(f, [b[:, i] for i in range(nb)], processes) else: res = list(map(f, [b[:, i].copy() for i in range(nb)])) @@ -273,7 +424,6 @@ def f(b): return res - def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) @@ -326,7 +476,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None k = X_init.shape[0] d = X_init.shape[1] if b is None: - b = np.ones((k,))/k + b = np.ones((k,)) / k if weights is None: weights = np.ones((N,)) / N @@ -337,7 +487,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None displacement_square_norm = stopThr + 1. - while ( displacement_square_norm > stopThr and iter_count < numItermax ): + while (displacement_square_norm > stopThr and iter_count < numItermax): T_sum = np.zeros((k, d)) @@ -347,7 +497,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None T_i = emd(b, measure_weights_i, M_i) T_sum = T_sum + weight_i * np.reshape(1. / b, (-1, 1)) * np.matmul(T_i, measure_locations_i) - displacement_square_norm = np.sum(np.square(T_sum-X)) + displacement_square_norm = np.sum(np.square(T_sum - X)) if log: displacement_square_norms.append(displacement_square_norm) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index c0d712843..a4987f41f 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -40,6 +40,8 @@ def check_result(result_code): return message + + @cython.boundscheck(False) @cython.wraparound(False) def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mode="c"] b, np.ndarray[double, ndim=2, mode="c"] M, int max_iter, bint dense): From 9a9b3547837eac56349ce8df92bb5b0565daa2d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Mon, 27 Jan 2020 10:59:58 +0100 Subject: [PATCH 4/5] correct emd2 and add centering for dual potentials --- ot/lp/__init__.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index a771ce4cf..aa3166f10 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -264,6 +264,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) @@ -271,6 +274,9 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) @@ -287,7 +293,8 @@ def emd(a, b, M, numItermax=100000, log=False, dense=True, center_dual=True): def emd2(a, b, M, processes=multiprocessing.cpu_count(), - numItermax=100000, log=False, dense=True, return_matrix=False): + numItermax=100000, log=False, dense=True, return_matrix=False, + center_dual=True): r"""Solves the Earth Movers distance problem and returns the loss .. math:: @@ -329,6 +336,9 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), If True, returns math:`\gamma` as a dense ndarray of shape (ns, nt). Otherwise returns a sparse representation using scipy's `coo_matrix` format. + center_dual: boolean, optional (default=True) + If True, centers the dual potential using function + :ref:`center_ot_dual`. Returns ------- @@ -383,14 +393,23 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" + asel = a != 0 + if log or return_matrix: def f(b): + bsel = b != 0 if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + result_code_string = check_result(result_code) log = {} if return_matrix: @@ -402,12 +421,19 @@ def f(b): return [cost, log] else: def f(b): + bsel = b != 0 if dense: G, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) else: Gv, iG, jG, cost, u, v, result_code = emd_c(a, b, M, numItermax, dense) G = coo_matrix((Gv, (iG, jG)), shape=(a.shape[0], b.shape[0])) + if center_dual: + u, v = center_ot_dual(u, v, a, b) + + if np.any(~asel) or np.any(~bsel): + u, v = estimate_dual_null_weights(u, v, a, b, M) + result_code_string = check_result(result_code) check_result(result_code) return cost From f65073faa73b36280a19ff8b9c383e66f8bdbd2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 30 Jan 2020 08:04:36 +0100 Subject: [PATCH 5/5] comlete documentation --- ot/lp/__init__.py | 30 +++++++++++++++++++----------- ot/lp/emd_wrap.pyx | 6 ++++++ test/test_ot.py | 4 ++-- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index aa3166f10..cdd505dd6 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -28,10 +28,10 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): - r"""Center dual OT potentials wrt theirs weights + r"""Center dual OT potentials w.r.t. theirs weights The main idea of this function is to find unique dual potentials - that ensure some kind of centering/fairness. It will help have + that ensure some kind of centering/fairness. The main idea is to find dual potentials that lead to the same final objective value for both source and targets (see below for more details). It will help having stability when multiple calling of the OT solver with small changes. Basically we add another constraint to the potential that will not @@ -91,7 +91,15 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): def estimate_dual_null_weights(alpha0, beta0, a, b, M): r"""Estimate feasible values for 0-weighted dual potentials - The feasible values are computed efficiently bjt rather coarsely. + The feasible values are computed efficiently but rather coarsely. + + .. warning:: + This function is necessary because the C++ solver in emd_c + discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + First we compute the constraints violations: .. math:: @@ -113,11 +121,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): \beta_j = \beta_j -v^b_j \quad \text{ if } b_j=0 \text{ and } v^b_j>0 - In the end the dual potential are centred using function + In the end the dual potentials are centered using function :ref:`center_ot_dual`. Note that all those updates do not change the objective value of the - solution but provide dual potential that do not violate the constraints. + solution but provide dual potentials that do not violate the constraints. Parameters ---------- @@ -130,9 +138,9 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): beta0 : (nt,) numpy.ndarray, float64 Target dual potential a : (ns,) numpy.ndarray, float64 - Source histogram (uniform weight if empty list) + Source distribution (uniform weights if empty list) b : (nt,) numpy.ndarray, float64 - Target histogram (uniform weight if empty list) + Target distribution (uniform weights if empty list) M : (ns,nt) numpy.ndarray, float64 Loss matrix (c-order array with type float64) @@ -150,11 +158,11 @@ def estimate_dual_null_weights(alpha0, beta0, a, b, M): bsel = b != 0 # compute dual constraints violation - Viol = alpha0[:, None] + beta0[None, :] - M + constraint_violation = alpha0[:, None] + beta0[None, :] - M - # Compute worst violation per line and columns - aviol = np.max(Viol, 1) - bviol = np.max(Viol, 0) + # Compute largest violation per line and columns + aviol = np.max(constraint_violation, 1) + bviol = np.max(constraint_violation, 0) # update corrects violation of alpha_up = -1 * ~asel * np.maximum(aviol, 0) diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index a4987f41f..d345fd427 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -66,6 +66,12 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod .. warning:: Note that the M matrix needs to be a C-order :py.cls:`numpy.array` + .. warning:: + The C++ solver discards all samples in the distributions with + zeros weights. This means that while the primal variable (transport + matrix) is exact, the solver only returns feasible dual potentials + on the samples with weights different from zero. + Parameters ---------- a : (ns,) numpy.ndarray, float64 diff --git a/test/test_ot.py b/test/test_ot.py index 245a10712..47df94634 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -338,9 +338,9 @@ def test_dual_variables(): np.testing.assert_almost_equal(cost1, log['cost']) check_duality_gap(a, b, M, G, log['u'], log['v'], log['cost']) - viol = log['u'][:, None] + log['v'][None, :] - M + constraint_violation = log['u'][:, None] + log['v'][None, :] - M - assert viol.max() < 1e-8 + assert constraint_violation.max() < 1e-8 def check_duality_gap(a, b, M, G, u, v, cost):