From 1e1ff3f17cee7469a0b587a917133f286d5d0c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 7 Oct 2025 09:27:21 +0200 Subject: [PATCH] Revert "[WIP] Sliced OT Plans (#757)" This reverts commit 07d6bfa74d39e2e784552a7b8ac6593f1331b430. --- README.md | 6 +- RELEASES.md | 6 - .../sliced-wasserstein/plot_sliced_plans.py | 168 ----------- ot/__init__.py | 4 - ot/sliced.py | 263 +----------------- test/test_sliced.py | 139 --------- 6 files changed, 3 insertions(+), 583 deletions(-) delete mode 100644 examples/sliced-wasserstein/plot_sliced_plans.py diff --git a/README.md b/README.md index ed6f2a89c..f8880a166 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,6 @@ POT provides the following generic OT solvers: * Fused unbalanced Gromov-Wasserstein [70]. * [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] * [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] -* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84] POT provides the following Machine Learning related solvers: @@ -450,8 +449,5 @@ Artificial Intelligence. [81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). -[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. -[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. - -[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. +``` diff --git a/RELEASES.md b/RELEASES.md index eee4257af..ccb9b97d2 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,11 +1,5 @@ # Releases -## 0.9.7dev - -#### New features - -- Added Sliced OT plans (PR #757) - ## 0.9.6.post1 *September 2025* diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py deleted file mode 100644 index ca7b35a3f..000000000 --- a/examples/sliced-wasserstein/plot_sliced_plans.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -""" -=============== -Sliced OT Plans -=============== - -Compares different Sliced OT plans between two 2D point clouds. The min-Pivot -Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both -were further studied theoretically in [83]. - -.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. - -.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. - -.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. -""" - -# Author: Eloi Tanguy -# License: MIT License - -# sphinx_gallery_thumbnail_number = 1 - -############################################################################## -# Setup data and imports -# ---------------------- -import numpy as np -import ot -import matplotlib.pyplot as plt -from ot.sliced import get_random_projections - -seed = 0 -np.random.seed(seed) -n = 10 -d = 2 -X = np.random.randn(n, 2) -Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :] -n_proj = 20 -thetas = get_random_projections(d, n_proj).T -alpha = 0.3 - -############################################################################## -# Compute min-Pivot Sliced permutation -# ------------------------------------ -min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) -min_plan = np.zeros((n, n)) -min_plan[np.arange(n), min_perm] = 1 / n - -############################################################################## -# Compute Expected Sliced Plan -# ------------------------------------ -expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) - -############################################################################## -# Compute 2-Wasserstein Plan -# ------------------------------------ -a = np.ones(n, device=X.device) / n -dists = ot.dist(X, Y) -W2 = ot.emd2(a, a, dists) -W2_plan = ot.emd(a, a, dists) - -############################################################################## -# Plot resulting assignments -# ------------------------------------ -fig, axs = plt.subplots(2, 3, figsize=(12, 4)) -fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16) - -# draw min sliced permutation -axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") -for i in range(n): - axs[0, 0].plot( - [X[i, 0], Y[min_perm[i], 0]], - [X[i, 1], Y[min_perm[i], 1]], - color="black", - alpha=alpha, - label="min-Sliced perm" if i == 0 else None, - ) -axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues") - -# draw expected sliced plan -axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}") -for i in range(n): - for j in range(n): - w = alpha * expected_plan[i, j].item() * n - axs[0, 1].plot( - [X[i, 0], Y[j, 0]], - [X[i, 1], Y[j, 1]], - color="black", - alpha=w, - label="Expected Sliced plan" if i == 0 and j == 0 else None, - ) -axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues") - -# draw W2 plan -axs[0, 2].set_title(f"W2: cost={W2:.2f}") -for i in range(n): - for j in range(n): - w = alpha * W2_plan[i, j].item() * n - axs[0, 2].plot( - [X[i, 0], Y[j, 0]], - [X[i, 1], Y[j, 1]], - color="black", - alpha=w, - label="W2 plan" if i == 0 and j == 0 else None, - ) -axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues") - -for ax in axs[0, :]: - ax.scatter(X[:, 0], X[:, 1], label="X") - ax.scatter(Y[:, 0], Y[:, 1], label="Y") - -for ax in axs.flatten(): - ax.set_aspect("equal") - ax.set_xticks([]) - ax.set_yticks([]) - -fig.tight_layout() - -############################################################################## -# Compare Expected Sliced plans with different inverse-temperatures beta -# ------------------------------------ -## As the temperature decreases, ES becomes sparser and approaches minPS -betas = [0.0, 5.0, 50.0] -n_plots = len(betas) + 1 -size = 4 -fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size)) -fig.suptitle( - "Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16 -) -for beta_idx, beta in enumerate(betas): - expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta) - print(f"beta={beta}: cost={expected_cost:.2f}") - - axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}") - for i in range(n): - for j in range(n): - w = alpha * expected_plan[i, j].item() * n - axs[0, beta_idx].plot( - [X[i, 0], Y[j, 0]], - [X[i, 1], Y[j, 1]], - color="black", - alpha=w, - label="Expected Sliced plan" if i == 0 and j == 0 else None, - ) - - axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X") - axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y") - axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues") - -# draw min sliced permutation (limit when beta -> +inf) -axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") -for i in range(n): - axs[0, -1].plot( - [X[i, 0], Y[min_perm[i], 0]], - [X[i, 1], Y[min_perm[i], 1]], - color="black", - alpha=alpha, - label="min-Sliced perm" if i == 0 else None, - ) -axs[0, -1].scatter(X[:, 0], X[:, 1], label="X") -axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y") -axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues") - -for ax in axs.flatten(): - ax.set_aspect("equal") - ax.set_xticks([]) - ax.set_yticks([]) - -fig.tight_layout() diff --git a/ot/__init__.py b/ot/__init__.py index 7f1d8152f..235fb91b4 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,8 +58,6 @@ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif, linear_sliced_wasserstein_sphere, - min_pivot_sliced, - expected_sliced, ) from .gromov import ( gromov_wasserstein, @@ -111,8 +109,6 @@ "sliced_wasserstein_distance", "sliced_wasserstein_sphere", "linear_sliced_wasserstein_sphere", - "min_pivot_sliced", - "expected_sliced", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", diff --git a/ot/sliced.py b/ot/sliced.py index cc79fa968..3cf2002e7 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,17 +1,17 @@ """ Sliced OT Distances + """ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary -# Eloi Tanguy # # License: MIT License import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array, get_coordinate_circle, dist +from .utils import list_to_array, get_coordinate_circle from .lp import ( wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -674,262 +674,3 @@ def linear_sliced_wasserstein_sphere( if log: return res, {"projections": projections, "projected_emds": projected_lcot} return res - - -def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None): - r""" - Computes all the permutations that sort the projections of two `(n, d)` - datasets `X` and `Y` on the directions `thetas`. - Each permutation `perm[:, k]` is such that each `X[i, :]` is matched - to `Y[perm[i, k], :]` when projected on `thetas[k, :]`. - - Parameters - ---------- - X : array-like, shape (n, d) - The first set of vectors. - Y : array-like, shape (n, d) - The second set of vectors. - thetas : array-like, shape (n_proj, d), optional - The projection directions. If None, random directions will be generated. - Default is None. - n_proj : int, optional - The number of projection directions. Required if thetas is None. - log : bool, optional - If True, returns additional logging information. Default is False. - backend : ot.backend, optional - Backend to use for computations. If None, the backend is inferred from the input arrays. Default is None. - - Returns - ------- - perm : array-like, shape (n, n_proj) - All sliced permutations. - log_dict : dict, optional - A dictionary containing intermediate computations for logging purposes. - Returned only if `log` is True. - """ - assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" - nx = get_backend(X, Y) if backend is None else backend - d = X.shape[1] - do_draw_thetas = thetas is None - if do_draw_thetas: # create thetas (n_proj, d) - thetas = get_random_projections(d, n_proj, backend=nx).T - - # project on each theta: (n, d) -> (n, n_proj) - X_theta = X @ thetas.T # shape (n, n_proj) - Y_theta = Y @ thetas.T # shape (n, n_proj) - - # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] - sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) - tau = nx.argsort(Y_theta, axis=0) # (n, n_proj) - - # perm[:, i_proj] is tau[:, i_proj] o sigma[:, i_proj]^{-1} - perm = nx.take_along_axis(tau, nx.argsort(sigma, axis=0), axis=0) # (n, n_proj) - - if log: - log_dict = { - "X_theta": X_theta, - "Y_theta": Y_theta, - "sigma": sigma, - "tau": tau, - "perm": perm, - } - if do_draw_thetas: - log_dict["thetas"] = thetas - return perm, log_dict - else: - return perm - - -def min_pivot_sliced( - X, Y, thetas=None, order=2, n_proj=None, log=False, warm_perm=None -): - r""" - Computes the cost and permutation associated to the min-Pivot Sliced - Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given - the supports `X` and `Y` of two discrete uniform measures with `n` atoms in - dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj` - different projections of the measures on random directions, and retains the - permutation that yields the lowest cost between `X` and `Y` (compared - in :math:`\mathbb{R}^d`). - - .. math:: - \mathrm{min\text{-}PS}_p^p(X, Y) \approx - \min_{k \in [1, n_{\mathrm{proj}}]} \left( - \frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\sigma_k(i)}\|_2^p \right), - - where :math:`\sigma_k` is a permutation such that ordering the projections - on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`. - - .. note:: - The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. - - Parameters - ---------- - X : array-like, shape (n, d) - The first set of vectors. - Y : array-like, shape (n, d) - The second set of vectors. - thetas : array-like, shape (n_proj, d), optional - The projection directions. If None, random directions will be generated. Default is None. - order : int, optional - Power to elevate the norm. Default is 2. - n_proj : int, optional - The number of projection directions. Required if thetas is None. - log : bool, optional - If True, returns additional logging information. Default is False. - warm_perm : array-like, shape (n,), optional - A permutation to add to the permutation list. Default is None. - - Returns - ------- - perm : array-like, shape (n,) - The permutation that minimizes the cost. - min_cost : float - The minimum cost corresponding to the optimal permutation. - log_dict : dict, optional - A dictionary containing intermediate computations for logging purposes. - Returned only if `log` is True. - - References - ---------- - .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. - - .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. - """ - assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" - n = X.shape[0] - nx = get_backend(X, Y) - log_dict = {} - - if log: - perm, log_dict = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=True, backend=nx - ) - else: - perm = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=False, backend=nx - ) - - # add the 'warm perm' to permutations to test - if warm_perm is not None: - perm = nx.concatenate([perm, warm_perm[:, None]], axis=1) - if log: - log_dict["perm"] = perm - - min_cost = None - idx_min_cost = None - costs = [] - - for k in range(perm.shape[-1]): - cost = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n - if min_cost is None or cost < min_cost: - min_cost = cost - idx_min_cost = k - if log: - costs.append(cost) - - min_perm = perm[:, idx_min_cost] - - if log: - log_dict["costs"] = costs - log_dict["idx_min_cost"] = idx_min_cost - return min_perm, min_cost, log_dict - else: - return min_perm, min_cost - - -def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0): - r""" - Computes the Expected Sliced cost and plan between two `(n, d)` - datasets `X` and `Y`. Given a set of `n_proj` projection directions, - the expected sliced plan is obtained by averaging the `n_proj` 1d optimal - transport plans between the projections of `X` and `Y` on each direction. - Expected Sliced was introduced in [84] and further studied in [83]. - - .. note:: - The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. - - .. warning:: - The function runs on backend but tensorflow and jax are not supported due to array assignment. - - Parameters - ---------- - X : torch.Tensor - A tensor of shape (n, d) representing the first set of vectors. - Y : torch.Tensor - A tensor of shape (n, d) representing the second set of vectors. - thetas : torch.Tensor, optional - A tensor of shape (n_proj, d) representing the projection directions. - If None, random directions will be generated. Default is None. - n_proj : int, optional - The number of projection directions. Required if thetas is None. - order : int, optional - Power to elevate the norm. Default is 2. - log : bool, optional - If True, returns additional logging information. Default is False. - beta : float, optional - Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). - - Returns - ------- - plan : torch.Tensor - A tensor of shape (n_proj, n, n) representing the expected sliced plan. - log_dict : dict, optional - A dictionary containing intermediate computations for logging purposes. - Returned only if `log` is True. - - References - ---------- - .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. - - .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. - """ - assert ( - X.shape == Y.shape - ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" - - nx = get_backend(X, Y) - if str(nx) in ["tf", "jax"]: - raise NotImplementedError( - f"expected_sliced is not implemented for the {str(nx)} backend due" - "to array assignment." - ) - n = X.shape[0] - - log_dict = {} - if log: - perm, log_dict = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx - ) - else: - perm = sliced_permutations( - X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx - ) - plan = nx.zeros((n, n), type_as=X) - n_proj = perm.shape[1] - range_array = nx.arange(n, type_as=X) - - if beta != 0.0: # computing the temperature weighting - log_factors = nx.zeros(n_proj, type_as=X) # for beta weighting - for k in range(n_proj): - cost_k = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n - log_factors[k] = -beta * cost_k - weights = nx.exp(log_factors - nx.logsumexp(log_factors)) - - else: # uniform weights - weights = nx.ones(n_proj, type_as=X) / n_proj - - for k in range(n_proj): # populating the expected plan - # 1 / n is because is a permutation of [1, n] - plan[range_array, perm[:, k]] += (1 / n) * weights[k] - - cost = (dist(X, Y, p=order) * plan).sum() - - if log: - return plan, cost, log_dict - else: - return plan, cost diff --git a/test/test_sliced.py b/test/test_sliced.py index 7f12d378a..05de13755 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -2,7 +2,6 @@ # Author: Adrien Corenflos # Nicolas Courty -# Eloi Tanguy # # License: MIT License @@ -12,7 +11,6 @@ import ot from ot.sliced import get_random_projections from ot.backend import tf, torch -from contextlib import nullcontext def test_get_random_projections(): @@ -112,14 +110,6 @@ def test_max_sliced_different_dists(): assert res > 0.0 -def test_max_sliced_dim_check(): - n = 3 - x = np.zeros((n, 2)) - y = np.zeros((n + 1, 3)) - with pytest.raises(ValueError): - _ = ot.max_sliced_wasserstein_distance(x, y, n_projections=10) - - def test_sliced_same_proj(): n_projections = 10 seed = 12 @@ -162,16 +152,6 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) - a = rng.uniform(0, 1, n) - a /= a.sum() - b = rng.uniform(0, 1, 2 * n) - b /= b.sum() - a_b = nx.from_numpy(a) - b_b = nx.from_numpy(b) - val = ot.sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) - val_b = ot.sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) - np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) - def test_sliced_backend_type_devices(nx): n = 100 @@ -247,16 +227,6 @@ def test_max_sliced_backend(nx): assert np.allclose(val0, valb) - a = rng.uniform(0, 1, n) - a /= a.sum() - b = rng.uniform(0, 1, 2 * n) - b /= b.sum() - a_b = nx.from_numpy(a) - b_b = nx.from_numpy(b) - val = ot.max_sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) - val_b = ot.max_sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) - np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) - def test_max_sliced_backend_type_devices(nx): n = 100 @@ -727,112 +697,3 @@ def test_linear_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) - - -def test_sliced_permutations(nx): - n = 4 - n_proj = 10 - d = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - y = rng.randn(n, 2) - - x_b, y_b = nx.from_numpy(x, y) - thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) - - perm = ot.sliced.sliced_permutations(x, y, thetas=thetas) - perm_b, _ = ot.sliced.sliced_permutations( - x_b, y_b, thetas=thetas_b, log=True, backend=nx - ) - - np.testing.assert_almost_equal(perm, nx.to_numpy(perm_b)) - - # test without provided thetas - perm = ot.sliced.sliced_permutations(x, y, n_proj=n_proj) - - # test with invalid shapes - with pytest.raises(AssertionError): - ot.sliced.sliced_permutations(x[1:, :], y, thetas=thetas) - - -def test_min_pivot_sliced(nx): - n = 10 - n_proj = 10 - d = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - y = rng.randn(n, 2) - - x_b, y_b = nx.from_numpy(x, y) - thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) - - min_perm, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) - min_perm_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( - x_b, y_b, thetas=thetas_b, log=True - ) - - np.testing.assert_almost_equal(min_perm, nx.to_numpy(min_perm_b)) - np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b)) - - # result should be an upper-bound of W2 and relatively close - w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) - assert min_cost >= w2 - assert min_cost <= 1.5 * w2 - - # test without provided thetas and with a warm permutation - ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, warm_perm=np.arange(n), log=True) - - # test with invalid shapes - with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) - - -def test_expected_sliced(nx): - n = 10 - n_proj = 10 - d = 2 - rng = np.random.RandomState(0) - - x = rng.randn(n, 2) - y = rng.randn(n, 2) - - x_b, y_b = nx.from_numpy(x, y) - thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T - thetas_b = nx.from_numpy(thetas) - - context = ( - nullcontext() - if str(nx) not in ["tf", "jax"] - else pytest.raises(NotImplementedError) - ) - - with context: - expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) - expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( - x_b, y_b, thetas=thetas_b, log=True - ) - - np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) - np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) - - # result should be a coarse upper-bound of W2 - w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) - assert expected_cost >= w2 - assert expected_cost <= 3 * w2 - - # test without provided thetas - ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) - - # test with invalid shapes - with pytest.raises(AssertionError): - ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) - - # with a small temperature (i.e. large beta), the cost should be close - # to min_pivot - _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) - _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) - np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)