From 9f51c14e7cbc103494697b6fdd08cfeb847c51b6 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Fri, 8 May 2020 20:24:25 +0100 Subject: [PATCH 01/14] example for log treatment in bregman.py --- ot/bregman.py | 122 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 73 insertions(+), 49 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index f1f843775..eefb75c08 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -14,10 +14,22 @@ # # License: MIT License -import numpy as np +import math import warnings -from .utils import unif, dist + +import numpy as np from scipy.optimize import fmin_l_bfgs_b +from scipy.special import logsumexp + +from .utils import unif, dist + + +def log_matvec(matrix, u, out): + max_matrix = np.max(matrix) + max_u = np.max(u) + np.dot(np.exp(matrix - max_matrix), np.exp(u - max_u), out=out) + np.log(out, out=out) + out += max_matrix + max_u def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -311,61 +323,68 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, ot.optim.cg : General regularized OT """ - a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) + M = np.asarray(M, dtype=np.float64) if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] + a = np.ones((M.shape[0], 1), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] - - # init data - dim_a = len(a) - dim_b = len(b) + b = np.ones((M.shape[1], 1), dtype=np.float64) / M.shape[1] if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 + if len(a.shape) == 1: + a = a[:, None] + + if len(b.shape) == 1: + b = b[:, None] + + log_threshold = math.log(stopThr) + is_logweight = kwargs.get('is_logweight', False) + + if not is_logweight: + a = np.log(a) + b = np.log(b) + + # init data + dim_a = len(a) + dim_b = len(b) + if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.ones((dim_a, n_hists)) / dim_a - v = np.ones((dim_b, n_hists)) / dim_b + u = np.zeros((dim_a, n_hists)) - math.log(dim_a) + v = np.zeros((dim_b, n_hists)) - math.log(dim_b) else: - u = np.ones(dim_a) / dim_a - v = np.ones(dim_b) / dim_b - - # print(reg) - - # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute - K = np.empty(M.shape, dtype=M.dtype) - np.divide(M, -reg, out=K) - np.exp(K, out=K) + u = np.zeros((dim_a, 1)) - math.log(dim_a) + v = np.zeros((dim_b, 1)) - math.log(dim_b) - # print(np.min(K)) - tmp2 = np.empty(b.shape, dtype=M.dtype) + log_K = -M / reg - Kp = (1 / a).reshape(-1, 1) * K + log_Kp = -a.reshape(-1, 1) + log_K + log_K_T = log_K.T cpt = 0 - err = 1 - while (err > stopThr and cpt < numItermax): + log_err = 0.5 * log_threshold + + while log_err > log_threshold and cpt < numItermax: uprev = u vprev = v - KtransposeU = np.dot(K.T, u) - v = np.divide(b, KtransposeU) - u = 1. / np.dot(Kp, v) + log_matvec(log_K_T, u, v) + v *= -1 + v += b + log_matvec(log_Kp, v, u) + u *= -1 - if (np.any(KtransposeU == 0) - or np.any(np.isnan(u)) or np.any(np.isnan(v)) - or np.any(np.isinf(u)) or np.any(np.isinf(v))): + if np.any(~np.isfinite(u)) or np.any(~np.isfinite(v)): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -375,27 +394,32 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - if n_hists: - np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) - else: - # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 - np.einsum('i,ij,j->j', u, K, v, out=tmp2) - err = np.linalg.norm(tmp2 - b) # violation of marginal + temp2 = u + log_K + v.T + temp2 = logsumexp(temp2, axis=0, keepdims=True).T + # noinspection PyTypeChecker + log_err = 0.5 * np.sum(np.exp(2 * temp2) - np.exp(2 * b)) # violation of marginal + # would be more efficient with a check on stability of dual vectors if log: - log['err'].append(err) + log['err'].append(math.exp(log_err)) if verbose: if cpt % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, err)) + print('{:5d}|{:8e}|'.format(cpt, np.exp(log_err))) cpt = cpt + 1 if log: - log['u'] = u - log['v'] = v - + log['u'] = np.exp(u) if not is_logweight else u + log['v'] = np.exp(v) if not is_logweight else v + + gamma = u + log_K + v.T + res = logsumexp(gamma, axis=(0, 1), b=M) + if not is_logweight: + gamma = np.exp(gamma) + res = np.exp(res) + if log: + log['cost'] = res if n_hists: # return only loss - res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -404,9 +428,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, else: # return OT matrix if log: - return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log + return gamma.squeeze(), log else: - return u.reshape((-1, 1)) * K * v.reshape((1, -1)) + return gamma.squeeze() def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, @@ -716,7 +740,7 @@ def get_Gamma(alpha, beta, u, v): if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -2182,11 +2206,11 @@ def projection(u, epsilon): # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) From a07330c683ef6732d9b817d1a89f88f6ac226fe2 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 14 Jul 2020 16:36:39 +0300 Subject: [PATCH 02/14] Improve doc --- ot/lp/__init__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 514a6071d..2f0cc1393 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -443,18 +443,30 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None - we do not optimize over the weights - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + Formally we solve the following problem: + + .. math:: + \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) + + where : + + - :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one + - the :math:`a_i \in \mathbb{R}^{k_i}` are the empirical measures weights and sum to one for each :math:`i` + - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations + - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + Parameters ---------- - measures_locations : list of (k_i,d) numpy.ndarray + measures_locations : list of N (k_i,d) numpy.ndarray The discrete support of a measure supported on k_i locations of a d-dimensional space (k_i can be different for each element of the list) - measures_weights : list of (k_i,) numpy.ndarray + measures_weights : list of N (k_i,) numpy.ndarray Numpy arrays where each numpy array has k_i non-negatives values summing to one representing the weights of each discrete input measure X_init : (k,d) np.ndarray Initialization of the support locations (on k atoms) of the barycenter b : (k,) np.ndarray Initialization of the weights of the barycenter (non-negatives, sum to 1) - weights : (k,) np.ndarray + weights : (N,) np.ndarray Initialization of the coefficients of the barycenter (non-negatives, sum to 1) numItermax : int, optional From dfa2c9d3618625cbb97d115fda7ed28fdbfc7a1a Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 14 Jul 2020 16:37:19 +0300 Subject: [PATCH 03/14] Revert "example for log treatment in bregman.py" This reverts commit 9f51c14e --- ot/bregman.py | 122 ++++++++++++++++++++------------------------------ 1 file changed, 49 insertions(+), 73 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index eefb75c08..f1f843775 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -14,22 +14,10 @@ # # License: MIT License -import math -import warnings - import numpy as np -from scipy.optimize import fmin_l_bfgs_b -from scipy.special import logsumexp - +import warnings from .utils import unif, dist - - -def log_matvec(matrix, u, out): - max_matrix = np.max(matrix) - max_u = np.max(u) - np.dot(np.exp(matrix - max_matrix), np.exp(u - max_u), out=out) - np.log(out, out=out) - out += max_matrix + max_u +from scipy.optimize import fmin_l_bfgs_b def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, @@ -323,68 +311,61 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, ot.optim.cg : General regularized OT """ + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64) if len(a) == 0: - a = np.ones((M.shape[0], 1), dtype=np.float64) / M.shape[0] + a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] if len(b) == 0: - b = np.ones((M.shape[1], 1), dtype=np.float64) / M.shape[1] + b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] + + # init data + dim_a = len(a) + dim_b = len(b) if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 - if len(a.shape) == 1: - a = a[:, None] - - if len(b.shape) == 1: - b = b[:, None] - - log_threshold = math.log(stopThr) - is_logweight = kwargs.get('is_logweight', False) - - if not is_logweight: - a = np.log(a) - b = np.log(b) - - # init data - dim_a = len(a) - dim_b = len(b) - if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if n_hists: - u = np.zeros((dim_a, n_hists)) - math.log(dim_a) - v = np.zeros((dim_b, n_hists)) - math.log(dim_b) + u = np.ones((dim_a, n_hists)) / dim_a + v = np.ones((dim_b, n_hists)) / dim_b else: - u = np.zeros((dim_a, 1)) - math.log(dim_a) - v = np.zeros((dim_b, 1)) - math.log(dim_b) + u = np.ones(dim_a) / dim_a + v = np.ones(dim_b) / dim_b - log_K = -M / reg + # print(reg) - log_Kp = -a.reshape(-1, 1) + log_K - log_K_T = log_K.T - cpt = 0 - log_err = 0.5 * log_threshold + # Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute + K = np.empty(M.shape, dtype=M.dtype) + np.divide(M, -reg, out=K) + np.exp(K, out=K) - while log_err > log_threshold and cpt < numItermax: + # print(np.min(K)) + tmp2 = np.empty(b.shape, dtype=M.dtype) + + Kp = (1 / a).reshape(-1, 1) * K + cpt = 0 + err = 1 + while (err > stopThr and cpt < numItermax): uprev = u vprev = v - log_matvec(log_K_T, u, v) - v *= -1 - v += b - log_matvec(log_Kp, v, u) - u *= -1 + KtransposeU = np.dot(K.T, u) + v = np.divide(b, KtransposeU) + u = 1. / np.dot(Kp, v) - if np.any(~np.isfinite(u)) or np.any(~np.isfinite(v)): + if (np.any(KtransposeU == 0) + or np.any(np.isnan(u)) or np.any(np.isnan(v)) + or np.any(np.isinf(u)) or np.any(np.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop print('Warning: numerical errors at iteration', cpt) @@ -394,32 +375,27 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, if cpt % 10 == 0: # we can speed up the process by checking for the error only all # the 10th iterations - temp2 = u + log_K + v.T - temp2 = logsumexp(temp2, axis=0, keepdims=True).T - # noinspection PyTypeChecker - log_err = 0.5 * np.sum(np.exp(2 * temp2) - np.exp(2 * b)) # violation of marginal - # would be more efficient with a check on stability of dual vectors + if n_hists: + np.einsum('ik,ij,jk->jk', u, K, v, out=tmp2) + else: + # compute right marginal tmp2= (diag(u)Kdiag(v))^T1 + np.einsum('i,ij,j->j', u, K, v, out=tmp2) + err = np.linalg.norm(tmp2 - b) # violation of marginal if log: - log['err'].append(math.exp(log_err)) + log['err'].append(err) if verbose: if cpt % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(cpt, np.exp(log_err))) + print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if log: - log['u'] = np.exp(u) if not is_logweight else u - log['v'] = np.exp(v) if not is_logweight else v - - gamma = u + log_K + v.T - res = logsumexp(gamma, axis=(0, 1), b=M) - if not is_logweight: - gamma = np.exp(gamma) - res = np.exp(res) - if log: - log['cost'] = res + log['u'] = u + log['v'] = v + if n_hists: # return only loss + res = np.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: @@ -428,9 +404,9 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, else: # return OT matrix if log: - return gamma.squeeze(), log + return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log else: - return gamma.squeeze() + return u.reshape((-1, 1)) * K * v.reshape((1, -1)) def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False, @@ -740,7 +716,7 @@ def get_Gamma(alpha, beta, u, v): if np.abs(u).max() > tau or np.abs(v).max() > tau: if n_hists: alpha, beta = alpha + reg * \ - np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) + np.max(np.log(u), 1), beta + reg * np.max(np.log(v)) else: alpha, beta = alpha + reg * np.log(u), beta + reg * np.log(v) if n_hists: @@ -2206,11 +2182,11 @@ def projection(u, epsilon): # box constraints in L-BFGS-B (see Proposition 1 in [26]) bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / ( - ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget + ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget bounds_v = [( - max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), - epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget + max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))), + epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget # pre-calculated constants for the objective vec_eps_IJc = epsilon * kappa * (K_IJc * np.ones(nt - nt_budget).reshape((1, -1))).sum(axis=1) From 36377ccb1d011fce081aea64a90f8f83710c9be5 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 13:40:55 +0300 Subject: [PATCH 04/14] Add comments by Flamary --- ot/lp/__init__.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 2f0cc1393..233573033 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -272,7 +272,7 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True): if np.any(~asel) or np.any(~bsel): u, v = estimate_dual_null_weights(u, v, a, b, M) - + result_code_string = check_result(result_code) if log: log = {} @@ -389,7 +389,7 @@ def emd2(a, b, M, processes=multiprocessing.cpu_count(), if log or return_matrix: def f(b): bsel = b != 0 - + G, cost, u, v, result_code = emd_c(a, b, M, numItermax) if center_dual: @@ -438,12 +438,7 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None """ Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms. - This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: - - we do not optimize over the weights - - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. - - Formally we solve the following problem: + The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms, formally: .. math:: \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) @@ -455,6 +450,11 @@ def free_support_barycenter(measures_locations, measures_weights, X_init, b=None - the :math:`X_i \in \mathbb{R}^{k_i, d}` are the empirical measures atoms locations - :math:`b \in \mathbb{R}^{k}` is the desired weights vector of the barycenter + This problem is considered in [1] (Algorithm 2). There are two differences with the following codes: + + - we do not optimize over the weights + - we do not do line search for the locations updates, we use i.e. theta = 1 in [1] (Algorithm 2). This can be seen as a discrete implementation of the fixed-point algorithm of [2] proposed in the continuous setting. + Parameters ---------- measures_locations : list of N (k_i,d) numpy.ndarray From 110f382bd5dd60303e95e2e516880f77297a7ef5 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 13:41:49 +0300 Subject: [PATCH 05/14] Delete repetitive description --- ot/lp/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 233573033..36cf838af 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -436,9 +436,7 @@ def f(b): def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): """ - Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance) - - The function solves the Wasserstein barycenter problem when the barycenter measure is constrained to be supported on k atoms, formally: + Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: .. math:: \min_X \sum_{i=1}^N w_i W_2^2(b, X, a_i, X_i) From cbf6bf5fe80aac2b6eac09385421554b1cf0ab36 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 13:44:54 +0300 Subject: [PATCH 06/14] Added raw string to avoid pbs with backslashes --- ot/lp/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 36cf838af..2a1b082f5 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -435,7 +435,7 @@ def f(b): def free_support_barycenter(measures_locations, measures_weights, X_init, b=None, weights=None, numItermax=100, stopThr=1e-7, verbose=False, log=None): - """ + r""" Solves the free support (locations of the barycenters are optimized, not the weights) Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Wasserstein distance), formally: .. math:: From 22e7f6b1a0313c78b7049fbd0b63e532b5450e1d Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 14:40:02 +0300 Subject: [PATCH 07/14] Implements sliced wasserstein --- ot/__init__.py | 1 + ot/sliced.py | 99 +++++++++++++++++++++++++++++++++++++++++++++ test/test_sliced.py | 59 +++++++++++++++++++++++++++ 3 files changed, 159 insertions(+) create mode 100644 ot/sliced.py create mode 100644 test/test_sliced.py diff --git a/ot/__init__.py b/ot/__init__.py index 0e6e2e25b..79c87cbaa 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,6 +39,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm +from .sliced import sliced # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/sliced.py b/ot/sliced.py new file mode 100644 index 000000000..ee0fdd9ac --- /dev/null +++ b/ot/sliced.py @@ -0,0 +1,99 @@ +""" +Sliced Wasserstein Distance. + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + + +import numpy as np + + +def _random_projections(n_projections, dimension, random_state): + """Samples n_projections times dimension normal distributions""" + projections = random_state.normal(0., 1., [n_projections, dimension]) + norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) + projections = projections / norm + return projections + + +def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): + r""" + Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + .. math:: + \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} + + where : + + - :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle` + + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + a : ndarray, shape (n_samples_a,) + samples weights in the source domain + b : ndarray, shape (n_samples_b,) + samples weights in the target domain + n_projections : int + Number of projections used for the Monte-Carlo approximation + seed: int or RandomState or None + Seed used for numpy random number generator + + Returns + ------- + cost: float + Sliced Wasserstein Cost + log : dict, optional + log dictionary return only if log==True in parameters + + Examples + -------- + + >>> n_samples_a = 20 + >>> reg = 0.1 + >>> X = np.random.normal(0., 1., (n_samples_a, 5)) + >>> sliced(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + 0. + + References + ---------- + + .. [1] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 + """ + from .lp import emd2_1d + + X_s = np.asanyarray(X_s) + X_t = np.asanyarray(X_t) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + f"X_s and X_t must have the same number of dimensions {X_s.shape[1]} and {X_t.shape[1]} respectively given") + + if a is None: + a = np.full(n, 1 / n) + if b is None: + b = np.full(m, 1 / m) + + d = X_s.shape[1] + + if not isinstance(seed, np.random.RandomState): + random_state = np.random.RandomState(seed) + projections = _random_projections(n_projections, d, random_state) + else: + projections = _random_projections(n_projections, d, seed) + + res = 0. + for projection in projections: + X_s_proj = X_s @ projection + X_t_proj = X_t @ projection + res += emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) + return (res / n_projections) ** 0.5 diff --git a/test/test_sliced.py b/test/test_sliced.py new file mode 100644 index 000000000..a3272ba11 --- /dev/null +++ b/test/test_sliced.py @@ -0,0 +1,59 @@ +"""Tests for module sliced""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import numpy as np +import ot +import pytest + + +def test_sliced_same_dist(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + res = ot.sliced(x, x, u, u, 10, seed=rng) + np.testing.assert_almost_equal(res, 0.) + + +def test_sliced_bad_shapes(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + with pytest.raises(ValueError): + _ = ot.sliced(x, y, u, u, 10, seed=rng) + + +def test_sliced_different_dists(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + y = rng.randn(n, 2) + + res = ot.sliced(x, y, u, u, 10, seed=rng) + assert res > 0. + + +def test_1d_sliced_equals_emd(): + n = 100 + m = 120 + rng = np.random.RandomState(0) + + x = rng.randn(n, 1) + a = rng.uniform(0, 1, n) + a /= a.sum() + y = rng.randn(m, 1) + u = ot.utils.unif(m) + res = ot.sliced(x, y, a, u, 10, seed=42) + expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) + np.testing.assert_almost_equal(res ** 2, expected) From ba04ed644c6a07559dc16e151d11a0555a4e0f98 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 15:03:46 +0300 Subject: [PATCH 08/14] Changed formatting of string for py3.5 support --- ot/sliced.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ot/sliced.py b/ot/sliced.py index ee0fdd9ac..b589cadd8 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -76,7 +76,8 @@ def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): if X_s.shape[1] != X_t.shape[1]: raise ValueError( - f"X_s and X_t must have the same number of dimensions {X_s.shape[1]} and {X_t.shape[1]} respectively given") + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1], + X_t.shape[1])) if a is None: a = np.full(n, 1 / n) From ca8364c514bfa162e2106f8a41b1e1ac00546017 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 20 Jul 2020 15:08:13 +0300 Subject: [PATCH 09/14] Docstest, expected 0.0 and not 0. --- ot/sliced.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/sliced.py b/ot/sliced.py index b589cadd8..8857ee5fd 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -59,7 +59,7 @@ def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) >>> sliced(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE - 0. + 0.0 References ---------- From 2d893f29709dc21f947dac98fed91679aa5eaccd Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 4 Aug 2020 09:47:16 +0300 Subject: [PATCH 10/14] Adressed comments by @rflamary --- README.md | 6 ++ examples/sliced-wasserstein/README.txt | 4 + examples/sliced-wasserstein/plot_variance.py | 86 ++++++++++++++++++++ ot/__init__.py | 4 +- ot/sliced.py | 84 ++++++++++++++----- test/test_sliced.py | 38 +++++++-- 6 files changed, 194 insertions(+), 28 deletions(-) create mode 100644 examples/sliced-wasserstein/README.txt create mode 100644 examples/sliced-wasserstein/plot_variance.py diff --git a/README.md b/README.md index e3598f11c..150c10632 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples): * [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3] formulations). +* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32]. POT provides the following Machine Learning related solvers: @@ -180,6 +181,7 @@ The contributors to this library are * [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein) * [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn) * [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT) +* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance) This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages): @@ -263,3 +265,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t [29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276. [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +[32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 \ No newline at end of file diff --git a/examples/sliced-wasserstein/README.txt b/examples/sliced-wasserstein/README.txt new file mode 100644 index 000000000..a57534591 --- /dev/null +++ b/examples/sliced-wasserstein/README.txt @@ -0,0 +1,4 @@ + + +Sliced Wasserstein Distance +--------------------------- \ No newline at end of file diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py new file mode 100644 index 000000000..703e72e56 --- /dev/null +++ b/examples/sliced-wasserstein/plot_variance.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +""" +============================== +2D Sliced Wasserstein Distance +============================== + +This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31]. + +[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + +""" + +# Author: Adrien Corenflos +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np +# necessary for 3d plot even if not used +from mpl_toolkits.mplot3d import Axes3D # noqa + +import ot + +############################################################################## +# Generate data +# ------------- + +# %% parameters and data generation + +n = 500 # nb samples + +mu_s = np.array([0, 0]) +cov_s = np.array([[1, 0], [0, 1]]) + +mu_t = np.array([4, 4]) +cov_t = np.array([[1, -.8], [-.8, 1]]) + +xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s) +xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t) + +a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples + +############################################################################## +# Plot data +# --------- + +# %% plot samples + +pl.figure(1) +pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples') +pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples') +pl.legend(loc=0) +pl.title('Source and target distributions') + +################################################################################### +# Compute Sliced Wasserstein distance for different seeds and number of projections +# ----------- + +n_seed = 50 +n_projections_arr = np.logspace(0, 3, 25, dtype=int) +res = np.empty((n_seed, 25)) + +# %% Compute statistics +for seed in range(n_seed): + for i, n_projections in enumerate(n_projections_arr): + res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed) + +res_mean = np.mean(res, axis=0) +res_std = np.std(res, axis=0) + +################################################################################### +# Plot Sliced Wasserstein Distance +# ----------- + +pl.figure(2) +pl.plot(n_projections_arr, res_mean, label="SWD") +pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5) + +pl.legend() +pl.xscale('log') + +pl.xlabel("Number of projections") +pl.ylabel("Distance") +pl.title('Sliced Wasserstein Distance with 95% confidence inverval') + +pl.show() diff --git a/ot/__init__.py b/ot/__init__.py index 79c87cbaa..ec3ede257 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -39,7 +39,7 @@ from .bregman import sinkhorn, sinkhorn2, barycenter from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 from .da import sinkhorn_lpl1_mm -from .sliced import sliced +from .sliced import sliced_wasserstein_distance # utils functions from .utils import dist, unif, tic, toc, toq @@ -51,4 +51,4 @@ 'emd_1d', 'emd2_1d', 'wasserstein_1d', 'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim', 'sinkhorn_unbalanced', 'barycenter_unbalanced', - 'sinkhorn_unbalanced2'] + 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance'] diff --git a/ot/sliced.py b/ot/sliced.py index 8857ee5fd..66728f50e 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -3,7 +3,7 @@ """ -# Author: Adrien Corenflos +# Author: Adrien Corenflos # # License: MIT License @@ -11,15 +11,46 @@ import numpy as np -def _random_projections(n_projections, dimension, random_state): - """Samples n_projections times dimension normal distributions""" - projections = random_state.normal(0., 1., [n_projections, dimension]) +def get_random_projections(n_projections, d, seed=None): + r""" + Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})` + + Parameters + ---------- + n_projections : int + number of samples requested + d : int + dimension of the space + seed: int or RandomState, optional + Seed used for numpy random number generator + + Returns + ------- + out: ndarray, shape (n_projections, d) + The uniform unit vectors on the sphere + + Examples + -------- + >>> n_projections = 100 + >>> d = 5 + >>> projs = get_random_projections(n_projections, d) + >>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE + True + + """ + + if not isinstance(seed, np.random.RandomState): + random_state = np.random.RandomState(seed) + else: + random_state = seed + + projections = random_state.normal(0., 1., [n_projections, d]) norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True) projections = projections / norm return projections -def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): +def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): r""" Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance .. math:: @@ -36,14 +67,16 @@ def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): samples in the source domain X_t : ndarray, shape (n_samples_b, dim) samples in the target domain - a : ndarray, shape (n_samples_a,) + a : ndarray, shape (n_samples_a,), optional samples weights in the source domain - b : ndarray, shape (n_samples_b,) + b : ndarray, shape (n_samples_b,), optional samples weights in the target domain - n_projections : int + n_projections : int, optional Number of projections used for the Monte-Carlo approximation - seed: int or RandomState or None + seed: int or RandomState or None, optional Seed used for numpy random number generator + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. Returns ------- @@ -58,13 +91,14 @@ def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): >>> n_samples_a = 20 >>> reg = 0.1 >>> X = np.random.normal(0., 1., (n_samples_a, 5)) - >>> sliced(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE + >>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE 0.0 References ---------- - .. [1] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 + .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 + .. [32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 """ from .lp import emd2_1d @@ -86,15 +120,25 @@ def sliced(X_s, X_t, a=None, b=None, n_projections=50, seed=None): d = X_s.shape[1] - if not isinstance(seed, np.random.RandomState): - random_state = np.random.RandomState(seed) - projections = _random_projections(n_projections, d, random_state) + projections = get_random_projections(n_projections, d, seed) + + X_s_projections = np.dot(projections, X_s.T) + X_t_projections = np.dot(projections, X_t.T) + + if log: + projected_emd = [] else: - projections = _random_projections(n_projections, d, seed) + projected_emd = None res = 0. - for projection in projections: - X_s_proj = X_s @ projection - X_t_proj = X_t @ projection - res += emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) - return (res / n_projections) ** 0.5 + + for X_s_proj, X_t_proj in zip(X_s_projections, X_t_projections): + emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) + if projected_emd is not None: + projected_emd.append(emd) + res += emd + + res = (res / n_projections) ** 0.5 + if log: + return res, {"projections": projections.tolist(), "projected_emds": projected_emd} + return res diff --git a/test/test_sliced.py b/test/test_sliced.py index a3272ba11..a07d97500 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -1,13 +1,21 @@ """Tests for module sliced""" -# Author: Adrien Corenflos +# Author: Adrien Corenflos # # License: MIT License import numpy as np -import ot import pytest +import ot +from ot.sliced import get_random_projections + + +def test_get_random_projections(): + rng = np.random.RandomState(0) + projections = get_random_projections(1000, 50, rng) + np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.) + def test_sliced_same_dist(): n = 100 @@ -16,7 +24,7 @@ def test_sliced_same_dist(): x = rng.randn(n, 2) u = ot.utils.unif(n) - res = ot.sliced(x, x, u, u, 10, seed=rng) + res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng) np.testing.assert_almost_equal(res, 0.) @@ -29,7 +37,25 @@ def test_sliced_bad_shapes(): u = ot.utils.unif(n) with pytest.raises(ValueError): - _ = ot.sliced(x, y, u, u, 10, seed=rng) + _ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) + + +def test_sliced_log(): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True) + assert len(log) == 2 + projections = log["projections"] + projected_emds = log["projected_emds"] + + assert len(projections) == len(projected_emds) == 10 + for emd in projected_emds: + assert emd > 0 def test_sliced_different_dists(): @@ -40,7 +66,7 @@ def test_sliced_different_dists(): u = ot.utils.unif(n) y = rng.randn(n, 2) - res = ot.sliced(x, y, u, u, 10, seed=rng) + res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng) assert res > 0. @@ -54,6 +80,6 @@ def test_1d_sliced_equals_emd(): a /= a.sum() y = rng.randn(m, 1) u = ot.utils.unif(m) - res = ot.sliced(x, y, a, u, 10, seed=42) + res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42) expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u) np.testing.assert_almost_equal(res ** 2, expected) From 7d9b9207da796922042689ae9c44be80a872f484 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 4 Aug 2020 09:57:04 +0300 Subject: [PATCH 11/14] No 3d plot here --- examples/sliced-wasserstein/plot_variance.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/sliced-wasserstein/plot_variance.py b/examples/sliced-wasserstein/plot_variance.py index 703e72e56..f3deeff78 100644 --- a/examples/sliced-wasserstein/plot_variance.py +++ b/examples/sliced-wasserstein/plot_variance.py @@ -16,8 +16,6 @@ import matplotlib.pylab as pl import numpy as np -# necessary for 3d plot even if not used -from mpl_toolkits.mplot3d import Axes3D # noqa import ot From b68e2c2a783e1488e8acc989ff98f76ed7241e7a Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Tue, 4 Aug 2020 10:29:35 +0300 Subject: [PATCH 12/14] add sliced to the docs --- docs/source/all.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/all.rst b/docs/source/all.rst index d7b878f37..f1f707505 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -27,6 +27,7 @@ API and modules stochastic unbalanced partial + sliced .. autosummary:: :toctree: ../modules/generated/ From 9a8edb56461b8f89bd45fe20814ca9605d8d7720 Mon Sep 17 00:00:00 2001 From: AdrienCorenflos Date: Mon, 31 Aug 2020 18:03:27 +0300 Subject: [PATCH 13/14] Incorporate comments by @rflamary --- README.md | 2 -- ot/sliced.py | 10 +++++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 150c10632..d49588768 100644 --- a/README.md +++ b/README.md @@ -267,5 +267,3 @@ You can also post bug reports and feature requests in Github issues. Make sure t [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 - -[32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 \ No newline at end of file diff --git a/ot/sliced.py b/ot/sliced.py index 66728f50e..479257608 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -53,6 +53,7 @@ def get_random_projections(n_projections, d, seed=None): def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False): r""" Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance + .. math:: \mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}} @@ -98,7 +99,6 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed ---------- .. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 - .. [32] S. Kolouri et al., Generalized Sliced Wasserstein Distances, Advances in Neural Information Processing Systems (NIPS) 33, 2019 """ from .lp import emd2_1d @@ -126,19 +126,19 @@ def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed X_t_projections = np.dot(projections, X_t.T) if log: - projected_emd = [] + projected_emd = np.empty(n_projections) else: projected_emd = None res = 0. - for X_s_proj, X_t_proj in zip(X_s_projections, X_t_projections): + for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)): emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False) if projected_emd is not None: - projected_emd.append(emd) + projected_emd[i] = emd res += emd res = (res / n_projections) ** 0.5 if log: - return res, {"projections": projections.tolist(), "projected_emds": projected_emd} + return res, {"projections": projections, "projected_emds": projected_emd} return res From 5590a794836aeffb9ca5a2df1cd03c3e0360fc38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Fri, 4 Sep 2020 14:09:45 +0200 Subject: [PATCH 14/14] add link to pdf --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d49588768..6fe528a2e 100644 --- a/README.md +++ b/README.md @@ -266,4 +266,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t [30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014. -[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45 +[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45