diff --git a/.github/workflows/build_tests.yml b/.github/workflows/build_tests.yml index b7ece015a..2ffc90786 100644 --- a/.github/workflows/build_tests.yml +++ b/.github/workflows/build_tests.yml @@ -18,7 +18,7 @@ jobs: strategy: max-parallel: 4 matrix: - python-version: [3.5, 3.6, 3.7, 3.8] + python-version: [3.6, 3.7, 3.8] steps: - uses: actions/checkout@v1 diff --git a/ot/torch/lp.py b/ot/torch/lp.py index 627a40294..c39df05f2 100644 --- a/ot/torch/lp.py +++ b/ot/torch/lp.py @@ -5,7 +5,9 @@ import numpy as np import torch from torch.autograd import Function -from .. import emd +from ot import emd +from torch.nn.functional import pad +from ot.torch.utils import quantile_function # Author: Remi Flamary @@ -20,7 +22,6 @@ class OptimalTransportLossFunction(Function): @staticmethod # bias is an optional argument def forward(ctx, a, b, M, num_iter_max=100000): - # convert to numpy a2 = a.detach().cpu().numpy().astype(np.float64) b2 = b.detach().cpu().numpy().astype(np.float64) @@ -42,7 +43,6 @@ def forward(ctx, a, b, M, num_iter_max=100000): @staticmethod def backward(ctx, grad_output): - grad_a, grad_b, grad_M = ctx.saved_tensors print(grad_a) @@ -56,7 +56,6 @@ def ot_loss(a, b, M, num_iter_max=100000): def ot_solve(a, b, M, num_iter_max=100000, log=False): - a2 = a.detach().cpu().numpy().astype(np.float64) b2 = b.detach().cpu().numpy().astype(np.float64) M2 = M.detach().cpu().numpy().astype(np.float64) @@ -79,3 +78,96 @@ def ot_solve(a, b, M, num_iter_max=100000, log=False): G = emd(a2, b2, M2, log=False, numItermax=num_iter_max) return torch.from_numpy(G).type_as(M) + + +def ot_loss_1d(u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True): + r""" + Computes the 1 dimensional OT loss [2] between two (batched) empirical distributions + ..math: + ot_{loss} &= \int_0^1 |cdf_u^{-1}(q) cdf_v^{-1}(q)|^p dq + + It is formally the p-Wasserstein distance raised to the power p. + We do so in a vectorized way by first building the individual quantile functions then integrating them. + This has a theoretically higher complexity than the core OT implementation but behaves better with PyTorch + + Parameters + ---------- + u_values: torch.Tensor (n, ...) + locations of the first empirical distribution + v_values: torch.Tensor (m, ...) + locations of the second empirical distribution + u_weights: torch.Tensor (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: torch.Tensor (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1 (see [2, Chap. 2], default is 1 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + + Returns + ------- + cost: torch.Tensor (...,) + the batched EMD + + Examples + -------- + Simple example: + >>> import ot + >>> import torch + >>> np.random.seed(0) + >>> n_source = 7 + >>> n_target = 100 + >>> a = torch.tensor(ot.utils.unif(n_source), requires_grad=True) + >>> b = torch.tensor(ot.utils.unif(n_target)) + >>> X_source = torch.tensor(np.random.randn(n_source,), requires_grad=True) + >>> Y_target = torch.tensor(np.random.randn(n_target,)) + >>> loss = ot.torch.lp.ot_loss_1d(X_source, Y_target, a, b) + >>> torch.autograd.grad(loss, X_source)[0] + tensor([0.1429, 0.1429, 0.1429, 0.1229, 0.1429, 0.1429, 0.1429], + dtype=torch.float64) + + References + ---------- + .. [2] Cuturi, M. (2013). [Sinkhorn distances: Lightspeed computation of optimal transport](https://arxiv.org/pdf/1306.0895.pdf). In Advances in Neural Information Processing Systems (pp. 2292-2300). + + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + n = u_values.shape[0] + m = v_values.shape[0] + + device = u_values.device + dtype = u_values.dtype + + if u_weights is None: + u_weights = torch.full_like(u_values, 1 / n, dtype=dtype, device=device) + elif u_weights.ndim != u_values.ndim: + u_weights = torch.repeat_interleave(u_weights.unsqueeze(-1), u_values.shape[-1], -1) + + if v_weights is None: + v_weights = torch.full_like(v_values, 1 / m, dtype=dtype, device=device) + elif v_weights.ndim != v_values.ndim: + v_weights = torch.repeat_interleave(v_weights.unsqueeze(-1), v_values.shape[-1], -1) + + if require_sort: + u_values, u_sorter = torch.sort(u_values, 0) + v_values, v_sorter = torch.sort(v_values, 0) + + u_weights = torch.gather(u_weights, 0, u_sorter) + v_weights = torch.gather(v_weights, 0, v_sorter) + + u_cumweights = torch.cumsum(u_weights, 0) + v_cumweights = torch.cumsum(v_weights, 0) + + qs, _ = torch.sort(torch.cat((u_cumweights, v_cumweights), 0), 0) + u_quantiles = quantile_function(qs, u_cumweights, u_values) + v_quantiles = quantile_function(qs, v_cumweights, v_values) + + qs = pad(qs, (qs.ndim - 1) * (0, 0) + (1, 0)) + delta = qs[1:, ...] - qs[:-1, ...] + diff_quantiles = torch.abs(u_quantiles - v_quantiles) + + if p == 1: + return torch.sum(delta * torch.abs(diff_quantiles), dim=0) + return torch.sum(delta * torch.pow(diff_quantiles, p), dim=0) diff --git a/ot/torch/utils.py b/ot/torch/utils.py index 4b349c08b..6225be5e4 100644 --- a/ot/torch/utils.py +++ b/ot/torch/utils.py @@ -58,7 +58,7 @@ def dist(x1, x2, metric="sqeuclidean"): if x2 is None: x2 = x1 if metric == "sqeuclidean": - return torch.cdist(x1, x2, p=2)**2 + return torch.cdist(x1, x2, p=2) ** 2 elif metric == "euclidean": p = 2 elif metric == "cityblock": @@ -89,3 +89,29 @@ def proj_simplex(v, z=1): return w[:, 0] else: return w + + +def quantile_function(qs, cws, xs): + # type: (torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor + r""" Computes the quantile function of an empirical distribution + + Parameters + ---------- + qs: torch.tensor (n,) + Quantiles at which the quantile function is evaluated + cws: torch.tensor (m, ...) + cumulative weights of the 1D empirical distribution, if batched, must be similar to xs + xs: torch.tensor (n, ...) + locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + + Returns + ------- + q: torch.tensor (..., n) + The quantiles of the distribution + """ + n = xs.shape[0] + + cws = cws.T.contiguous() + qs = qs.T.contiguous() + idx = torch.searchsorted(cws, qs).T + return torch.gather(xs, 0, idx.clip(0, n - 1)) diff --git a/test/test_torch.py b/test/test_torch.py index 25f9fa9ef..65bc6ee91 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -12,7 +12,6 @@ import ot.torch import torch - nogo = False lst_types = [torch.float32, torch.float64] @@ -22,12 +21,10 @@ except BaseException: - nogo = True + pytest.skip("Missing pytorch", allow_module_level=True) -@pytest.mark.skipif(nogo, reason="Missing pytorch") def test_dist(): - n = 200 lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5] @@ -39,16 +36,13 @@ def test_dist(): y = torch.randn(n, 2, dtype=dtype, device=device) for metric in lst_metrics: - M = ot.torch.dist(x, y, metric) assert M.shape[0] == n assert M.shape[1] == n -@pytest.mark.skipif(nogo, reason="Missing pytorch") def test_ot_loss(): - n = 10 lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5] @@ -63,21 +57,17 @@ def test_ot_loss(): b = ot.torch.unif(n, dtype=dtype, device=device) for metric in lst_metrics: - M = ot.torch.dist(x, y, metric) loss = ot.torch.ot_loss(a, b, M) assert float(loss) >= 0 -@pytest.mark.skipif(nogo, reason="Missing pytorch") def test_proj_simplex(): - n = 10 for dtype in lst_types: for device in lst_devices: - x = torch.randn(n, dtype=dtype, device=device) xp = ot.torch.proj_simplex(x) @@ -93,9 +83,7 @@ def test_proj_simplex(): assert torch.allclose(xp.sum(0), torch.ones(3, dtype=dtype, device=device)) -@pytest.mark.skipif(nogo, reason="Missing pytorch") def test_ot_loss_grad(): - n = 10 lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5] @@ -104,7 +92,6 @@ def test_ot_loss_grad(): for device in lst_devices: for metric in lst_metrics: - x = torch.randn(n, 2, dtype=dtype, device=device, requires_grad=True) y = torch.randn(n, 2, dtype=dtype, device=device, requires_grad=True) @@ -124,9 +111,7 @@ def test_ot_loss_grad(): assert float(loss) >= 0 -@pytest.mark.skipif(nogo, reason="Missing pytorch") def test_ot_solve(): - n = 10 lst_metrics = ['sqeuclidean', 'euclidean', 'cityblock', 0, 0.5, 1, 2, 5] @@ -141,9 +126,101 @@ def test_ot_solve(): b = ot.torch.unif(n, dtype=dtype, device=device) for metric in lst_metrics: - M = ot.torch.dist(x, y, metric) G = ot.torch.ot_solve(a, b, M) np.testing.assert_allclose(ot.unif(n), G.sum(1).cpu().numpy()) np.testing.assert_allclose(ot.unif(n), G.sum(0).cpu().numpy()) # cf convergence sinkhorn + + +@pytest.mark.parametrize("random_weights", [True, False]) +@pytest.mark.parametrize("batch_size", [0, 2, 10]) +def test_ot_loss_1d(random_weights, batch_size): + torch.random.manual_seed(42) + n = 300 + m = 200 + k = 5 + ps = [1, 2, 3] + + for dtype in lst_types: + for device in lst_devices: + if batch_size: + x = torch.randn(n, batch_size, k, dtype=dtype, device=device) + y = torch.randn(m, batch_size, k, dtype=dtype, device=device) + else: + x = torch.randn(n, k, dtype=dtype, device=device) + y = torch.randn(m, k, dtype=dtype, device=device) + if random_weights: + if batch_size: + a = torch.rand(n, batch_size, dtype=dtype, device=device) + b = torch.rand(m, batch_size, dtype=dtype, device=device) + else: + a = torch.rand(n, dtype=dtype, device=device) + b = torch.rand(m, dtype=dtype, device=device) + a = a / torch.sum(a, 0, keepdim=True) + b = b / torch.sum(b, 0, keepdim=True) + np_a = a.cpu().numpy() + np_b = b.cpu().numpy() + else: + a = b = np_a = np_b = None + + for p in ps: + same_dist_cost = ot.torch.lp.ot_loss_1d(x, x, a, a, p) + assert np.allclose(same_dist_cost.cpu().numpy(), 0., atol=1e-5) + torch_cost = ot.torch.lp.ot_loss_1d(x, y, a, b, p) + + if batch_size: + cpu_cost = np.zeros((batch_size, k)) + else: + cpu_cost = np.zeros(k) + + for i in range(k): + if batch_size: + for batch_num in range(batch_size): + cpu_cost[batch_num, i] = ot.lp.emd2_1d(x[:, batch_num, i].cpu().numpy(), + y[:, batch_num, i].cpu().numpy(), + np_a if np_a is None else np_a[:, batch_num], + np_b if np_b is None else np_b[:, batch_num], + "minkowski", p=p) + else: + cpu_cost[i] = ot.lp.emd2_1d(x[:, i].cpu().numpy(), y[:, i].cpu().numpy(), np_a, np_b, + "minkowski", p=p) + + np.testing.assert_allclose(cpu_cost, torch_cost.cpu().numpy(), atol=1e-5) + + +def test_ot_loss_1d_grad(): + torch.random.manual_seed(42) + n = 10 + m = 5 + k = 5 + ps = [1, 2, 3] + + for dtype in lst_types: + for device in lst_devices: + x = torch.randn(n, k, dtype=dtype, device=device, requires_grad=True) + y = torch.randn(m, k, dtype=dtype, device=device, requires_grad=True) + + a = torch.rand(n, dtype=dtype, device=device, requires_grad=True) + b = torch.rand(m, dtype=dtype, device=device, requires_grad=True) + + for p in ps: + torch.autograd.gradcheck(lambda *inp: ot.torch.lp.ot_loss_1d(*inp, p=p), (x, y, a, b), eps=1e-3, + atol=1e-2, raise_exception=True) + + res_equal = ot.torch.lp.ot_loss_1d(x, x, a, a, p=p).sum() + print(torch.autograd.grad(res_equal, (x, a))) + + +@pytest.mark.filterwarnings("error") +def test_quantile(): + torch.random.manual_seed(42) + dims = (100, 5, 3) + cws = torch.rand(*dims) + cws = cws / cws.sum(0, keepdim=True) + cws = torch.cumsum(cws, 0) + qs, _ = torch.sort(torch.rand(*dims), dim=0) + xs = torch.randn(*dims) + res = ot.torch.utils.quantile_function(qs, cws, xs) + assert np.all(res.cpu().numpy() <= xs.max(0, keepdim=True)[0].cpu().numpy()) + assert np.all(res.cpu().numpy() >= xs.min(0, keepdim=True)[0].cpu().numpy())