From f49f6b4f34ddd3a2313e1df00c487bd7f47df845 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 01/23] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 3c4b50fdb660f27cc080618edb664d17086d93a9 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 02/23] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 90 ++++++++++++++++++++++++++++ 3 files changed, 272 insertions(+), 19 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 8cbb0db25..d570b9f30 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -938,3 +938,93 @@ def citation(self): url = {http://jmlr.org/papers/v22/20-451.html} } """ + + + +############################## WORK IN PROGRESS #################################### + +## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes + +class OTResultLazy: + def __init__(self, potentials=None, lazy_plan=None, backend=None): + + self._potentials = potentials + self._lazy_plan = lazy_plan + self._backend = backend if backend is not None else NumpyBackend() + + + # Dual potentials -------------------------------------------- + + def __repr__(self): + s = 'OTResultLazy(' + if self._lazy_plan is not None: + s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) + + if s[-1] != '(': + s = s[:-1] + ')' + else: + s = s + ')' + return s + + @property + def potentials(self): + """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + """ + if self._potentials is not None: + return self._potentials + else: + raise NotImplementedError() + + @property + def potential_a(self): + """First dual potential, associated to the "source" measure "a".""" + if self._potentials is not None: + return self._potentials[0] + else: + raise NotImplementedError() + + @property + def potential_b(self): + """Second dual potential, associated to the "target" measure "b".""" + if self._potentials is not None: + return self._potentials[1] + else: + raise NotImplementedError() + + # Transport plan ------------------------------------------- + @property + def lazy_plan(self): + """A subset of the Transport plan, encoded as a dense array.""" + + if self._lazy_plan is not None: + return self._lazy_plan + else: + raise NotImplementedError() + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ \ No newline at end of file From 3034e575c55d2ce56499be6849e1906fe52f0573 Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 03/23] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 27 ++++++++++++ 5 files changed, 304 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 40324518e..a14be460e 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -401,3 +401,30 @@ def test_get_coordinate_circle(): x_p = ot.utils.get_coordinate_circle(x) np.testing.assert_allclose(u[0], x_p) + + + +############################################################################################## +##################################### WORK IN PROGRESS ####################################### +############################################################################################## + +# test function for OTResultLazy + +def test_OTResultLazy(): + + res_lazy = ot.utils.OTResultLazy() + + # test print + print(res_lazy) + + # tets get citation + print(res_lazy.citation) + + lst_attributes = ['lazy_plan', + 'potential_a', + 'potential_b', + 'potentials'] + + for at in lst_attributes: + with pytest.raises(NotImplementedError): + getattr(res_lazy, at) \ No newline at end of file From 085863aef96f0d19e740879dfae158a762275a67 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 04/23] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq From 9becafc305fd6b2cc5390b0de16bae015bd41121 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 11:38:40 +0100 Subject: [PATCH 05/23] modify low rank, remove solve_sample,OTResultLazy --- ot/__init__.py | 4 +- ot/lowrank.py | 200 ++++++++++++++++++++++++------------------- ot/solvers.py | 160 ---------------------------------- ot/utils.py | 89 ------------------- test/test_lowrank.py | 54 ++++++------ test/test_solvers.py | 97 --------------------- test/test_utils.py | 25 ------ 7 files changed, 142 insertions(+), 487 deletions(-) diff --git a/ot/__init__.py b/ot/__init__.py index cb00f4553..4aba450af 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -51,7 +51,7 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov from .lowrank import lowrank_sinkhorn # utils functions @@ -70,4 +70,4 @@ 'factored_optimal_transport', 'solve', 'solve_gromov', 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] diff --git a/ot/lowrank.py b/ot/lowrank.py index 22ff8b754..b3fce8de0 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -1,78 +1,88 @@ +""" +Low rank OT solvers +""" + +# Author: Laurène David +# +# License: MIT License + + + ################################################################################################################# ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# -## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms import warnings -from .utils import unif, list_to_array, dist -from .backend import get_backend +from ot.utils import unif +from ot.backend import get_backend ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): """ - Implementation of the Dykstra algorithm for low rank sinkhorn + Implementation of the Dykstra algorithm for the Low rank sinkhorn solver + """ + # Get dykstra parameters + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p + g_ = eps3.copy() + err = 1 - # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p + # POT backend if needed + if nx is None: + nx = get_backend(eps1, eps2, eps3, p1, p2, + g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - # POT backend - eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) - q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) - - nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) - - # ------- Dykstra algorithm ------ - g_ = eps3 - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) + # ------- Dykstra algorithm ------ + while err > stopThr : + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g.copy() - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + q1 = (v1_ * q1) / v1 - q1 = (v1_ * q1) / v1 - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1, v2 - g_ = g + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1.copy(), v2.copy() + g_ = g.copy() - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 # Compute low rank matrices Q, R Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - return Q, R, g, err, dykstra_p + return Q, R, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", - numItermax=10000, stopThr=1e-9, warn=True, verbose=False): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", + numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 + r''' Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters @@ -95,6 +105,9 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea Max number of iterations stopThr : float, optional Stop threshold on error (>0) + warn: + + verbose: Returns @@ -109,73 +122,87 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea References ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' - X_s, X_t = list_to_array(X_s, X_t) nx = get_backend(X_s, X_t) - ns, nt = X_s.shape[0], X_t.shape[0] if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) + a = unif(ns, type_as=X_s) if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) + b = unif(nt, type_as=X_t) - # Compute cost matrix - M = dist(X_s,X_t, metric=metric) + d = X_s.shape[1] + + # First low rank decomposition of the cost matrix (A) + M1 = nx.zeros((ns,(d+2))) + M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] + M1[:,1] = nx.ones(ns) + M1[:,2:] = -2*X_s + + # Second low rank decomposition of the cost matrix (B) + M2 = nx.zeros((nt,(d+2))) + M2[:,0] = nx.ones(nt) + M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] + M2[:,2:] = X_t # Compute rank rank = min(ns, nt, rank) r = rank + # Alpha: lower bound for 1/rank if alpha == 'auto': - alpha = 1.0 / (r + 1) + alpha = 1e-3 # no convergence with alpha = 1 / (r+1) if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") + warnings.warn("The provided alpha value might lead to instabilities.") - # Compute gamma - L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) + L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) gamma = 1/(2*L) - # Initialisation + # Initialize the low rank matrices Q, R, g Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + + # Initialize parameters for Dykstra algorithm q3_1, q3_2 = nx.ones(r), nx.ones(r) + u1, u2 = nx.ones(ns), nx.ones(nt) + v1, v2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - err = 1 + dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] + - for ii in range(numItermax): - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) + for ii in range(numItermax): + CR_ = nx.dot(M2.T, R) + CR = nx.dot(M1, CR_) + + CQ_ = nx.dot(M1.T, Q) + CQ = nx.dot(M2, CQ_) + diag_g = (1/g)[:,None] eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - - if err < stopThr: - break - - if verbose: - if ii % 200 == 0: - print( - '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - print('{:5d}|{:8e}|'.format(ii, err)) - - else: - if warn: - warnings.warn("Sinkhorn did not converge. You might want to " - "increase the number of iterations `numItermax` " - "or the regularization parameter `reg`.") + Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) + g = dykstra_p[0] + + # if verbose: + # if ii % 200 == 0: + # print( + # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + # print('{:5d}|{:8e}|'.format(ii, err)) + + # else: + # if warn: + # warnings.warn("Sinkhorn did not converge. You might want to " + # "increase the number of iterations `numItermax` " + # "or the regularization parameter `reg`.") return Q, R, g @@ -187,24 +214,23 @@ def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidea ## Test with X_s, X_t from ot.datasets ############################################################################# -# import numpy as np -# import ot - -# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +import numpy as np +import ot -# ns = Xs.shape[0] -# nt = Xt.shape[0] +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -# a = unif(ns) -# b = unif(nt) +ns = Xs.shape[0] +nt = Xt.shape[0] -# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) -# M = ot.dist(Xs,Xt) -# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +a = unif(ns) +b = unif(nt) -# print(np.sum(P)) +Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index c176969ca..8d6e10a5f 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -854,163 +854,3 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, -################################## WORK IN PROGRESS ##################################### - -## Implementation of the ot.solve_sample function -## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) - - -from .utils import unif, list_to_array, dist, OTResultLazy -from .bregman import empirical_sinkhorn - - -def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): - - r"""Solve the discrete optimal transport problem using the samples in the source and target domains. - It returns either a :any:`OTResult` or :any:`OTResultLazy` object. - - The function solves the following general optimal transport problem - - .. math:: - \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + - \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + - \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) - - The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By - default ``reg=None`` and there is no regularization. The unbalanced marginal - penalization can be selected with `unbalanced` (:math:`\lambda_u`) and - `unbalanced_type`. By default ``unbalanced=None`` and the function - solves the exact optimal transport problem (respecting the marginals). - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (dim_a,), optional - Samples weights in the source domain (default is uniform) - b : array-like, shape (dim_b,), optional - Samples weights in the source domain (default is uniform) - reg : float, optional - Regularization weight :math:`\lambda_r`, by default None (no reg., exact - OT) - reg_type : str, optional - Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" - unbalanced : float, optional - Unbalanced penalization weight :math:`\lambda_u`, by default None - (balanced OT) - unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" - is_Lazy : bool, optional - Return :any:`OTResultlazy` object to reduce memory cost when True, by default False - n_threads : int, optional - Number of OMP threads for exact OT solver, by default 1 - max_iter : int, optional - Maximum number of iteration, by default None (default values in each solvers) - plan_init : array_like, shape (dim_a, dim_b), optional - Initialization of the OT plan for iterative methods, by default None - potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional - Initialization of the OT dual potentials for iterative methods, by default None - tol : _type_, optional - Tolerance for solution precision, by default None (default values in each solvers) - verbose : bool, optional - Print information in the solver, by default False - - Returns - ------- - - res_lazy : OTResultLazy() - Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. - The information can be obtained as follows: - - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - - res.potentials : OT dual potentials - - See :any:`OTResultLazy` for more information. - - res : OTResult() - Result of the optimization problem. The information can be obtained as follows: - - - res.plan : OT plan :math:`\mathbf{T}` - - res.potentials : OT dual potentials - - res.value : Optimal value of the optimization problem - - res.value_linear : Linear OT loss with the optimal OT plan - - See :any:`OTResult` for more information. - - - """ - - X_s, X_t = list_to_array(X_s,X_t) - - # detect backend - arr = [X_s,X_t] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) - - # create uniform weights if not given - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) - - # default values for solutions - potentials = None - lazy_plan = None - - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if batch_size is None: - batch_size = 100 - - if is_Lazy: - ################# WIP #################### - if reg is None or reg == 0: # EMD solver for isLazy ? - - if unbalanced is None: # balanced EMD solver for isLazy ? - raise (NotImplementedError('Not implemented balanced with no regularization')) - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) - - - ############################################# - - else: - if unbalanced is None: - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - # compute potentials - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy - - else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) - - else: - # compute cost matrix M and use solve function - M = dist(X_s, X_t, metric) - - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) - return res - - - - diff --git a/ot/utils.py b/ot/utils.py index d570b9f30..01944f56b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -939,92 +939,3 @@ def citation(self): } """ - - -############################## WORK IN PROGRESS #################################### - -## Implementation of the OTResultLazy class for ot.solve_sample() with potentials and lazy_plan as attributes - -class OTResultLazy: - def __init__(self, potentials=None, lazy_plan=None, backend=None): - - self._potentials = potentials - self._lazy_plan = lazy_plan - self._backend = backend if backend is not None else NumpyBackend() - - - # Dual potentials -------------------------------------------- - - def __repr__(self): - s = 'OTResultLazy(' - if self._lazy_plan is not None: - s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) - - if s[-1] != '(': - s = s[:-1] + ')' - else: - s = s + ')' - return s - - @property - def potentials(self): - """Dual potentials, i.e. Lagrange multipliers for the marginal constraints. - - This pair of arrays has the same shape, numerical type - and properties as the input weights "a" and "b". - """ - if self._potentials is not None: - return self._potentials - else: - raise NotImplementedError() - - @property - def potential_a(self): - """First dual potential, associated to the "source" measure "a".""" - if self._potentials is not None: - return self._potentials[0] - else: - raise NotImplementedError() - - @property - def potential_b(self): - """Second dual potential, associated to the "target" measure "b".""" - if self._potentials is not None: - return self._potentials[1] - else: - raise NotImplementedError() - - # Transport plan ------------------------------------------- - @property - def lazy_plan(self): - """A subset of the Transport plan, encoded as a dense array.""" - - if self._lazy_plan is not None: - return self._lazy_plan - else: - raise NotImplementedError() - - @property - def citation(self): - """Appropriate citation(s) for this result, in plain text and BibTex formats.""" - - # The string below refers to the POT library: - # successor methods may concatenate the relevant references - # to the original definitions, solvers and underlying numerical backends. - return """POT library: - - POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. - Website: https://pythonot.github.io/ - Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; - - @article{flamary2021pot, - author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, - title = {{POT}: {Python} {Optimal} {Transport}}, - journal = {Journal of Machine Learning Research}, - year = {2021}, - volume = {22}, - number = {78}, - pages = {1-8}, - url = {http://jmlr.org/papers/v22/20-451.html} - } - """ \ No newline at end of file diff --git a/test/test_lowrank.py b/test/test_lowrank.py index 6e1f24067..7d90ce9ef 100644 --- a/test/test_lowrank.py +++ b/test/test_lowrank.py @@ -16,34 +16,34 @@ def test_LR_Dykstra(): pass -@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -def test_lowrank_sinkhorn(verbose, warn): - # test low rank sinkhorn - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) - P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - - Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') - P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - - # check constraints - np.testing.assert_allclose( - a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian - np.testing.assert_allclose( - a, P_m.sum(1), atol=1e-05) # metric euclidian - np.testing.assert_allclose( - b, P_m.sum(0), atol=1e-05) # metric euclidian +# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +# def test_lowrank_sinkhorn(verbose, warn): +# # test low rank sinkhorn +# n = 100 +# a = ot.unif(n) +# b = ot.unif(n) + +# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) +# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + +# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) +# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + +# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') +# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + +# # check constraints +# np.testing.assert_allclose( +# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian +# np.testing.assert_allclose( +# a, P_m.sum(1), atol=1e-05) # metric euclidian +# np.testing.assert_allclose( +# b, P_m.sum(0), atol=1e-05) # metric euclidian - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) +# with pytest.warns(UserWarning): +# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) diff --git a/test/test_solvers.py b/test/test_solvers.py index 5a05d54cf..e845ac7c2 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -258,100 +258,3 @@ def test_solve_gromov_not_implemented(nx): - -########################################################################################################### -############################################ WORK IN PROGRESS ############################################# -########################################################################################################### - -def assert_allclose_sol_sample(sol1, sol2): - # test attributes of OTResultLazy class - lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] - - nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() - nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() - - for attr in lst_attr: - try: - np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) - except NotImplementedError: - pass - - -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) -def test_solve_sample(nx): - # test solve_sample when is_Lazy = False - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t) - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b) - - # check some attributes - sol.potentials - sol.sparse_plan - sol.marginals - sol.status - - assert_allclose_sol(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb) - - assert_allclose_sol(sol, solb) - - # test not implemented unbalanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') - - # test not implemented reg_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') - - - -def test_lazy_solve_sample(nx): - # test solve_sample when is_Lazy = True - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) - - # check some attributes - sol.potentials - sol.lazy_plan - - assert_allclose_sol_sample(sol0, sol) - - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) - - assert_allclose_sol_sample(sol, solb) - - # test not implemented reg==0 (or None) + balanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default - - # test not implemented reg==0 (or None) + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default - - # test not implemented reg != 0 + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index a14be460e..bbadec65a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -403,28 +403,3 @@ def test_get_coordinate_circle(): np.testing.assert_allclose(u[0], x_p) - -############################################################################################## -##################################### WORK IN PROGRESS ####################################### -############################################################################################## - -# test function for OTResultLazy - -def test_OTResultLazy(): - - res_lazy = ot.utils.OTResultLazy() - - # test print - print(res_lazy) - - # tets get citation - print(res_lazy.citation) - - lst_attributes = ['lazy_plan', - 'potential_a', - 'potential_b', - 'potentials'] - - for at in lst_attributes: - with pytest.raises(NotImplementedError): - getattr(res_lazy, at) \ No newline at end of file From 58576a3f89ac3322c93588832bdb4daa380ce384 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 13:59:01 +0100 Subject: [PATCH 06/23] solve_sample + test functions --- ot/solvers.py | 147 ++++++++++++++++++++++++++++++++++++++++++- test/test_solvers.py | 98 +++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 8d6e10a5f..6eb19ea05 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -7,11 +7,11 @@ # # License: MIT License -from .utils import OTResult +from .utils import OTResult, unif, dist from .lp import emd2 from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log +from .bregman import sinkhorn_log, empirical_sinkhorn from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -20,6 +20,8 @@ entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2) from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 +from .bregman import empirical_sinkhorn + #, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 @@ -851,6 +853,147 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, +########## ot.solve_sample function ########### + +from .bregman import empirical_sinkhorn +from .utils import unif, dist + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + # Detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # Create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + if metric is not 'sqeuclidean': + raise (NotImplementedError('Only implemented for sqeuclidean metric')) + + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + + ############################################# + + else: + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + pass + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index e845ac7c2..77723ddd0 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -258,3 +258,101 @@ def test_solve_gromov_not_implemented(nx): + +######## Test functions for ot.solve_sample ######## + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) + + + From ed1b22d13da6082a69c099e59915fed3301ff208 Mon Sep 17 00:00:00 2001 From: laudavid Date: Fri, 3 Nov 2023 14:04:11 +0100 Subject: [PATCH 07/23] remove low rank from branch --- ot/__init__.py | 9 +- ot/lowrank.py | 247 ------------------------------------------- test/test_lowrank.py | 84 --------------- 3 files changed, 4 insertions(+), 336 deletions(-) delete mode 100644 ot/lowrank.py delete mode 100644 test/test_lowrank.py diff --git a/ot/__init__.py b/ot/__init__.py index 4aba450af..034875c55 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,7 +35,6 @@ from . import factored from . import solvers from . import gaussian -from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -51,8 +50,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov -from .lowrank import lowrank_sinkhorn +from .solvers import solve, solve_gromov, solve_sample + # utils functions from .utils import dist, unif, tic, toc, toq @@ -67,7 +66,7 @@ 'sinkhorn_unbalanced2', 'sliced_wasserstein_distance', 'sliced_wasserstein_sphere', 'gromov_wasserstein', 'gromov_wasserstein2', 'gromov_barycenters', 'fused_gromov_wasserstein', 'fused_gromov_wasserstein2', 'max_sliced_wasserstein_distance', 'weak_optimal_transport', - 'factored_optimal_transport', 'solve', 'solve_gromov', + 'factored_optimal_transport', 'solve', 'solve_gromov', 'solve_sample' 'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers', 'binary_search_circle', 'wasserstein_circle', - 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn'] + 'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif'] diff --git a/ot/lowrank.py b/ot/lowrank.py deleted file mode 100644 index d583f4741..000000000 --- a/ot/lowrank.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Low rank OT solvers -""" - -# Author: Laurène David -# -# License: MIT License - - - -################################################################################################################# -############################################## WORK IN PROGRESS ################################################# -################################################################################################################# - - -import warnings -from ot.utils import unif -from ot.backend import get_backend - - - -################################## LR-DYSKTRA ALGORITHM ########################################## - -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p, stopThr, nx=None): - """ - Implementation of the Dykstra algorithm for the Low rank sinkhorn solver - - """ - # Get dykstra parameters - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2 = dykstra_p - g_ = eps3.copy() - err = 1 - - # POT backend if needed - if nx is None: - nx = get_backend(eps1, eps2, eps3, p1, p2, - g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2) - - - # ------- Dykstra algorithm ------ - while err > stopThr : - u1 = p1 / nx.dot(eps1, v1_) - u2 = p2 / nx.dot(eps2, v2_) - - g = nx.maximum(alpha, g_ * q3_1) - q3_1 = (g_ * q3_1) / g - g_ = g.copy() - - prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) - prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) - g = (g_ * q3_2 * prod1 * prod2)**(1/3) - - v1 = g / nx.dot(eps1.T,u1) - v2 = g / nx.dot(eps2.T,u2) - q1 = (v1_ * q1) / v1 - - q2 = (v2_ * q2) / v2 - q3_2 = (g_ * q3_2) / g - - v1_, v2_ = v1.copy(), v2.copy() - g_ = g.copy() - - # Compute error - err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) - err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) - err = err1 + err2 - - # Compute low rank matrices Q, R - Q = u1[:,None] * eps1 * v1[None,:] - R = u2[:,None] * eps2 * v2[None,:] - - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - - return Q, R, dykstra_p - - - - -#################################### LOW RANK SINKHORN ALGORITHM ######################################### - - -def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, alpha="auto", - numItermax=1000, stopThr=1e-9, warn=True, verbose=False): #stopThr = 1e-9 - - r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. - This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. - - Parameters - ---------- - X_s : array-like, shape (n_samples_a, dim) - samples in the source domain - X_t : array-like, shape (n_samples_b, dim) - samples in the target domain - a : array-like, shape (n_samples_a,) - samples weights in the source domain - b : array-like, shape (n_samples_b,) - samples weights in the target domain - reg : float, optional - Regularization term >0 - rank: int, optional - Nonnegative rank of the OT plan - alpha: int, optional - Lower bound for the weight vector g (>0 and <1/r) - numItermax : int, optional - Max number of iterations - stopThr : float, optional - Stop threshold on error (>0) - warn: - - verbose: - - - Returns - ------- - Q : array-like, shape (n_samples_a, r) - First low-rank matrix decomposition of the OT plan - R: array-like, shape (n_samples_b, r) - Second low-rank matrix decomposition of the OT plan - g : array-like, shape (r, ) - Weight vector for the low-rank decomposition of the OT plan - - - References - ---------- - .. Scetbon, M., Cuturi, M., & Peyré, G (2021). - Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. - - ''' - # POT backend - nx = get_backend(X_s, X_t) - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = unif(ns, type_as=X_s) - if b is None: - b = unif(nt, type_as=X_t) - - d = X_s.shape[1] - - # First low rank decomposition of the cost matrix (A) - M1 = nx.zeros((ns,(d+2))) - M1[:,0] = [nx.norm(X_s[i,:])**2 for i in range(ns)] - M1[:,1] = nx.ones(ns) - M1[:,2:] = -2*X_s - - # Second low rank decomposition of the cost matrix (B) - M2 = nx.zeros((nt,(d+2))) - M2[:,0] = nx.ones(nt) - M2[:,1] = [nx.norm(X_t[i,:])**2 for i in range(nt)] - M2[:,2:] = X_t - - # Compute rank - rank = min(ns, nt, rank) - r = rank - - # Alpha: lower bound for 1/rank - if alpha == 'auto': - alpha = 1e-3 # no convergence with alpha = 1 / (r+1) - - if (1/r < alpha) or (alpha < 0): - warnings.warn("The provided alpha value might lead to instabilities.") - - # Compute gamma - L = nx.sqrt(3*(2/(alpha**4))*((nx.norm(M1)*nx.norm(M2))**2) + (reg + (2/(alpha**3))*(nx.norm(M1)*nx.norm(M2)))**2) - gamma = 1/(2*L) - - # Initialize the low rank matrices Q, R, g - Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) - - # Initialize parameters for Dykstra algorithm - q3_1, q3_2 = nx.ones(r), nx.ones(r) - u1, u2 = nx.ones(ns), nx.ones(nt) - v1, v2 = nx.ones(r), nx.ones(r) - v1_, v2_ = nx.ones(r), nx.ones(r) - q1, q2 = nx.ones(r), nx.ones(r) - dykstra_p = [g, q3_1, q3_2, v1_, v2_, q1, q2, u1, u2, v1, v2] - - - for ii in range(numItermax): - CR_ = nx.dot(M2.T, R) - CR = nx.dot(M1, CR_) - - CQ_ = nx.dot(M1.T, Q) - CQ = nx.dot(M2, CQ_) - - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(CQ,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p, stopThr, nx) - g = dykstra_p[0] - - # if verbose: - # if ii % 200 == 0: - # print( - # '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) - # print('{:5d}|{:8e}|'.format(ii, err)) - - # else: - # if warn: - # warnings.warn("Sinkhorn did not converge. You might want to " - # "increase the number of iterations `numItermax` " - # "or the regularization parameter `reg`.") - - - # Compute OT value using trace formula for scalar product - v1 = nx.dot(Q.T,M1) - v2 = nx.dot(R,nx.dot(diag_g.T,v1)) - value_linear = nx.sum(nx.diag(nx.dot(v2,M2.T))) # compute Trace - - #value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) - - #value - - return value_linear, Q, R, g - - - - - -############################################################################ -## Test with X_s, X_t from ot.datasets -############################################################################# - -import numpy as np -import ot - -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) - -ns = Xs.shape[0] -nt = Xt.shape[0] - -a = unif(ns) -b = unif(nt) - -Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, verbose=True, numItermax=20) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) - -print(np.sum(P)) - - - diff --git a/test/test_lowrank.py b/test/test_lowrank.py deleted file mode 100644 index 7d90ce9ef..000000000 --- a/test/test_lowrank.py +++ /dev/null @@ -1,84 +0,0 @@ -##################################################################################################### -####################################### WORK IN PROGRESS ############################################ -##################################################################################################### - - -""" Test for low rank sinkhorn solvers """ - -import ot -import numpy as np -import pytest -from itertools import product - - -def test_LR_Dykstra(): - # test for LR_Dykstra algorithm ? catch nan values ? - pass - - -# @pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) -# def test_lowrank_sinkhorn(verbose, warn): -# # test low rank sinkhorn -# n = 100 -# a = ot.unif(n) -# b = ot.unif(n) - -# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) -# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - -# Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) -# P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) - -# Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') -# P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) - -# # check constraints -# np.testing.assert_allclose( -# a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian -# np.testing.assert_allclose( -# a, P_m.sum(1), atol=1e-05) # metric euclidian -# np.testing.assert_allclose( -# b, P_m.sum(0), atol=1e-05) # metric euclidian - -# with pytest.warns(UserWarning): -# ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) - - - -@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) -def test_lowrank_sinkhorn_alpha_warning(alpha,rank): - # test warning for value of alpha - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - with pytest.warns(UserWarning): - ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) - - - -def test_lowrank_sinkhorn_backends(nx): - # test low rank sinkhorn for different backends - n = 100 - a = ot.unif(n) - b = ot.unif(n) - - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) - - Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) - P = np.dot(Q,np.dot(np.diag(1/g),R.T)) - - np.testing.assert_allclose(a, P.sum(1), atol=1e-05) - np.testing.assert_allclose(b, P.sum(0), atol=1e-05) - - - - From 6ea251c89ecf52603eb81c798a0769e9a2cb9f54 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 14:54:42 +0200 Subject: [PATCH 08/23] new file for lr sinkhorn --- ot/lowrank.py | 171 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 ot/lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py new file mode 100644 index 000000000..ba46cd1ed --- /dev/null +++ b/ot/lowrank.py @@ -0,0 +1,171 @@ +################################################################################################################# +############################################## WORK IN PROGRESS ################################################# +################################################################################################################# + + +from ot.utils import unif, list_to_array +from ot.backend import get_backend +from ot.datasets import make_1D_gauss as gauss + + + +################################## LR-DYSKTRA ALGORITHM ########################################## + +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): + """ + Implementation of the Dykstra algorithm for low rank Sinkhorn + """ + + # get dykstra parameters + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + + # POT backend + eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) + q3_1, q3_2, v1_, v2_, q1, q2 = list_to_array(q3_1, q3_2, v1_, v2_, q1, q2) + + nx = get_backend(eps1, eps2, eps3, p1, p2, q3_1, q3_2, v1_, v2_, q1, q2) + + # ------- Dykstra algorithm ------ + g_ = eps3 + + u1 = p1 / nx.dot(eps1, v1_) + u2 = p2 / nx.dot(eps2, v2_) + + g = nx.maximum(alpha, g_ * q3_1) + q3_1 = (g_ * q3_1) / g + g_ = g + + prod1 = ((v1_ * q1) * nx.dot(eps1.T, u1)) + prod2 = ((v2_ * q2) * nx.dot(eps2.T, u2)) + g = (g_ * q3_2 * prod1 * prod2)**(1/3) + + v1 = g / nx.dot(eps1.T,u1) + v2 = g / nx.dot(eps2.T,u2) + + q1 = (v1_ * q1) / v1 + q2 = (v2_ * q2) / v2 + q3_2 = (g_ * q3_2) / g + + v1_, v2_ = v1, v2 + g_ = g + + # Compute error + err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) + err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) + err = err1 + err2 + + # Compute low rank matrices Q, R + Q = u1[:,None] * eps1 * v1[None,:] + R = u2[:,None] * eps2 * v2[None,:] + + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + + return Q, R, g, err, dykstra_w + + + +#################################### LOW RANK SINKHORN ALGORITHM ######################################### + + +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): + r''' + Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + reg : float + Regularization term >0 + a : array-like, shape (n_samples_a,) + samples weights in the source domain + b : array-like, shape (n_samples_b,) + samples weights in the target domain + numItermax : int, optional + Max number of iterations + stopThr : float, optional + Stop threshold on error (>0) + + Returns + ------- + Q : array-like, shape (n_samples_a, r) + First low-rank matrix decomposition of the OT plan + R: array-like, shape (n_samples_b, r) + Second low-rank matrix decomposition of the OT plan + g : array-like, shape (r, ) + ... + + ''' + + X_s, X_t = list_to_array(X_s, X_t) + nx = get_backend(X_s, X_t) + + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + M = ot.dist(X_s,X_t, metric=metric) + + # Compute rank + r = min(ns, nt, r) + + # Compute gamma + L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + gamma = 1/(2*L) + + # Initialisation + Q, R, g = nx.ones((ns,r)), nx.ones((nt,r)), nx.ones(r) + q3_1, q3_2 = nx.ones(r), nx.ones(r) + v1_, v2_ = nx.ones(r), nx.ones(r) + q1, q2 = nx.ones(r), nx.ones(r) + dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + n_iter = 0 + err = 1 + + while n_iter < numIterMax: + if err > stopThr: + n_iter = n_iter + 1 + + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + else: + break + + return Q, R, g + + + + + +############################################################################ +## Test with X_s, X_t from ot.datasets +############################################################################# + +import numpy as np +import ot + +Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) + + +Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +M = ot.dist(Xs,Xt) +P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + +print(np.sum(P)) + + + + From 965e4d69113f6fe8eab106412b652dabdbc05712 Mon Sep 17 00:00:00 2001 From: laudavid Date: Tue, 24 Oct 2023 16:47:21 +0200 Subject: [PATCH 09/23] lr sinkhorn, solve_sample, OTResultLazy --- ot/lowrank.py | 40 +++++++------ ot/solvers.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++++++ ot/utils.py | 3 +- 3 files changed, 183 insertions(+), 21 deletions(-) diff --git a/ot/lowrank.py b/ot/lowrank.py index ba46cd1ed..a1c73bdf3 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -2,8 +2,10 @@ ############################################## WORK IN PROGRESS ################################################# ################################################################################################################# +## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms -from ot.utils import unif, list_to_array + +from ot.utils import unif, list_to_array, dist from ot.backend import get_backend from ot.datasets import make_1D_gauss as gauss @@ -11,13 +13,13 @@ ################################## LR-DYSKTRA ALGORITHM ########################################## -def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): +def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ Implementation of the Dykstra algorithm for low rank Sinkhorn """ # get dykstra parameters - q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_w + q3_1, q3_2, v1_, v2_, q1, q2 = dykstra_p # POT backend eps1, eps2, eps3, p1, p2 = list_to_array(eps1, eps2, eps3, p1, p2) @@ -58,18 +60,18 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_w): Q = u1[:,None] * eps1 * v1[None,:] R = u2[:,None] * eps2 * v2[None,:] - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - return Q, R, g, err, dykstra_w + return Q, R, g, err, dykstra_p #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative low rank constraints + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. Parameters ---------- @@ -95,7 +97,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - ... + Third low-rank matrix decomposition of the OT plan ''' @@ -108,7 +110,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) - M = ot.dist(X_s,X_t, metric=metric) + M = dist(X_s,X_t, metric=metric) # Compute rank r = min(ns, nt, r) @@ -122,7 +124,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', q3_1, q3_2 = nx.ones(r), nx.ones(r) v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) - dykstra_w = [q3_1, q3_2, v1_, v2_, q1, q2] + dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] n_iter = 0 err = 1 @@ -139,7 +141,7 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', omega = nx.diag(nx.dot(Q.T, CR)) eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - Q, R, g, err, dykstra_w = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_w) + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) else: break @@ -153,18 +155,18 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=4, metric='sqeuclidean', ## Test with X_s, X_t from ot.datasets ############################################################################# -import numpy as np -import ot +# import numpy as np +# import ot -Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) -Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) +# Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) -Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) -M = ot.dist(Xs,Xt) -P = np.dot(Q,np.dot(np.diag(1/g),R.T)) +# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# M = ot.dist(Xs,Xt) +# P = np.dot(Q,np.dot(np.diag(1/g),R.T)) -print(np.sum(P)) +# print(np.sum(P)) diff --git a/ot/solvers.py b/ot/solvers.py index 0313cf588..9c2746c25 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -848,3 +848,164 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) return res + + + + + + +################################## WORK IN PROGRESS ##################################### + +## Implementation of the ot.solve_sample function +## Function isn't complete, still work in progress for reg == 0 / reg is None case (and unbalanced cases) + + +from .utils import unif, list_to_array, dist, OTResultLazy +from .bregman import empirical_sinkhorn + + +def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): + + r"""Solve the discrete optimal transport problem using the samples in the source and target domains. + It returns either a :any:`OTResult` or :any:`OTResultLazy` object. + + The function solves the following general optimal transport problem + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : array-like, shape (n_samples_a, dim) + samples in the source domain + X_t : array-like, shape (n_samples_b, dim) + samples in the target domain + a : array-like, shape (dim_a,), optional + Samples weights in the source domain (default is uniform) + b : array-like, shape (dim_b,), optional + Samples weights in the source domain (default is uniform) + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float, optional + Unbalanced penalization weight :math:`\lambda_u`, by default None + (balanced OT) + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + is_Lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + plan_init : array_like, shape (dim_a, dim_b), optional + Initialization of the OT plan for iterative methods, by default None + potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional + Initialization of the OT dual potentials for iterative methods, by default None + tol : _type_, optional + Tolerance for solution precision, by default None (default values in each solvers) + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res_lazy : OTResultLazy() + Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. + The information can be obtained as follows: + + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.potentials : OT dual potentials + + See :any:`OTResultLazy` for more information. + + res : OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + + See :any:`OTResult` for more information. + + + """ + + X_s, X_t = list_to_array(X_s,X_t) + + # detect backend + arr = [X_s,X_t] + if a is not None: + arr.append(a) + if b is not None: + arr.append(b) + nx = get_backend(*arr) + + # create uniform weights if not given + ns, nt = X_s.shape[0], X_t.shape[0] + if a is None: + a = nx.from_numpy(unif(ns), type_as=X_s) + if b is None: + b = nx.from_numpy(unif(nt), type_as=X_s) + + # default values for solutions + potentials = None + lazy_plan = None + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + if is_Lazy: + ################# WIP #################### + if reg is None or reg == 0: # EMD solver for isLazy ? + if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) + pass + elif unbalanced_type.lower() in ['kl', 'l2']: + pass + elif unbalanced_type.lower() == 'tv': + pass + pass + ############################################# + + else: + # compute potentials + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + # compute cost matrix M and use solve function + M = dist(X_s, X_t, metric) + + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + return res + + + + diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..2f4cfc9e7 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1168,7 +1168,6 @@ def citation(self): } """ - class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1233,4 +1232,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file From fd5e26d86e484f310f55a792bca89de13bd7340f Mon Sep 17 00:00:00 2001 From: laudavid Date: Wed, 25 Oct 2023 17:39:08 +0200 Subject: [PATCH 10/23] add test functions + small modif lr_sin/solve_sample --- ot/lowrank.py | 97 ++++++++++++++++++++++++++++------------- ot/solvers.py | 47 +++++++++++--------- test/test_lowrank.py | 84 ++++++++++++++++++++++++++++++++++++ test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 1 + 5 files changed, 278 insertions(+), 51 deletions(-) create mode 100644 test/test_lowrank.py diff --git a/ot/lowrank.py b/ot/lowrank.py index a1c73bdf3..22ff8b754 100644 --- a/ot/lowrank.py +++ b/ot/lowrank.py @@ -4,10 +4,9 @@ ## Implementation of the LR-Dykstra algorithm and low rank sinkhorn algorithms - -from ot.utils import unif, list_to_array, dist -from ot.backend import get_backend -from ot.datasets import make_1D_gauss as gauss +import warnings +from .utils import unif, list_to_array, dist +from .backend import get_backend @@ -15,7 +14,7 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): """ - Implementation of the Dykstra algorithm for low rank Sinkhorn + Implementation of the Dykstra algorithm for low rank sinkhorn """ # get dykstra parameters @@ -69,9 +68,12 @@ def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, dykstra_p): #################################### LOW RANK SINKHORN ALGORITHM ######################################### -def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', alpha=1e-10, numIterMax=10000, stopThr=1e-20): +def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank=2, metric='sqeuclidean', alpha="auto", + numItermax=10000, stopThr=1e-9, warn=True, verbose=False): r''' - Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints on the feasible couplings. + Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. + + This function returns the two low-rank matrix decomposition of the OT plan (Q,R), as well as the weight vector g. Parameters ---------- @@ -79,17 +81,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain - reg : float - Regularization term >0 a : array-like, shape (n_samples_a,) samples weights in the source domain b : array-like, shape (n_samples_b,) samples weights in the target domain + reg : float, optional + Regularization term >0 + rank: int, optional + Nonnegative rank of the OT plan + alpha: int, optional + Lower bound for the weight vector g (>0 and <1/r) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) + Returns ------- Q : array-like, shape (n_samples_a, r) @@ -97,7 +104,14 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', R: array-like, shape (n_samples_b, r) Second low-rank matrix decomposition of the OT plan g : array-like, shape (r, ) - Third low-rank matrix decomposition of the OT plan + Weight vector for the low-rank decomposition of the OT plan + + + References + ---------- + + .. Scetbon, M., Cuturi, M., & Peyré, G (2021). + Low-Rank Sinkhorn Factorization. arXiv preprint arXiv:2103.04737. ''' @@ -110,13 +124,22 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', if b is None: b = nx.from_numpy(unif(nt), type_as=X_s) + # Compute cost matrix M = dist(X_s,X_t, metric=metric) - + # Compute rank - r = min(ns, nt, r) + rank = min(ns, nt, rank) + r = rank + + if alpha == 'auto': + alpha = 1.0 / (r + 1) + + if (1/r < alpha) or (alpha < 0): + warnings.warn("The provided alpha value might lead to instabilities.") + # Compute gamma - L = nx.sqrt((2/(alpha**4))*nx.norm(M)**2 + (reg + (2/(alpha**3))*nx.norm(M))**2) + L = nx.sqrt((2/(alpha**4))*(nx.norm(M)**2) + (reg + (2/(alpha**3))*(nx.norm(M))**2)) gamma = 1/(2*L) # Initialisation @@ -125,25 +148,34 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', v1_, v2_ = nx.ones(r), nx.ones(r) q1, q2 = nx.ones(r), nx.ones(r) dykstra_p = [q3_1, q3_2, v1_, v2_, q1, q2] - n_iter = 0 err = 1 - while n_iter < numIterMax: - if err > stopThr: - n_iter = n_iter + 1 - - CR = nx.dot(M,R) - C_t_Q = nx.dot(M.T,Q) - diag_g = (1/g)[:,None] - - eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) - eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) - omega = nx.diag(nx.dot(Q.T, CR)) - eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) - - Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) - else: + for ii in range(numItermax): + CR = nx.dot(M,R) + C_t_Q = nx.dot(M.T,Q) + diag_g = (1/g)[:,None] + + eps1 = nx.exp(-gamma*(nx.dot(CR,diag_g)) - ((gamma*reg)-1)*nx.log(Q)) + eps2 = nx.exp(-gamma*(nx.dot(C_t_Q,diag_g)) - ((gamma*reg)-1)*nx.log(R)) + omega = nx.diag(nx.dot(Q.T, CR)) + eps3 = nx.exp(gamma*omega/(g**2) - (gamma*reg - 1)*nx.log(g)) + + Q, R, g, err, dykstra_p = LR_Dysktra(eps1, eps2, eps3, a, b, alpha, dykstra_p) + + if err < stopThr: break + + if verbose: + if ii % 200 == 0: + print( + '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) + print('{:5d}|{:8e}|'.format(ii, err)) + + else: + if warn: + warnings.warn("Sinkhorn did not converge. You might want to " + "increase the number of iterations `numItermax` " + "or the regularization parameter `reg`.") return Q, R, g @@ -161,8 +193,13 @@ def lowrank_sinkhorn(X_s, X_t, reg=0, a=None, b=None, r=2, metric='sqeuclidean', # Xs, _ = ot.datasets.make_data_classif('3gauss', n=1000) # Xt, _ = ot.datasets.make_data_classif('3gauss2', n=1500) +# ns = Xs.shape[0] +# nt = Xt.shape[0] + +# a = unif(ns) +# b = unif(nt) -# Q, R, g = lowrank_sinkhorn(Xs,Xt,reg=0.1) +# Q, R, g = lowrank_sinkhorn(Xs, Xt, reg=0.1, metric='euclidean', verbose=True, numItermax=100) # M = ot.dist(Xs,Xt) # P = np.dot(Q,np.dot(np.diag(1/g),R.T)) diff --git a/ot/solvers.py b/ot/solvers.py index 9c2746c25..c176969ca 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -926,7 +926,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t Result of the optimization problem. This class only returns a partial OT plan and the OT dual potentials to reduce memory costs. The information can be obtained as follows: - - res.lazy_plan : OT plan computed on a subsample of X_s and X_t :math:`\mathbf{T}` + - res.lazy_plan : OT plan computed on a subsample of X_s and X_t - res.potentials : OT dual potentials See :any:`OTResultLazy` for more information. @@ -975,29 +975,34 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if is_Lazy: ################# WIP #################### if reg is None or reg == 0: # EMD solver for isLazy ? - if unbalanced is None: # not sure "unbalanced" parameter is needed here ? (since we won't compute value) - pass - elif unbalanced_type.lower() in ['kl', 'l2']: - pass - elif unbalanced_type.lower() == 'tv': - pass - pass + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Not implemented balanced with no regularization')) + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) + + ############################################# else: - # compute potentials - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) - potentials = (log["u"], log["v"]) - - # compute lazy_plan - ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) - M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) - K = nx.exp(M / (-reg)) - lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) - - res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) - return res_lazy + if unbalanced is None: + u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + # compute potentials + potentials = (log["u"], log["v"]) + + # compute lazy_plan + ns_lazy, nt_lazy = 100, 100 # size of the lazy_plan (subplan) + M = dist(X_s[:ns_lazy,:], X_t[:nt_lazy,:], metric) + K = nx.exp(M / (-reg)) + lazy_plan = u[:ns_lazy].reshape((-1, 1)) * K * v[:nt_lazy].reshape((1, -1)) + + res_lazy = OTResultLazy(potentials=potentials, lazy_plan=lazy_plan, backend=nx) + return res_lazy + + else: + raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) else: # compute cost matrix M and use solve function diff --git a/test/test_lowrank.py b/test/test_lowrank.py new file mode 100644 index 000000000..6e1f24067 --- /dev/null +++ b/test/test_lowrank.py @@ -0,0 +1,84 @@ +##################################################################################################### +####################################### WORK IN PROGRESS ############################################ +##################################################################################################### + + +""" Test for low rank sinkhorn solvers """ + +import ot +import numpy as np +import pytest +from itertools import product + + +def test_LR_Dykstra(): + # test for LR_Dykstra algorithm ? catch nan values ? + pass + + +@pytest.mark.parametrize("verbose, warn", product([True, False], [True, False])) +def test_lowrank_sinkhorn(verbose, warn): + # test low rank sinkhorn + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + Q_sqe, R_sqe, g_sqe = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1) + P_sqe = np.dot(Q_sqe,np.dot(np.diag(1/g_sqe),R_sqe.T)) + + Q_m, R_m, g_m = ot.lowrank.lowrank_sinkhorn(X_s, X_t, a, b, 0.1, metric='euclidean') + P_m = np.dot(Q_m,np.dot(np.diag(1/g_m),R_m.T)) + + # check constraints + np.testing.assert_allclose( + a, P_sqe.sum(1), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + b, P_sqe.sum(0), atol=1e-05) # metric sqeuclidian + np.testing.assert_allclose( + a, P_m.sum(1), atol=1e-05) # metric euclidian + np.testing.assert_allclose( + b, P_m.sum(0), atol=1e-05) # metric euclidian + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, stopThr=0, numItermax=1) + + + +@pytest.mark.parametrize(("alpha, rank"),((0.8,2),(0.5,3),(0.2,4))) +def test_lowrank_sinkhorn_alpha_warning(alpha,rank): + # test warning for value of alpha + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + with pytest.warns(UserWarning): + ot.lowrank.lowrank_sinkhorn(X_s, X_t, 0.1, a, b, r=rank, alpha=alpha, warn=False) + + + +def test_lowrank_sinkhorn_backends(nx): + # test low rank sinkhorn for different backends + n = 100 + a = ot.unif(n) + b = ot.unif(n) + + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + ab, bb, X_sb, X_tb = nx.from_numpy(a, b, X_s, X_t) + + Q, R, g = nx.to_numpy(ot.lowrank.lowrank_sinkhorn(X_sb, X_tb, ab, bb, 0.1)) + P = np.dot(Q,np.dot(np.diag(1/g),R.T)) + + np.testing.assert_allclose(a, P.sum(1), atol=1e-05) + np.testing.assert_allclose(b, P.sum(0), atol=1e-05) + + + + diff --git a/test/test_solvers.py b/test/test_solvers.py index f0f5b638f..5a05d54cf 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -255,3 +255,103 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=1.5) with pytest.raises(NotImplementedError): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) + + + + +########################################################################################################### +############################################ WORK IN PROGRESS ############################################# +########################################################################################################### + +def assert_allclose_sol_sample(sol1, sol2): + # test attributes of OTResultLazy class + lst_attr = ['potentials','potential_a', 'potential_b', 'lazy_plan'] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + try: + np.allclose(nx1.to_numpy(getattr(sol1, attr)), nx2.to_numpy(getattr(sol2, attr))) + except NotImplementedError: + pass + + +@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) +def test_solve_sample(nx): + # test solve_sample when is_Lazy = False + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b) + + # check some attributes + sol.potentials + sol.sparse_plan + sol.marginals + sol.status + + assert_allclose_sol(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb) + + assert_allclose_sol(sol, solb) + + # test not implemented unbalanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced=1, unbalanced_type='cryptic divergence') + + # test not implemented reg_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') + + + +def test_lazy_solve_sample(nx): + # test solve_sample when is_Lazy = True + n = 100 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + # solve unif weights + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) + + # check some attributes + sol.potentials + sol.lazy_plan + + assert_allclose_sol_sample(sol0, sol) + + # solve in backend + X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) + solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) + + assert_allclose_sol_sample(sol, solb) + + # test not implemented reg==0 (or None) + balanced and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default + + # test not implemented reg==0 (or None) + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default + + # test not implemented reg != 0 + unbalanced_type and check raise + with pytest.raises(NotImplementedError): + sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) \ No newline at end of file diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..942f403ce 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -569,3 +569,4 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + \ No newline at end of file From 3df3b77de2605d233224e0ccefa6ee127af9f040 Mon Sep 17 00:00:00 2001 From: laudavid Date: Thu, 26 Oct 2023 10:49:23 +0200 Subject: [PATCH 11/23] add import to __init__ --- ot/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ot/__init__.py b/ot/__init__.py index f16b6fcfc..cb00f4553 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -35,6 +35,7 @@ from . import factored from . import solvers from . import gaussian +from . import lowrank # OT functions from .lp import (emd, emd2, emd_1d, emd2_1d, wasserstein_1d, @@ -50,7 +51,8 @@ gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov +from .solvers import solve, solve_gromov, solve_sample +from .lowrank import lowrank_sinkhorn # utils functions from .utils import dist, unif, tic, toc, toq From 9403851ec57d66e0179f7f178ce6397410e4b8c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 7 Nov 2023 09:53:03 +0100 Subject: [PATCH 12/23] clean ot.solve_sample and remve lazy test cause not ilplemented yet --- ot/solvers.py | 120 +++++++++++++++++++++---------------------- ot/utils.py | 3 +- test/test_solvers.py | 73 +++++++++++++------------- test/test_utils.py | 1 - 4 files changed, 95 insertions(+), 102 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 4b261b5cc..239ba91e0 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -22,11 +22,6 @@ from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 - - -#, entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2 - - def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False): @@ -853,14 +848,9 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, return res - - -##### new ot.solve_sample function - -def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', is_Lazy=False, batch_size=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): - +def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, + potentials_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -870,6 +860,10 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + where the cost matrix :math:`\mathbf{M}` is computed from the samples in the + source and target domains wuch that :math:`M_{i,j} = d(x_i,y_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and @@ -881,7 +875,7 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t X_s : array-like, shape (n_samples_a, dim) samples in the source domain X_t : array-like, shape (n_samples_b, dim) - samples in the target domain + samples in the target domain a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional @@ -896,8 +890,16 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" - is_Lazy : bool, optional - Return :any:`OTResultlazy` object to reduce memory cost when True, by default False + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unalanced problems (see :any:`ot.solve`), or to select a specific + lazy large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional @@ -925,68 +927,62 @@ def solve_sample(X_s, X_t, a=None, b=None, metric='sqeuclidean', reg=None, reg_t See :any:`OTResult` for more information. """ - - # Detect backend - arr = [X_s,X_t] - if a is not None: - arr.append(a) - if b is not None: - arr.append(b) - nx = get_backend(*arr) - # create uniform weights if not given - ns, nt = X_s.shape[0], X_t.shape[0] - if a is None: - a = nx.from_numpy(unif(ns), type_as=X_s) - if b is None: - b = nx.from_numpy(unif(nt), type_as=X_s) + if method is not None and method.lower() in ['1d', 'sliced', 'lowrank', 'factored']: + lazy = True - if metric != 'sqeuclidean': - raise (NotImplementedError('Not implemented metric = {} (only sqeulidean)'.format(metric))) + if not lazy: # default non lazy solver calls ot.solve + # compute cost matrix M and use solve function + M = dist(X_a, X_b, metric) - # default values for solutions - potentials = None - lazy_plan = None + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + + return res + + else: - if max_iter is None: - max_iter = 1000 - if tol is None: - tol = 1e-9 - if batch_size is None: - batch_size = 100 + # Detect backend + nx = get_backend(X_a, X_b, a, b) + + # default values for solutions + potentials = None + value = None + value_linear = None + plan = None + lazy_plan = None + status = None - if is_Lazy: ################# WIP #################### - if reg is None or reg == 0: # EMD solver for isLazy ? - - if unbalanced is None: # balanced EMD solver for isLazy ? - raise (NotImplementedError('Not implemented balanced with no regularization')) - + if reg is None or reg == 0: # EMD solver for isLazy ? + + if unbalanced is None: # balanced EMD solver for isLazy ? + raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) + else: - raise (NotImplementedError('Not implemented unbalanced_type="{}" with no regularization'.format(unbalanced_type))) - + raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) ############################################# - - else: + + else: if unbalanced is None: - u, v, log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric='sqeuclidean', numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + + if max_iter is None: + max_iter = 1000 + if tol is None: + tol = 1e-9 + if batch_size is None: + batch_size = 100 + + u, v, log = empirical_sinkhorn(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) # compute potentials - potentials = (log["u"], log["v"]) + potentials = (u, v) # compute lazy_plan # ... raise (NotImplementedError('Not implemented balanced with regularization')) - + else: raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) - - else: - # compute cost matrix M and use solve function - M = dist(X_s, X_t, metric) - - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) - return res \ No newline at end of file diff --git a/ot/utils.py b/ot/utils.py index 2f4cfc9e7..0936648ca 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1168,6 +1168,7 @@ def citation(self): } """ + class LazyTensor(object): """ A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. @@ -1232,4 +1233,4 @@ def __getitem__(self, key): return self._getitem(*k, **self.kwargs) def __repr__(self): - return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) \ No newline at end of file + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) diff --git a/test/test_solvers.py b/test/test_solvers.py index 18572b90a..0a9ad6a93 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -257,12 +257,9 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) - - ######## Test functions for ot.solve_sample ######## -@pytest.mark.parametrize("reg,reg_type,unbalanced,unbalanced_type", itertools.product(lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type)) def test_solve_sample(nx): # test solve_sample when is_Lazy = False n = 100 @@ -272,8 +269,13 @@ def test_solve_sample(nx): a = ot.utils.unif(X_s.shape[0]) b = ot.utils.unif(X_t.shape[0]) + M = ot.dist(X_s, X_t) + + # solve with ot.solve + sol00 = ot.solve(M, a, b) + # solve unif weights - sol0 = ot.solve_sample(X_s, X_t) + sol0 = ot.solve_sample(X_s, X_t) # solve signe weights sol = ot.solve_sample(X_s, X_t, a, b) @@ -285,6 +287,7 @@ def test_solve_sample(nx): sol.status assert_allclose_sol(sol0, sol) + assert_allclose_sol(sol0, sol00) # solve in backend X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) @@ -301,47 +304,41 @@ def test_solve_sample(nx): sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') +# def test_lazy_solve_sample(nx): +# # test solve_sample when is_Lazy = True +# n = 100 +# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) +# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) -def test_lazy_solve_sample(nx): - # test solve_sample when is_Lazy = True - n = 100 - X_s = np.reshape(1.0 * np.arange(n), (n, 1)) - X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - - a = ot.utils.unif(X_s.shape[0]) - b = ot.utils.unif(X_t.shape[0]) - - # solve unif weights - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, is_Lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - - # solve signe weights - sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, is_Lazy=True) - - # check some attributes - sol.potentials - sol.lazy_plan +# a = ot.utils.unif(X_s.shape[0]) +# b = ot.utils.unif(X_t.shape[0]) - assert_allclose_sol(sol0, sol) +# # solve unif weights +# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True - # solve in backend - X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) - solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, is_Lazy=True) - - assert_allclose_sol(sol, solb) +# # solve signe weights +# sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, lazy=True) - # test not implemented reg==0 (or None) + balanced and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, is_Lazy=True) # reg == 0 (or None) + unbalanced= None are default +# # check some attributes +# sol.potentials +# sol.lazy_plan - # test not implemented reg==0 (or None) + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", is_Lazy=True) # reg == 0 (or None) is default - - # test not implemented reg != 0 + unbalanced_type and check raise - with pytest.raises(NotImplementedError): - sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", is_Lazy=True) +# assert_allclose_sol(sol0, sol) +# # solve in backend +# X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) +# solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, lazy=True) +# assert_allclose_sol(sol, solb) +# # test not implemented reg==0 (or None) + balanced and check raise +# with pytest.raises(NotImplementedError): +# sol0 = ot.solve_sample(X_s, X_t, lazy=True) # reg == 0 (or None) + unbalanced= None are default +# # test not implemented reg==0 (or None) + unbalanced_type and check raise +# with pytest.raises(NotImplementedError): +# sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", lazy=True) # reg == 0 (or None) is default +# # test not implemented reg != 0 + unbalanced_type and check raise +# with pytest.raises(NotImplementedError): +# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", lazy=True) diff --git a/test/test_utils.py b/test/test_utils.py index 942f403ce..3a9d590ab 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -569,4 +569,3 @@ def test_lowrank_LazyTensor(nx): T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) - \ No newline at end of file From 411f3023116ccd023b38bbd7d4e8e2f0a2fd745e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 7 Nov 2023 11:27:30 +0100 Subject: [PATCH 13/23] add factored and gaussian solvers --- ot/solvers.py | 75 ++++++++++++++++++++++++++++++++++++-------- ot/utils.py | 8 +++++ test/test_solvers.py | 3 -- 3 files changed, 70 insertions(+), 16 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 239ba91e0..c822c87b7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -7,8 +7,8 @@ # # License: MIT License -from .utils import OTResult, unif, dist -from .lp import emd2 +from .utils import OTResult, dist +from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced from .bregman import sinkhorn_log, empirical_sinkhorn @@ -20,6 +20,8 @@ entropic_semirelaxed_fused_gromov_wasserstein2, entropic_semirelaxed_gromov_wasserstein2) from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2 +from .gaussian import empirical_bures_wasserstein_distance +from .factored import factored_optimal_transport def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, @@ -599,6 +601,7 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, value_quad = None plan = None status = None + log = None loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} @@ -843,14 +846,14 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, - value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx) + value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) return res def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, - potentials_init=None, tol=None, verbose=False): + unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, + potentials_init=None, X_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem @@ -898,8 +901,8 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t solvers) method : str, optional Method for solving the problem, this can be used to select the solver - for unalanced problems (see :any:`ot.solve`), or to select a specific - lazy large scale solver. + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional @@ -928,7 +931,8 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t """ - if method is not None and method.lower() in ['1d', 'sliced', 'lowrank', 'factored']: + if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: + lazy0 = lazy lazy = True if not lazy: # default non lazy solver calls ot.solve @@ -952,18 +956,59 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t plan = None lazy_plan = None status = None + log = None - ################# WIP #################### - if reg is None or reg == 0: # EMD solver for isLazy ? + if method.lower() == '1d': # Wasserstein 1d (parallel on all dimensions) + if metric == 'sqeuclidean': + p = 2 + elif metric in ['euclidean', 'cityblock']: + p = 1 + else: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + value = wasserstein_1d(X_a, X_b, a, b, p=p) + value_linear = value + + elif method.lower() == 'gaussian': # Gaussian Bures-Wasserstein + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if reg is None: + reg = 1e-6 - if unbalanced is None: # balanced EMD solver for isLazy ? + value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) + value = value**2 # return the value (squared bures distance) + value_linear = value # return the value + + elif method.lower() == 'factored': # Factored OT + + if not metric.lower() in ['sqeuclidean']: + raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) + + if max_iter is None: + max_iter = 100 + if tol is None: + tol = 1e-7 + if reg is None: + reg = 0 + + Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) + log['X'] = X + + value_linear = log['costa'] + log['costb'] + lazy_plan = log['lazy_plan'] + if not lazy0: # store plan if not lazy + plan = lazy_plan[:] + + elif reg is None or reg == 0: # exact OT + + if unbalanced is None: # balanced EMD solver not available for lazy raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) else: raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) - ############################################# - else: if unbalanced is None: @@ -986,3 +1031,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t else: raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) + + res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, + value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) + return res diff --git a/ot/utils.py b/ot/utils.py index 0936648ca..f64c2fea6 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1115,6 +1115,14 @@ def status(self): else: raise NotImplementedError() + @property + def log(self): + """Dictionary containing potential information about the solver.""" + if self._log is not None: + return self._log + else: + raise NotImplementedError() + # Barycentric mappings ------------------------- # Return the displacement vectors as an array # that has the same shape as "xa"/"xb" (for samples) diff --git a/test/test_solvers.py b/test/test_solvers.py index 0a9ad6a93..c3db98b0c 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -257,9 +257,6 @@ def test_solve_gromov_not_implemented(nx): ot.solve_gromov(Ca, Cb, reg=1, unbalanced_type='partial', unbalanced=0.5, symmetric=False) -######## Test functions for ot.solve_sample ######## - - def test_solve_sample(nx): # test solve_sample when is_Lazy = False n = 100 From 5812de2fca942d5f196a3e2ea2a1bc56364b859b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 9 Nov 2023 10:07:15 +0100 Subject: [PATCH 14/23] workin lazy sinkhorn with lazy tensor returned --- ot/bregman.py | 45 +++++++++++++++++++++++++++++++++++++++++++- ot/gaussian.py | 2 +- ot/solvers.py | 23 +++++++++++----------- test/test_solvers.py | 35 ++++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 14 deletions(-) diff --git a/ot/bregman.py b/ot/bregman.py index c90d89986..6861e9867 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -20,7 +20,7 @@ import numpy as np from scipy.optimize import fmin_l_bfgs_b -from ot.utils import dist, list_to_array, unif +from ot.utils import dist, list_to_array, unif, LazyTensor from .backend import get_backend @@ -3075,6 +3075,8 @@ def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', if log: dict_log["u"] = f dict_log["v"] = g + dict_log["niter"] = i_ot + dict_log["lazy_plan"] = get_sinkhorn_lazytensor(X_s, X_t, f, g, metric, reg) return (f, g, dict_log) else: return (f, g) @@ -3792,3 +3794,44 @@ def obj(theta): return gamma, log else: return gamma + + +def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None): + """ Get a LazyTensor of sinkhorn solution T = exp(f+g^T-C/reg) + + Parameters + ---------- + X_a : array-like, shape (n_samples_a, dim) + samples in the source domain + X_b : array-like, shape (n_samples_b, dim) + samples in the target domain + f : array-like, shape (n_samples_a,) + First dual potentials (log space) + g : array-like, shape (n_samples_b,) + Second dual potentials (log space) + metric : str, default='sqeuclidean' + Metric used for the cost matrix computation + reg : float, default=1e-1 + Regularization term >0 + nx : Backend(), default=None + Numerical backend used + + + Returns + ------- + T : LazyTensor + Lowrank tensor T = exp(f+g^T-C/reg) + """ + + if nx is None: + nx = get_backend(X_a, X_b, f, g) + + shape = (X_a.shape[0], X_b.shape[0]) + + def func(i, j, X_a, X_b, f, g, metric, reg): + C = dist(X_a[i], X_b[j], metric=metric) + return nx.exp(f[i, None] + g[None, j] - C / reg) + + T = LazyTensor(shape, func, X_a=X_a, X_b=X_b, f=f, g=g, metric=metric, reg=reg) + + return T diff --git a/ot/gaussian.py b/ot/gaussian.py index 708f9eb16..0ddb92013 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -249,7 +249,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): Cs12 = nx.sqrtm(Cs) B = nx.trace(Cs + Ct - 2 * nx.sqrtm(dots(Cs12, Ct, Cs12))) - W = nx.sqrt(nx.norm(ms - mt)**2 + B) + W = nx.sqrt(nx.maximum(nx.norm(ms - mt)**2 + B, 0)) if log: log = {} diff --git a/ot/solvers.py b/ot/solvers.py index c822c87b7..a7e66931b 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -11,7 +11,7 @@ from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced -from .bregman import sinkhorn_log, empirical_sinkhorn +from .bregman import sinkhorn_log, empirical_sinkhorn2 from .partial import partial_wasserstein_lagrange from .smooth import smooth_ot_dual from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2, @@ -958,7 +958,9 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t status = None log = None - if method.lower() == '1d': # Wasserstein 1d (parallel on all dimensions) + method = method.lower() if method is not None else '' + + if method == '1d': # Wasserstein 1d (parallel on all dimensions) if metric == 'sqeuclidean': p = 2 elif metric in ['euclidean', 'cityblock']: @@ -969,7 +971,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t value = wasserstein_1d(X_a, X_b, a, b, p=p) value_linear = value - elif method.lower() == 'gaussian': # Gaussian Bures-Wasserstein + elif method == 'gaussian': # Gaussian Bures-Wasserstein if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) @@ -981,7 +983,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t value = value**2 # return the value (squared bures distance) value_linear = value # return the value - elif method.lower() == 'factored': # Factored OT + elif method == 'factored': # Factored OT if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) @@ -997,6 +999,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t log['X'] = X value_linear = log['costa'] + log['costb'] + value = value_linear # TODO add reg term lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] @@ -1019,15 +1022,11 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t if batch_size is None: batch_size = 100 - u, v, log = empirical_sinkhorn(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, - isLazy=True, batchSize=batch_size, verbose=verbose, log=True) + value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, + isLazy=True, batchSize=batch_size, verbose=verbose, log=True) # compute potentials - potentials = (u, v) - - # compute lazy_plan - # ... - - raise (NotImplementedError('Not implemented balanced with regularization')) + potentials = (log["u"], log["v"]) + lazy_plan = log['lazy_plan'] else: raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) diff --git a/test/test_solvers.py b/test/test_solvers.py index c3db98b0c..d11cd60eb 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,15 @@ lst_unbalanced_gromov = [None, 0.9] lst_alpha = [0, 0.4, 0.9, 1] +lst_method_params_solve_sample = [ + {'method': '1d'}, + {'method': 'gaussian'}, + {'method': 'gaussian', 'reg': 1}, + {'method': 'factored', 'rank': 10}, +] +# set readable ids for each param +lst_method_params_solve_sample = [pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample] + def assert_allclose_sol(sol1, sol2): @@ -301,6 +310,32 @@ def test_solve_sample(nx): sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') +@pytest.mark.parametrize("method_params", lst_method_params_solve_sample) +def test_solve_sample_methods(nx, method_params): + + n_samples_s = 20 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) + + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) + + sol = ot.solve_sample(x, y, **method_params) + solb = ot.solve_sample(xb, yb, ab, bb, **method_params) + + # check some attributes (no need ) + assert_allclose_sol(sol, solb) + + sol2 = ot.solve_sample(x, x, **method_params) + if method_params['method'] != 'factored': + np.testing.assert_allclose(sol2.value, 0) + + # def test_lazy_solve_sample(nx): # # test solve_sample when is_Lazy = True # n = 100 From 26c6f49ebe0f5de46ad8a051bd1f4b73161fcbee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 14 Nov 2023 08:29:35 +0100 Subject: [PATCH 15/23] stuff --- ot/solvers.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index a7e66931b..48a7f7401 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -25,7 +25,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, - unbalanced_type='KL', n_threads=1, max_iter=None, plan_init=None, + unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object @@ -851,7 +851,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, return res -def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, +def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", + unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, potentials_init=None, X_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. @@ -940,7 +941,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) - res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, n_threads, max_iter, plan_init, potentials_init, tol, verbose) + res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose) return res From c0884d8ad6074b0758f54e9bc129b33d62a126c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 14 Nov 2023 08:35:22 +0100 Subject: [PATCH 16/23] merge master --- ot/bregman/_empirical.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index 378c20b98..2fc1756ec 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -58,7 +58,6 @@ def func(i, j, X_a, X_b, f, g, metric, reg): return T - def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean', numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False, log=False, warn=True, warmstart=None, **kwargs): From 4b889821ac9e2cd8d648905c634f9681bf65cf9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 14 Nov 2023 10:43:06 +0100 Subject: [PATCH 17/23] =?UTF-8?q?update=20docum=C3=B9entation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ot/solvers.py | 164 ++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 153 insertions(+), 11 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index 48a7f7401..c156fbdc7 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -90,7 +90,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, The following methods are available for solving the OT problems: - - **Classical exact OT problem** (default parameters): + - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F @@ -107,7 +107,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, res = ot.solve(M, a, b) - - **Entropic regularized OT** (when ``reg!=None``): + - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) @@ -127,7 +127,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - - **Quadratic regularized OT** (when ``reg!=None`` and ``reg_type="L2"``): + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) @@ -144,7 +144,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, res = ot.solve(M,a,b,reg=1.0,reg_type='L2') - - **Unbalanced OT** (when ``unbalanced!=None``): + - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) @@ -154,14 +154,14 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, .. code-block:: python # default is ``"KL"`` - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) + res = ot.solve(M,a,b,unbalanced=1.0) # quadratic unbalanced OT - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') # TV = partial OT - res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='TV') + res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') - - **Regularized unbalanced regularized OT** (when ``unbalanced!=None`` and ``reg!=None``): + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) @@ -182,6 +182,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, References ---------- + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 @@ -199,6 +204,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + """ # detect backend @@ -413,9 +422,8 @@ def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 method : str, optional - Method for solving the problem, for entropic problems "PGD" is projected - gradient descent and "PPA" for proximal point, default None for - automatic selection ("PGD"). + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) @@ -930,6 +938,140 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t See :any:`OTResult` for more information. + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + with M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') + + + .. _references-solve-sample: + References + ---------- + + .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. + (2011, December). Displacement interpolation using Lagrangian mass + transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. + 158). ACM. + + .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation + of Optimal Transport, Advances in Neural Information Processing + Systems (NIPS) 26, 2013 + + .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). + Scaling algorithms for unbalanced transport problems. + arXiv preprint arXiv:1607.05816. + + .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse + Optimal Transport. Proceedings of the Twenty-First International + Conference on Artificial Intelligence and Statistics (AISTATS). + + .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, + A., & Peyré, G. (2019, April). Interpolating between optimal transport + and MMD using Sinkhorn divergences. In The 22nd International Conference + on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). + Unbalanced optimal transport through non-negative penalized + linear regression. NeurIPS. + + + """ if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: From d4e60b08968a0f4acf444c6ce8ab30173bc5aaf3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 14 Nov 2023 14:57:31 +0100 Subject: [PATCH 18/23] beter documentation --- ot/solvers.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 73 insertions(+), 2 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index c156fbdc7..e36e2e1d2 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -873,7 +873,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) where the cost matrix :math:`\mathbf{M}` is computed from the samples in the - source and target domains wuch that :math:`M_{i,j} = d(x_i,y_j)` where + source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where :math:`d` is a metric (by default the squared Euclidean distance). The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By @@ -985,6 +985,11 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # or for original Sinkhorn paper formulation [2] res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: @@ -1037,9 +1042,70 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # quadratic unbalanced OT with KL regularization res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') # both quadratic - res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from teh empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + where : + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + .. _references-solve-sample: References ---------- @@ -1066,6 +1132,11 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. + .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, + G., & Weed, J. (2019, April). Statistical optimal transport via factored + couplings. In The 22nd International Conference on Artificial + Intelligence and Statistics (pp. 2454-2465). PMLR. + .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. From b1182c8071e0a096aba72da9d6562c9ac93aa488 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 14 Nov 2023 15:33:15 +0100 Subject: [PATCH 19/23] pep8 --- ot/solvers.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/ot/solvers.py b/ot/solvers.py index e36e2e1d2..2234e7807 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -987,7 +987,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # lazy solver of memory complexity O(n) res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) - # lazy OT plan + # lazy OT plan lazy_plan = res.lazy_plan - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): @@ -1045,26 +1045,26 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') - + - **Factored OT [2]** (when ``method='factored'``): This method solve the following OT problem [40]_ .. math:: \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) - + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated to the samples in the source and target domains, and :math:`W_2` is the - Wasserstein distance. This problem is solved using exact OT solvers for + Wasserstein distance. This problem is solved using exact OT solvers for `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides two transport plans that can be used to recover a low rank OT plan between - the two distributions. + the two distributions. .. code-block:: python res = ot.solve_sample(xa, xb, method='factored', rank=10) - # recover the lazy low rank plan + # recover the lazy low rank plan factored_solution_lazy = res.lazy_plan # recover the full low rank plan @@ -1072,7 +1072,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): - This method computes the Gaussian Bures-Wasserstein distance between two + This method computes the Gaussian Bures-Wasserstein distance between two Gaussian distributions estimated from teh empirical distributions .. math:: @@ -1105,7 +1105,7 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t # recover the squared Wasserstein distances W_dists = res.value - + .. _references-solve-sample: References ---------- @@ -1142,7 +1142,6 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t linear regression. NeurIPS. - """ if method is not None and method.lower() in ['1d', 'gaussian', 'lowrank', 'factored']: From f54ec5692c8940a697673f0393d25cc0e0be4fde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 15 Nov 2023 09:34:51 +0100 Subject: [PATCH 20/23] big update tests --- test/test_bregman.py | 13 +++++-- test/test_solvers.py | 87 ++++++++++++++++++++++++++------------------ test/test_utils.py | 16 +++++++- 3 files changed, 76 insertions(+), 40 deletions(-) diff --git a/test/test_bregman.py b/test/test_bregman.py index 8627df3c6..67257f899 100644 --- a/test/test_bregman.py +++ b/test/test_bregman.py @@ -1078,10 +1078,10 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_sqe = nx.to_numpy(ot.sinkhorn(ab, bb, M_nx, 1)) f, g, log_es = ot.bregman.empirical_sinkhorn( - X_sb, X_tb, 0.1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) f, g = nx.to_numpy(f), nx.to_numpy(g) - G_log = np.exp(f[:, None] + g[None, :] - M / 0.1) - sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 0.1, log=True) + G_log = np.exp(f[:, None] + g[None, :] - M / 1) + sinkhorn_log, log_s = ot.sinkhorn(ab, bb, M_nx, 1, log=True) sinkhorn_log = nx.to_numpy(sinkhorn_log) f, g = ot.bregman.empirical_sinkhorn( @@ -1091,10 +1091,14 @@ def test_lazy_empirical_sinkhorn(nx): sinkhorn_m = nx.to_numpy(ot.sinkhorn(ab, bb, M_mb, 1)) loss_emp_sinkhorn, log = ot.bregman.empirical_sinkhorn2( - X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=True) + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=5, log=True) + G_lazy = nx.to_numpy(log['lazy_plan'][:]) loss_emp_sinkhorn = nx.to_numpy(loss_emp_sinkhorn) loss_sinkhorn = nx.to_numpy(ot.sinkhorn2(ab, bb, M_nx, 1)) + loss_emp_sinkhorn = ot.bregman.empirical_sinkhorn2( + X_sb, X_tb, 1, numIterMax=numIterMax, isLazy=True, batchSize=1, log=False) + # check constraints np.testing.assert_allclose( sinkhorn_sqe.sum(1), G_sqe.sum(1), atol=1e-05) # metric sqeuclidian @@ -1109,6 +1113,7 @@ def test_lazy_empirical_sinkhorn(nx): np.testing.assert_allclose( sinkhorn_m.sum(0), G_m.sum(0), atol=1e-05) # metric euclidian np.testing.assert_allclose(loss_emp_sinkhorn, loss_sinkhorn, atol=1e-05) + np.testing.assert_allclose(G_log, G_lazy, atol=1e-05) def test_empirical_sinkhorn_divergence(nx): diff --git a/test/test_solvers.py b/test/test_solvers.py index d11cd60eb..c6e1a3770 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -25,12 +25,24 @@ lst_method_params_solve_sample = [ {'method': '1d'}, + {'method': '1d', 'metric': 'euclidean'}, {'method': 'gaussian'}, {'method': 'gaussian', 'reg': 1}, {'method': 'factored', 'rank': 10}, ] + +lst_parameters_solve_sample_NotImplemented = [ + {'method': '1d', 'metric': 'any other one'}, # fail 1d on weird metrics + {'method': 'gaussian', 'metric': 'euclidean'}, # fail gaussian on metric not euclidean + {'method': 'factored', 'metric': 'euclidean'}, # fail factored on metric not euclidean + {'lazy': True}, # fail lazy for non regularized + {'lazy': True, 'unbalanced': 1}, # fail lazy for non regularized unbalanced + {'lazy': True, 'reg': 1, 'unbalanced': 1}, # fail lazy for unbalanced and regularized +] + # set readable ids for each param lst_method_params_solve_sample = [pytest.param(param, id=str(param)) for param in lst_method_params_solve_sample] +lst_parameters_solve_sample_NotImplemented = [pytest.param(param, id=str(param)) for param in lst_parameters_solve_sample_NotImplemented] def assert_allclose_sol(sol1, sol2): @@ -268,7 +280,7 @@ def test_solve_gromov_not_implemented(nx): def test_solve_sample(nx): # test solve_sample when is_Lazy = False - n = 100 + n = 20 X_s = np.reshape(1.0 * np.arange(n), (n, 1)) X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) @@ -310,6 +322,32 @@ def test_solve_sample(nx): sol0 = ot.solve_sample(X_s, X_t, reg=1, reg_type='cryptic divergence') +def test_solve_sample_lazy(nx): + # test solve_sample when is_Lazy = False + n = 20 + X_s = np.reshape(1.0 * np.arange(n), (n, 1)) + X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) + + a = ot.utils.unif(X_s.shape[0]) + b = ot.utils.unif(X_t.shape[0]) + + X_s, X_t, a, b = nx.from_numpy(X_s, X_t, a, b) + + M = ot.dist(X_s, X_t) + + # solve with ot.solve + sol00 = ot.solve(M, a, b, reg=1) + + sol0 = ot.solve_sample(X_s, X_t, a, b, reg=1) + + # solve signe weights + sol = ot.solve_sample(X_s, X_t, a, b, reg=1, lazy=True) + + assert_allclose_sol(sol0, sol00) + + np.testing.assert_allclose(sol0.plan, sol.lazy_plan[:], rtol=1e-5, atol=1e-5) + + @pytest.mark.parametrize("method_params", lst_method_params_solve_sample) def test_solve_sample_methods(nx, method_params): @@ -336,41 +374,20 @@ def test_solve_sample_methods(nx, method_params): np.testing.assert_allclose(sol2.value, 0) -# def test_lazy_solve_sample(nx): -# # test solve_sample when is_Lazy = True -# n = 100 -# X_s = np.reshape(1.0 * np.arange(n), (n, 1)) -# X_t = np.reshape(1.0 * np.arange(0, n), (n, 1)) - -# a = ot.utils.unif(X_s.shape[0]) -# b = ot.utils.unif(X_t.shape[0]) - -# # solve unif weights -# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, lazy=True) # reg != 0 or None since no implementation yet for is_Lazy=True +@pytest.mark.parametrize("method_params", lst_parameters_solve_sample_NotImplemented) +def test_solve_sample_NotImplemented(nx, method_params): -# # solve signe weights -# sol = ot.solve_sample(X_s, X_t, a, b, reg=0.1, lazy=True) - -# # check some attributes -# sol.potentials -# sol.lazy_plan - -# assert_allclose_sol(sol0, sol) - -# # solve in backend -# X_sb, X_tb, ab, bb = nx.from_numpy(X_s, X_t, a, b) -# solb = ot.solve_sample(X_sb, X_tb, ab, bb, reg=0.1, lazy=True) - -# assert_allclose_sol(sol, solb) + n_samples_s = 20 + n_samples_t = 7 + n_features = 2 + rng = np.random.RandomState(0) -# # test not implemented reg==0 (or None) + balanced and check raise -# with pytest.raises(NotImplementedError): -# sol0 = ot.solve_sample(X_s, X_t, lazy=True) # reg == 0 (or None) + unbalanced= None are default + x = rng.randn(n_samples_s, n_features) + y = rng.randn(n_samples_t, n_features) + a = ot.utils.unif(n_samples_s) + b = ot.utils.unif(n_samples_t) -# # test not implemented reg==0 (or None) + unbalanced_type and check raise -# with pytest.raises(NotImplementedError): -# sol0 = ot.solve_sample(X_s, X_t, unbalanced_type="kl", lazy=True) # reg == 0 (or None) is default + xb, yb, ab, bb = nx.from_numpy(x, y, a, b) -# # test not implemented reg != 0 + unbalanced_type and check raise -# with pytest.raises(NotImplementedError): -# sol0 = ot.solve_sample(X_s, X_t, reg=0.1, unbalanced_type="kl", lazy=True) + with pytest.raises(NotImplementedError): + ot.solve_sample(xb, yb, ab, bb, **method_params) diff --git a/test/test_utils.py b/test/test_utils.py index 3a9d590ab..e265295cf 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -318,6 +318,9 @@ def test_cost_normalization(nx): M1 = nx.to_numpy(M) np.testing.assert_allclose(M1.max(), np.log(1 + np.log(1 + C)).max()) + with pytest.raises(ValueError): + ot.utils.cost_normalization(C1, 'error') + def test_check_params(): @@ -328,6 +331,16 @@ def test_check_params(): assert res0 is False +def test_check_random_state_error(): + with pytest.raises(ValueError): + ot.utils.check_random_state('error') + + +def test_get_parameter_pairs_error(): + with pytest.raises(ValueError): + ot.utils.get_parameter_pairs((1, 2, 3)) # not pair ;) + + def test_deprecated_func(): @ot.utils.deprecated('deprecated text for fun') @@ -408,7 +421,8 @@ def test_OTResult(): 'status', 'value', 'value_linear', - 'value_quad'] + 'value_quad', + 'log'] for at in lst_attributes: print(at) with pytest.raises(NotImplementedError): From 129d905714fb896bc563110911db4d7f56f60bbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 15 Nov 2023 09:39:40 +0100 Subject: [PATCH 21/23] debug small test --- test/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index e265295cf..258a1c742 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -336,9 +336,9 @@ def test_check_random_state_error(): ot.utils.check_random_state('error') -def test_get_parameter_pairs_error(): +def test_get_parameter_pair_error(): with pytest.raises(ValueError): - ot.utils.get_parameter_pairs((1, 2, 3)) # not pair ;) + ot.utils.get_parameter_pair((1, 2, 3)) # not pair ;) def test_deprecated_func(): From c098a15a275b561d3df3eb8a32d53591a0dee1db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 15 Nov 2023 12:40:10 +0100 Subject: [PATCH 22/23] =?UTF-8?q?remarques=20c=C3=A9dri?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- RELEASES.md | 1 + ot/bregman/_empirical.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 321d9a78b..915b5c34e 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,6 +16,7 @@ + Add exact line-search for `gromov_wasserstein` and `fused_gromov_wasserstein` with KL loss (PR #556) + Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559) + Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551) ++ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/bregman/_empirical.py b/ot/bregman/_empirical.py index 2fc1756ec..b84c3b389 100644 --- a/ot/bregman/_empirical.py +++ b/ot/bregman/_empirical.py @@ -18,7 +18,10 @@ def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=None): - """ Get a LazyTensor of sinkhorn solution T = exp(f+g^T-C/reg) + r""" Get a LazyTensor of Sinkhorn solution from the dual potentials + + The returned LazyTensor is + :math:`\mathbf{T} = exp( \mathbf{f} \mathbf{1}_b^\top + \mathbf{1}_a \mathbf{g}^\top - \mathbf{C}/reg)`, where :math:`\mathbf{C}` is the pairwise metric matrix between samples :math:`\mathbf{X}_a` and :math:`\mathbf{X}_b`. Parameters ---------- @@ -41,7 +44,7 @@ def get_sinkhorn_lazytensor(X_a, X_b, f, g, metric='sqeuclidean', reg=1e-1, nx=N Returns ------- T : LazyTensor - Lowrank tensor T = exp(f+g^T-C/reg) + Sinkhorn solution tensor """ if nx is None: From d3f5bf3aab048ecf1c9bdaa375c77f90698a7e02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 15 Nov 2023 12:41:41 +0100 Subject: [PATCH 23/23] small stuff --- ot/solvers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ot/solvers.py b/ot/solvers.py index 2234e7807..aed7e8ffe 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -59,7 +59,11 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional - Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + Type of unbalanced penalization function :math:`U` either "KL", "L2", + "TV", by default "KL" + method : str, optional + Method for solving the problem when multiple algorithms are available, + default None for automatic selection. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional