diff --git a/RELEASES.md b/RELEASES.md index cf1dc0c4c..f943886d5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -12,6 +12,7 @@ + New LP solvers from scipy used by default for LP barycenter (PR #537) + Update wheels to Python 3.12 and remove old i686 arch that do not have scipy wheels (PR #543) + Upgraded unbalanced OT solvers for more flexibility (PR #539) ++ Add LazyTensor for modeling plans and low rank tensor in large scale OT (PR #544) #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) diff --git a/ot/factored.py b/ot/factored.py index 8d6615876..65613d328 100644 --- a/ot/factored.py +++ b/ot/factored.py @@ -7,7 +7,7 @@ # License: MIT License from .backend import get_backend -from .utils import dist +from .utils import dist, get_lowrank_lazytensor from .lp import emd from .bregman import sinkhorn @@ -139,6 +139,7 @@ def solve_ot(X1, X2, w1, w2): 'vb': logb['v'], 'costa': loga['cost'], 'costb': logb['cost'], + 'lazy_plan': get_lowrank_lazytensor(Ga * r, Gb.T, nx=nx), } return Ga, Gb, X, log_dic diff --git a/ot/utils.py b/ot/utils.py index 4efcb225e..0936648ca 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -492,6 +492,121 @@ def get_coordinate_circle(x): return x_t +def reduce_lazytensor(a, func, axis=None, nx=None, batch_size=100): + """ Reduce a LazyTensor along an axis with function fun using batches. + + When axis=None, reduce the LazyTensor to a scalar as a sum of fun over + batches taken along dim. + + .. warning:: + This function works for tensor of any order but the reduction can be done + only along the first two axis (or global). Also, in order to work, it requires that the slice of size `batch_size` along the axis to reduce (or axis 0 if `axis=None`) is can be computed and fits in memory. + + + Parameters + ---------- + a : LazyTensor + LazyTensor to reduce + func : callable + Function to apply to the LazyTensor + axis : int, optional + Axis along which to reduce the LazyTensor. If None, reduce the + LazyTensor to a scalar as a sum of fun over batches taken along axis 0. + If 0 or 1 reduce the LazyTensor to a vector/matrix as a sum of fun over + batches taken along axis. + nx : Backend, optional + Backend to use for the reduction + batch_size : int, optional + Size of the batches to use for the reduction (default=100) + + Returns + ------- + res : array-like + Result of the reduction + + """ + + if nx is None: + nx = get_backend(a[0]) + + if axis is None: + res = 0.0 + for i in range(0, a.shape[0], batch_size): + res += func(a[i:i + batch_size]) + return res + elif axis == 0: + res = nx.zeros(a.shape[1:], type_as=a[0]) + if nx.__name__ in ["jax", "tf"]: + lst = [] + for j in range(0, a.shape[1], batch_size): + lst.append(func(a[:, j:j + batch_size], 0)) + return nx.concatenate(lst, axis=0) + else: + for j in range(0, a.shape[1], batch_size): + res[j:j + batch_size] = func(a[:, j:j + batch_size], axis=0) + return res + elif axis == 1: + if len(a.shape) == 2: + shape = (a.shape[0]) + else: + shape = (a.shape[0], *a.shape[2:]) + res = nx.zeros(shape, type_as=a[0]) + if nx.__name__ in ["jax", "tf"]: + lst = [] + for i in range(0, a.shape[0], batch_size): + lst.append(func(a[i:i + batch_size], 1)) + return nx.concatenate(lst, axis=0) + else: + for i in range(0, a.shape[0], batch_size): + res[i:i + batch_size] = func(a[i:i + batch_size], axis=1) + return res + + else: + raise (NotImplementedError("Only axis=None, 0 or 1 is implemented for now.")) + + +def get_lowrank_lazytensor(Q, R, d=None, nx=None): + """ Get a low rank LazyTensor T=Q@R^T or T=Q@diag(d)@R^T + + Parameters + ---------- + Q : ndarray, shape (n, r) + First factor of the lowrank tensor + R : ndarray, shape (m, r) + Second factor of the lowrank tensor + d : ndarray, shape (r,), optional + Diagonal of the lowrank tensor + nx : Backend, optional + Backend to use for the reduction + + Returns + ------- + T : LazyTensor + Lowrank tensor T=Q@R^T or T=Q@diag(d)@R^T + """ + + if nx is None: + nx = get_backend(Q, R, d) + + shape = (Q.shape[0], R.shape[0]) + + if d is None: + + def func(i, j, Q, R): + return nx.dot(Q[i], R[j].T) + + T = LazyTensor(shape, func, Q=Q, R=R) + + else: + + def func(i, j, Q, R, d): + return nx.dot(Q[i] * d[None, :], R[j].T) + + T = LazyTensor(shape, func, Q=Q, R=R, d=d) + + return T + + def get_parameter_pair(parameter): r"""Extract a pair of parameters from a given parameter Used in unbalanced OT and COOT solvers @@ -761,7 +876,76 @@ class UndefinedParameter(Exception): class OTResult: - def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None): + """ Base class for OT results. + + Parameters + ---------- + + potentials : tuple of array-like, shape (`n1`, `n2`) + Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + plan : array-like, shape (`n1`, `n2`) + Transport plan, encoded as a dense array. + log : dict + Dictionary containing potential information about the solver. + backend : Backend + Backend used to compute the results. + sparse_plan : array-like, shape (`n1`, `n2`) + Transport plan, encoded as a sparse array. + lazy_plan : LazyTensor + Transport plan, encoded as a symbolic POT or KeOps LazyTensor. + status : int or str + Status of the solver. + batch_size : int + Batch size used to compute the results/marginals for LazyTensor. + + Attributes + ---------- + + potentials : tuple of array-like, shape (`n1`, `n2`) + Dual potentials, i.e. Lagrange multipliers for the marginal constraints. + This pair of arrays has the same shape, numerical type + and properties as the input weights "a" and "b". + potential_a : array-like, shape (`n1`,) + First dual potential, associated to the "source" measure "a". + potential_b : array-like, shape (`n2`,) + Second dual potential, associated to the "target" measure "b". + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + plan : array-like, shape (`n1`, `n2`) + Transport plan, encoded as a dense array. + sparse_plan : array-like, shape (`n1`, `n2`) + Transport plan, encoded as a sparse array. + lazy_plan : LazyTensor + Transport plan, encoded as a symbolic POT or KeOps LazyTensor. + marginals : tuple of array-like, shape (`n1`,), (`n2`,) + Marginals of the transport plan: should be very close to "a" and "b" + for balanced OT. + marginal_a : array-like, shape (`n1`,) + Marginal of the transport plan for the "source" measure "a". + marginal_b : array-like, shape (`n2`,) + Marginal of the transport plan for the "target" measure "b". + + """ + + def __init__(self, potentials=None, value=None, value_linear=None, value_quad=None, plan=None, log=None, backend=None, sparse_plan=None, lazy_plan=None, status=None, batch_size=100): self._potentials = potentials self._value = value @@ -773,6 +957,7 @@ def __init__(self, potentials=None, value=None, value_linear=None, value_quad=No self._lazy_plan = lazy_plan self._backend = backend if backend is not None else NumpyBackend() self._status = status + self._batch_size = batch_size # I assume that other solvers may return directly # some primal objects? @@ -793,7 +978,8 @@ def __repr__(self): s += 'value_linear={},'.format(self._value_linear) if self._plan is not None: s += 'plan={}(shape={}),'.format(self._plan.__class__.__name__, self._plan.shape) - + if self._lazy_plan is not None: + s += 'lazy_plan={}(shape={}),'.format(self._lazy_plan.__class__.__name__, self._lazy_plan.shape) if s[-1] != '(': s = s[:-1] + ')' else: @@ -853,7 +1039,10 @@ def sparse_plan(self): @property def lazy_plan(self): """Transport plan, encoded as a symbolic KeOps LazyTensor.""" - raise NotImplementedError() + if self._lazy_plan is not None: + return self._lazy_plan + else: + raise NotImplementedError() # Loss values -------------------------------- @@ -897,6 +1086,11 @@ def marginal_a(self): """First marginal of the transport plan, with the same shape as "a".""" if self._plan is not None: return self._backend.sum(self._plan, 1) + elif self._lazy_plan is not None: + lp = self._lazy_plan + bs = self._batch_size + nx = self._backend + return reduce_lazytensor(lp, nx.sum, axis=1, nx=nx, batch_size=bs) else: raise NotImplementedError() @@ -905,6 +1099,11 @@ def marginal_b(self): """Second marginal of the transport plan, with the same shape as "b".""" if self._plan is not None: return self._backend.sum(self._plan, 0) + elif self._lazy_plan is not None: + lp = self._lazy_plan + bs = self._batch_size + nx = self._backend + return reduce_lazytensor(lp, nx.sum, axis=0, nx=nx, batch_size=bs) else: raise NotImplementedError() @@ -968,3 +1167,70 @@ def citation(self): url = {http://jmlr.org/papers/v22/20-451.html} } """ + + +class LazyTensor(object): + """ A lazy tensor is a tensor that is not stored in memory. Instead, it is + defined by a function that computes its values on the fly from slices. + + Parameters + ---------- + + shape : tuple + shape of the tensor + getitem : callable + function that computes the values of the indices/slices and tensors + as arguments + + kwargs : dict + named arguments for the function, those names will be used as attributed + of the LazyTensor object + + Examples + -------- + >>> import numpy as np + >>> v = np.arange(5) + >>> def getitem(i,j, v): + ... return v[i,None]+v[None,j] + >>> T = LazyTensor((5,5),getitem, v=v) + >>> T[1,2] + array([3]) + >>> T[1,:] + array([[1, 2, 3, 4, 5]]) + >>> T[:] + array([[0, 1, 2, 3, 4], + [1, 2, 3, 4, 5], + [2, 3, 4, 5, 6], + [3, 4, 5, 6, 7], + [4, 5, 6, 7, 8]]) + + """ + + def __init__(self, shape, getitem, **kwargs): + + self._getitem = getitem + self.shape = shape + self.ndim = len(shape) + self.kwargs = kwargs + + # set attributes for named arguments/arrays + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getitem__(self, key): + k = [] + if isinstance(key, int) or isinstance(key, slice): + k.append(key) + for i in range(self.ndim - 1): + k.append(slice(None)) + elif isinstance(key, tuple): + k = list(key) + for i in range(self.ndim - len(key)): + k.append(slice(None)) + else: + raise NotImplementedError("Only integer, slice, and tuple indexing is supported") + + return self._getitem(*k, **self.kwargs) + + def __repr__(self): + return "LazyTensor(shape={},attributes=({}))".format(self.shape, ','.join(self.kwargs.keys())) diff --git a/test/test_factored.py b/test/test_factored.py index fd2fd0133..5cfc997ef 100644 --- a/test/test_factored.py +++ b/test/test_factored.py @@ -28,6 +28,7 @@ def test_factored_ot(): # check constraints np.testing.assert_allclose(u, Ga.sum(1)) np.testing.assert_allclose(u, Gb.sum(0)) + np.testing.assert_allclose(1, log['lazy_plan'][:].sum()) def test_factored_ot_backends(nx): diff --git a/test/test_utils.py b/test/test_utils.py index 40324518e..3a9d590ab 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -10,6 +10,27 @@ import pytest +def get_LazyTensor(nx): + n1 = 100 + n2 = 200 + + rng = np.random.RandomState(42) + a = rng.rand(n1) + a /= a.sum() + b = rng.rand(n2) + b /= b.sum() + + a, b = nx.from_numpy(a, b) + + def getitem(i, j, a, b): + return a[i, None] * b[None, j] + + # create a lazy tensor + T = ot.utils.LazyTensor((n1, n2), getitem, a=a, b=b) + + return T, a, b + + def test_proj_simplex(nx): n = 10 rng = np.random.RandomState(0) @@ -389,6 +410,7 @@ def test_OTResult(): 'value_linear', 'value_quad'] for at in lst_attributes: + print(at) with pytest.raises(NotImplementedError): getattr(res, at) @@ -401,3 +423,149 @@ def test_get_coordinate_circle(): x_p = ot.utils.get_coordinate_circle(x) np.testing.assert_allclose(u[0], x_p) + + +def test_LazyTensor(nx): + + n1 = 100 + n2 = 200 + shape = (n1, n2) + + rng = np.random.RandomState(42) + x1 = rng.randn(n1, 2) + x2 = rng.randn(n2, 2) + + x1, x2 = nx.from_numpy(x1, x2) + + # i,j can be integers or slices, x1,x2 have to be passed as keyword arguments + def getitem(i, j, x1, x2): + return nx.dot(x1[i], x2[j].T) + + # create a lazy tensor + T = ot.utils.LazyTensor((n1, n2), getitem, x1=x1, x2=x2) + + assert T.shape == (n1, n2) + assert str(T) == "LazyTensor(shape=(100, 200),attributes=(x1,x2))" + + assert T.x1 is x1 + assert T.x2 is x2 + + # get the full tensor (not lazy) + assert T[:].shape == shape + + # get one component + assert T[1, 1] == nx.dot(x1[1], x2[1].T) + + # get one row + assert T[1].shape == (n2,) + + # get one column with slices + assert T[::10, 5].shape == (10,) + + with pytest.raises(NotImplementedError): + T["error"] + + +def test_OTResult_LazyTensor(nx): + + T, a, b = get_LazyTensor(nx) + + res = ot.utils.OTResult(lazy_plan=T, batch_size=9, backend=nx) + + np.testing.assert_allclose(nx.to_numpy(a), nx.to_numpy(res.marginal_a)) + np.testing.assert_allclose(nx.to_numpy(b), nx.to_numpy(res.marginal_b)) + + +def test_LazyTensor_reduce(nx): + + T, a, b = get_LazyTensor(nx) + + T0 = T[:] + s0 = nx.sum(T0) + + # total sum + s = ot.utils.reduce_lazytensor(T, nx.sum, nx=nx) + np.testing.assert_allclose(nx.to_numpy(s), 1) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s0)) + + s2 = ot.utils.reduce_lazytensor(T, nx.sum) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + + s2 = ot.utils.reduce_lazytensor(T, nx.sum, batch_size=500) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + + s2 = ot.utils.reduce_lazytensor(T, nx.sum, batch_size=11) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + + # sum over axis 0 + s = ot.utils.reduce_lazytensor(T, nx.sum, axis=0, nx=nx) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(b)) + + # sum over axis 1 + s = ot.utils.reduce_lazytensor(T, nx.sum, axis=1, nx=nx) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(a)) + + # test otehr reduction function + s = ot.utils.reduce_lazytensor(T, nx.logsumexp, axis=1, nx=nx) + s2 = nx.logsumexp(T[:], axis=1) + np.testing.assert_allclose(nx.to_numpy(s), nx.to_numpy(s2)) + + # test 3D tensors + def getitem(i, j, k, a, b, c): + return a[i, None, None] * b[None, j, None] * c[None, None, k] + + # create a lazy tensor + n = a.shape[0] + T = ot.utils.LazyTensor((n, n, n), getitem, a=a, b=a, c=a) + + # total sum + s1 = ot.utils.reduce_lazytensor(T, nx.sum, axis=0, nx=nx) + s2 = ot.utils.reduce_lazytensor(T, nx.sum, axis=1, nx=nx) + + np.testing.assert_allclose(nx.to_numpy(s1), nx.to_numpy(s2)) + + with pytest.raises(NotImplementedError): + ot.utils.reduce_lazytensor(T, nx.sum, axis=2, nx=nx, batch_size=10) + + +def test_lowrank_LazyTensor(nx): + + p = 5 + n1 = 100 + n2 = 200 + + shape = (n1, n2) + + rng = np.random.RandomState(42) + X1 = rng.randn(n1, p) + X2 = rng.randn(n2, p) + diag_d = rng.rand(p) + + X1, X2, diag_d = nx.from_numpy(X1, X2, diag_d) + + T0 = nx.dot(X1, X2.T) + + T = ot.utils.get_lowrank_lazytensor(X1, X2) + + np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0)) + + assert T.Q is X1 + assert T.R is X2 + + # get the full tensor (not lazy) + assert T[:].shape == shape + + # get one component + assert T[1, 1] == nx.dot(X1[1], X2[1].T) + + # get one row + assert T[1].shape == (n2,) + + # get one column with slices + assert T[::10, 5].shape == (10,) + + T0 = nx.dot(X1 * diag_d[None, :], X2.T) + + T = ot.utils.get_lowrank_lazytensor(X1, X2, diag_d, nx=nx) + + np.testing.assert_allclose(nx.to_numpy(T[:]), nx.to_numpy(T0))