diff --git a/README.md b/README.md index f8880a166..ed6f2a89c 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,7 @@ POT provides the following generic OT solvers: * Fused unbalanced Gromov-Wasserstein [70]. * [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77] * [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77] +* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84] POT provides the following Machine Learning related solvers: @@ -449,5 +450,8 @@ Artificial Intelligence. [81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). +[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385. -``` +[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661. + +[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations. diff --git a/RELEASES.md b/RELEASES.md index ccb9b97d2..eee4257af 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,5 +1,11 @@ # Releases +## 0.9.7dev + +#### New features + +- Added Sliced OT plans (PR #757) + ## 0.9.6.post1 *September 2025* diff --git a/examples/sliced-wasserstein/plot_sliced_plans.py b/examples/sliced-wasserstein/plot_sliced_plans.py new file mode 100644 index 000000000..ca7b35a3f --- /dev/null +++ b/examples/sliced-wasserstein/plot_sliced_plans.py @@ -0,0 +1,168 @@ +# -*- coding: utf-8 -*- +""" +=============== +Sliced OT Plans +=============== + +Compares different Sliced OT plans between two 2D point clouds. The min-Pivot +Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both +were further studied theoretically in [83]. + +.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + +.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + +.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. +""" + +# Author: Eloi Tanguy +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +############################################################################## +# Setup data and imports +# ---------------------- +import numpy as np +import ot +import matplotlib.pyplot as plt +from ot.sliced import get_random_projections + +seed = 0 +np.random.seed(seed) +n = 10 +d = 2 +X = np.random.randn(n, 2) +Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :] +n_proj = 20 +thetas = get_random_projections(d, n_proj).T +alpha = 0.3 + +############################################################################## +# Compute min-Pivot Sliced permutation +# ------------------------------------ +min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True) +min_plan = np.zeros((n, n)) +min_plan[np.arange(n), min_perm] = 1 / n + +############################################################################## +# Compute Expected Sliced Plan +# ------------------------------------ +expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True) + +############################################################################## +# Compute 2-Wasserstein Plan +# ------------------------------------ +a = np.ones(n, device=X.device) / n +dists = ot.dist(X, Y) +W2 = ot.emd2(a, a, dists) +W2_plan = ot.emd(a, a, dists) + +############################################################################## +# Plot resulting assignments +# ------------------------------------ +fig, axs = plt.subplots(2, 3, figsize=(12, 4)) +fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16) + +# draw min sliced permutation +axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") +for i in range(n): + axs[0, 0].plot( + [X[i, 0], Y[min_perm[i], 0]], + [X[i, 1], Y[min_perm[i], 1]], + color="black", + alpha=alpha, + label="min-Sliced perm" if i == 0 else None, + ) +axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues") + +# draw expected sliced plan +axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}") +for i in range(n): + for j in range(n): + w = alpha * expected_plan[i, j].item() * n + axs[0, 1].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) +axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues") + +# draw W2 plan +axs[0, 2].set_title(f"W2: cost={W2:.2f}") +for i in range(n): + for j in range(n): + w = alpha * W2_plan[i, j].item() * n + axs[0, 2].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="W2 plan" if i == 0 and j == 0 else None, + ) +axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues") + +for ax in axs[0, :]: + ax.scatter(X[:, 0], X[:, 1], label="X") + ax.scatter(Y[:, 0], Y[:, 1], label="Y") + +for ax in axs.flatten(): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() + +############################################################################## +# Compare Expected Sliced plans with different inverse-temperatures beta +# ------------------------------------ +## As the temperature decreases, ES becomes sparser and approaches minPS +betas = [0.0, 5.0, 50.0] +n_plots = len(betas) + 1 +size = 4 +fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size)) +fig.suptitle( + "Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16 +) +for beta_idx, beta in enumerate(betas): + expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta) + print(f"beta={beta}: cost={expected_cost:.2f}") + + axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}") + for i in range(n): + for j in range(n): + w = alpha * expected_plan[i, j].item() * n + axs[0, beta_idx].plot( + [X[i, 0], Y[j, 0]], + [X[i, 1], Y[j, 1]], + color="black", + alpha=w, + label="Expected Sliced plan" if i == 0 and j == 0 else None, + ) + + axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X") + axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y") + axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues") + +# draw min sliced permutation (limit when beta -> +inf) +axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}") +for i in range(n): + axs[0, -1].plot( + [X[i, 0], Y[min_perm[i], 0]], + [X[i, 1], Y[min_perm[i], 1]], + color="black", + alpha=alpha, + label="min-Sliced perm" if i == 0 else None, + ) +axs[0, -1].scatter(X[:, 0], X[:, 1], label="X") +axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y") +axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues") + +for ax in axs.flatten(): + ax.set_aspect("equal") + ax.set_xticks([]) + ax.set_yticks([]) + +fig.tight_layout() diff --git a/ot/__init__.py b/ot/__init__.py index 235fb91b4..7f1d8152f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,6 +58,8 @@ sliced_wasserstein_sphere, sliced_wasserstein_sphere_unif, linear_sliced_wasserstein_sphere, + min_pivot_sliced, + expected_sliced, ) from .gromov import ( gromov_wasserstein, @@ -109,6 +111,8 @@ "sliced_wasserstein_distance", "sliced_wasserstein_sphere", "linear_sliced_wasserstein_sphere", + "min_pivot_sliced", + "expected_sliced", "gromov_wasserstein", "gromov_wasserstein2", "gromov_barycenters", diff --git a/ot/sliced.py b/ot/sliced.py index 3cf2002e7..cc79fa968 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -1,17 +1,17 @@ """ Sliced OT Distances - """ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Eloi Tanguy # # License: MIT License import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array, get_coordinate_circle +from .utils import list_to_array, get_coordinate_circle, dist from .lp import ( wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -674,3 +674,262 @@ def linear_sliced_wasserstein_sphere( if log: return res, {"projections": projections, "projected_emds": projected_lcot} return res + + +def sliced_permutations(X, Y, thetas=None, n_proj=None, log=False, backend=None): + r""" + Computes all the permutations that sort the projections of two `(n, d)` + datasets `X` and `Y` on the directions `thetas`. + Each permutation `perm[:, k]` is such that each `X[i, :]` is matched + to `Y[perm[i, k], :]` when projected on `thetas[k, :]`. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (n, d) + The second set of vectors. + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be generated. + Default is None. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + log : bool, optional + If True, returns additional logging information. Default is False. + backend : ot.backend, optional + Backend to use for computations. If None, the backend is inferred from the input arrays. Default is None. + + Returns + ------- + perm : array-like, shape (n, n_proj) + All sliced permutations. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" + nx = get_backend(X, Y) if backend is None else backend + d = X.shape[1] + do_draw_thetas = thetas is None + if do_draw_thetas: # create thetas (n_proj, d) + thetas = get_random_projections(d, n_proj, backend=nx).T + + # project on each theta: (n, d) -> (n, n_proj) + X_theta = X @ thetas.T # shape (n, n_proj) + Y_theta = Y @ thetas.T # shape (n, n_proj) + + # sigma[:, i_proj] is a permutation sorting X_theta[:, i_proj] + sigma = nx.argsort(X_theta, axis=0) # (n, n_proj) + tau = nx.argsort(Y_theta, axis=0) # (n, n_proj) + + # perm[:, i_proj] is tau[:, i_proj] o sigma[:, i_proj]^{-1} + perm = nx.take_along_axis(tau, nx.argsort(sigma, axis=0), axis=0) # (n, n_proj) + + if log: + log_dict = { + "X_theta": X_theta, + "Y_theta": Y_theta, + "sigma": sigma, + "tau": tau, + "perm": perm, + } + if do_draw_thetas: + log_dict["thetas"] = thetas + return perm, log_dict + else: + return perm + + +def min_pivot_sliced( + X, Y, thetas=None, order=2, n_proj=None, log=False, warm_perm=None +): + r""" + Computes the cost and permutation associated to the min-Pivot Sliced + Discrepancy (introduced as SWGG in [82] and studied further in [83]). Given + the supports `X` and `Y` of two discrete uniform measures with `n` atoms in + dimension `d`, the min-Pivot Sliced Discrepancy goes through `n_proj` + different projections of the measures on random directions, and retains the + permutation that yields the lowest cost between `X` and `Y` (compared + in :math:`\mathbb{R}^d`). + + .. math:: + \mathrm{min\text{-}PS}_p^p(X, Y) \approx + \min_{k \in [1, n_{\mathrm{proj}}]} \left( + \frac{1}{n} \sum_{i=1}^n \|X_i - Y_{\sigma_k(i)}\|_2^p \right), + + where :math:`\sigma_k` is a permutation such that ordering the projections + on the axis `thetas[k, :]` matches `X[i, :]` to `Y[\sigma_k(i), :]`. + + .. note:: + The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + + Parameters + ---------- + X : array-like, shape (n, d) + The first set of vectors. + Y : array-like, shape (n, d) + The second set of vectors. + thetas : array-like, shape (n_proj, d), optional + The projection directions. If None, random directions will be generated. Default is None. + order : int, optional + Power to elevate the norm. Default is 2. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + log : bool, optional + If True, returns additional logging information. Default is False. + warm_perm : array-like, shape (n,), optional + A permutation to add to the permutation list. Default is None. + + Returns + ------- + perm : array-like, shape (n,) + The permutation that minimizes the cost. + min_cost : float + The minimum cost corresponding to the optimal permutation. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385. + + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" + n = X.shape[0] + nx = get_backend(X, Y) + log_dict = {} + + if log: + perm, log_dict = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=True, backend=nx + ) + else: + perm = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=False, backend=nx + ) + + # add the 'warm perm' to permutations to test + if warm_perm is not None: + perm = nx.concatenate([perm, warm_perm[:, None]], axis=1) + if log: + log_dict["perm"] = perm + + min_cost = None + idx_min_cost = None + costs = [] + + for k in range(perm.shape[-1]): + cost = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n + if min_cost is None or cost < min_cost: + min_cost = cost + idx_min_cost = k + if log: + costs.append(cost) + + min_perm = perm[:, idx_min_cost] + + if log: + log_dict["costs"] = costs + log_dict["idx_min_cost"] = idx_min_cost + return min_perm, min_cost, log_dict + else: + return min_perm, min_cost + + +def expected_sliced(X, Y, thetas=None, n_proj=None, order=2, log=False, beta=0.0): + r""" + Computes the Expected Sliced cost and plan between two `(n, d)` + datasets `X` and `Y`. Given a set of `n_proj` projection directions, + the expected sliced plan is obtained by averaging the `n_proj` 1d optimal + transport plans between the projections of `X` and `Y` on each direction. + Expected Sliced was introduced in [84] and further studied in [83]. + + .. note:: + The computation ignores potential ambiguities in the projections: if two points from a same measure have the same projection on a direction, then multiple sorting permutations are possible. To avoid combinatorial explosion, only one permutation is retained: this strays from theory in pathological cases. + + .. warning:: + The function runs on backend but tensorflow and jax are not supported due to array assignment. + + Parameters + ---------- + X : torch.Tensor + A tensor of shape (n, d) representing the first set of vectors. + Y : torch.Tensor + A tensor of shape (n, d) representing the second set of vectors. + thetas : torch.Tensor, optional + A tensor of shape (n_proj, d) representing the projection directions. + If None, random directions will be generated. Default is None. + n_proj : int, optional + The number of projection directions. Required if thetas is None. + order : int, optional + Power to elevate the norm. Default is 2. + log : bool, optional + If True, returns additional logging information. Default is False. + beta : float, optional + Inverse-temperature parameter which weights each projection's contribution to the expected plan. Default is 0 (uniform weighting). + + Returns + ------- + plan : torch.Tensor + A tensor of shape (n_proj, n, n) representing the expected sliced plan. + log_dict : dict, optional + A dictionary containing intermediate computations for logging purposes. + Returned only if `log` is True. + + References + ---------- + .. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661. + + .. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations. + """ + assert ( + X.shape == Y.shape + ), f"X ({X.shape}) and Y ({Y.shape}) must have the same shape" + + nx = get_backend(X, Y) + if str(nx) in ["tf", "jax"]: + raise NotImplementedError( + f"expected_sliced is not implemented for the {str(nx)} backend due" + "to array assignment." + ) + n = X.shape[0] + + log_dict = {} + if log: + perm, log_dict = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx + ) + else: + perm = sliced_permutations( + X, Y, thetas=thetas, n_proj=n_proj, log=log, backend=nx + ) + plan = nx.zeros((n, n), type_as=X) + n_proj = perm.shape[1] + range_array = nx.arange(n, type_as=X) + + if beta != 0.0: # computing the temperature weighting + log_factors = nx.zeros(n_proj, type_as=X) # for beta weighting + for k in range(n_proj): + cost_k = nx.sum(nx.abs(X - Y[perm[:, k]]) ** order) / n + log_factors[k] = -beta * cost_k + weights = nx.exp(log_factors - nx.logsumexp(log_factors)) + + else: # uniform weights + weights = nx.ones(n_proj, type_as=X) / n_proj + + for k in range(n_proj): # populating the expected plan + # 1 / n is because is a permutation of [1, n] + plan[range_array, perm[:, k]] += (1 / n) * weights[k] + + cost = (dist(X, Y, p=order) * plan).sum() + + if log: + return plan, cost, log_dict + else: + return plan, cost diff --git a/test/test_sliced.py b/test/test_sliced.py index 05de13755..7f12d378a 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -2,6 +2,7 @@ # Author: Adrien Corenflos # Nicolas Courty +# Eloi Tanguy # # License: MIT License @@ -11,6 +12,7 @@ import ot from ot.sliced import get_random_projections from ot.backend import tf, torch +from contextlib import nullcontext def test_get_random_projections(): @@ -110,6 +112,14 @@ def test_max_sliced_different_dists(): assert res > 0.0 +def test_max_sliced_dim_check(): + n = 3 + x = np.zeros((n, 2)) + y = np.zeros((n + 1, 3)) + with pytest.raises(ValueError): + _ = ot.max_sliced_wasserstein_distance(x, y, n_projections=10) + + def test_sliced_same_proj(): n_projections = 10 seed = 12 @@ -152,6 +162,16 @@ def test_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_sliced_backend_type_devices(nx): n = 100 @@ -227,6 +247,16 @@ def test_max_sliced_backend(nx): assert np.allclose(val0, valb) + a = rng.uniform(0, 1, n) + a /= a.sum() + b = rng.uniform(0, 1, 2 * n) + b /= b.sum() + a_b = nx.from_numpy(a) + b_b = nx.from_numpy(b) + val = ot.max_sliced_wasserstein_distance(x, y, a=a, b=b, projections=P) + val_b = ot.max_sliced_wasserstein_distance(xb, yb, a=a_b, b=b_b, projections=Pb) + np.testing.assert_almost_equal(val, nx.to_numpy(val_b)) + def test_max_sliced_backend_type_devices(nx): n = 100 @@ -697,3 +727,112 @@ def test_linear_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) + + +def test_sliced_permutations(nx): + n = 4 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + perm = ot.sliced.sliced_permutations(x, y, thetas=thetas) + perm_b, _ = ot.sliced.sliced_permutations( + x_b, y_b, thetas=thetas_b, log=True, backend=nx + ) + + np.testing.assert_almost_equal(perm, nx.to_numpy(perm_b)) + + # test without provided thetas + perm = ot.sliced.sliced_permutations(x, y, n_proj=n_proj) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.sliced_permutations(x[1:, :], y, thetas=thetas) + + +def test_min_pivot_sliced(nx): + n = 10 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + min_perm, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + min_perm_b, min_cost_b, _ = ot.sliced.min_pivot_sliced( + x_b, y_b, thetas=thetas_b, log=True + ) + + np.testing.assert_almost_equal(min_perm, nx.to_numpy(min_perm_b)) + np.testing.assert_almost_equal(min_cost, nx.to_numpy(min_cost_b)) + + # result should be an upper-bound of W2 and relatively close + w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + assert min_cost >= w2 + assert min_cost <= 1.5 * w2 + + # test without provided thetas and with a warm permutation + ot.sliced.min_pivot_sliced(x, y, n_proj=n_proj, warm_perm=np.arange(n), log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + + +def test_expected_sliced(nx): + n = 10 + n_proj = 10 + d = 2 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + + x_b, y_b = nx.from_numpy(x, y) + thetas = ot.sliced.get_random_projections(d, n_proj, seed=0).T + thetas_b = nx.from_numpy(thetas) + + context = ( + nullcontext() + if str(nx) not in ["tf", "jax"] + else pytest.raises(NotImplementedError) + ) + + with context: + expected_plan, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas) + expected_plan_b, expected_cost_b, _ = ot.sliced.expected_sliced( + x_b, y_b, thetas=thetas_b, log=True + ) + + np.testing.assert_almost_equal(expected_plan, nx.to_numpy(expected_plan_b)) + np.testing.assert_almost_equal(expected_cost, nx.to_numpy(expected_cost_b)) + + # result should be a coarse upper-bound of W2 + w2 = ot.emd2(ot.unif(n), ot.unif(n), ot.dist(x, y)) + assert expected_cost >= w2 + assert expected_cost <= 3 * w2 + + # test without provided thetas + ot.sliced.expected_sliced(x, y, n_proj=n_proj, log=True) + + # test with invalid shapes + with pytest.raises(AssertionError): + ot.sliced.min_pivot_sliced(x[1:, :], y, thetas=thetas) + + # with a small temperature (i.e. large beta), the cost should be close + # to min_pivot + _, expected_cost = ot.sliced.expected_sliced(x, y, thetas=thetas, beta=100.0) + _, min_cost = ot.sliced.min_pivot_sliced(x, y, thetas=thetas) + np.testing.assert_almost_equal(expected_cost, min_cost, decimal=3)