Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ POT provides the following generic OT solvers:
* Fused unbalanced Gromov-Wasserstein [70].
* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [77]
* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 77]
* [Sliced Optimal Transport Plans](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_sliced_plans.html) [82, 83, 84]

POT provides the following Machine Learning related solvers:

Expand Down Expand Up @@ -449,5 +450,8 @@ Artificial Intelligence.

[81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS).

[82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). [Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics](https://proceedings.neurips.cc/paper_files/paper/2023/hash/6f1346bac8b02f76a631400e2799b24b-Abstract-Conference.html). Advances in Neural Information Processing Systems, 36, 35350–35385.

```
[83] Tanguy, E., Chapel, L., Delon, J. (2025). [Sliced Optimal Transport Plans](https://arxiv.org/abs/2508.01243) arXiv preprint 2506.03661.

[84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). [Expected Sliced Transport Plans](https://openreview.net/forum?id=P7O1Vt1BdU). International Conference on Learning Representations.
6 changes: 6 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Releases

## 0.9.7dev

#### New features

- Added Sliced OT plans (PR #757)

## 0.9.6.post1

*September 2025*
Expand Down
168 changes: 168 additions & 0 deletions examples/sliced-wasserstein/plot_sliced_plans.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-
"""
===============
Sliced OT Plans
===============

Compares different Sliced OT plans between two 2D point clouds. The min-Pivot
Sliced plan was introduced in [82], and the Expected Sliced plan in [84], both
were further studied theoretically in [83].

.. [82] Mahey, G., Chapel, L., Gasso, G., Bonet, C., & Courty, N. (2023). Fast Optimal Transport through Sliced Generalized Wasserstein Geodesics. Advances in Neural Information Processing Systems, 36, 35350–35385.

.. [83] Tanguy, E., Chapel, L., Delon, J. (2025). Sliced Optimal Transport Plans. arXiv preprint 2506.03661.

.. [84] Liu, X., Diaz Martin, R., Bai Y., Shahbazi A., Thorpe M., Aldroubi A., Kolouri, S. (2024). Expected Sliced Transport Plans. International Conference on Learning Representations.
"""

# Author: Eloi Tanguy <eloi.tanguy@math.cnrs.fr>
# License: MIT License

# sphinx_gallery_thumbnail_number = 1

##############################################################################
# Setup data and imports
# ----------------------
import numpy as np
import ot
import matplotlib.pyplot as plt
from ot.sliced import get_random_projections

seed = 0
np.random.seed(seed)
n = 10
d = 2
X = np.random.randn(n, 2)
Y = np.random.randn(n, 2) + np.array([5.0, 0.0])[None, :]
n_proj = 20
thetas = get_random_projections(d, n_proj).T
alpha = 0.3

##############################################################################
# Compute min-Pivot Sliced permutation
# ------------------------------------
min_perm, min_cost, log_min = ot.min_pivot_sliced(X, Y, thetas, log=True)
min_plan = np.zeros((n, n))
min_plan[np.arange(n), min_perm] = 1 / n

##############################################################################
# Compute Expected Sliced Plan
# ------------------------------------
expected_plan, expected_cost, log_expected = ot.expected_sliced(X, Y, thetas, log=True)

##############################################################################
# Compute 2-Wasserstein Plan
# ------------------------------------
a = np.ones(n, device=X.device) / n
dists = ot.dist(X, Y)
W2 = ot.emd2(a, a, dists)
W2_plan = ot.emd(a, a, dists)

##############################################################################
# Plot resulting assignments
# ------------------------------------
fig, axs = plt.subplots(2, 3, figsize=(12, 4))
fig.suptitle("Sliced plans comparison", y=0.95, fontsize=16)

# draw min sliced permutation
axs[0, 0].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}")
for i in range(n):
axs[0, 0].plot(
[X[i, 0], Y[min_perm[i], 0]],
[X[i, 1], Y[min_perm[i], 1]],
color="black",
alpha=alpha,
label="min-Sliced perm" if i == 0 else None,
)
axs[1, 0].imshow(min_plan, interpolation="nearest", cmap="Blues")

# draw expected sliced plan
axs[0, 1].set_title(f"Expected Sliced: cost={expected_cost:.2f}")
for i in range(n):
for j in range(n):
w = alpha * expected_plan[i, j].item() * n
axs[0, 1].plot(
[X[i, 0], Y[j, 0]],
[X[i, 1], Y[j, 1]],
color="black",
alpha=w,
label="Expected Sliced plan" if i == 0 and j == 0 else None,
)
axs[1, 1].imshow(expected_plan, interpolation="nearest", cmap="Blues")

# draw W2 plan
axs[0, 2].set_title(f"W2: cost={W2:.2f}")
for i in range(n):
for j in range(n):
w = alpha * W2_plan[i, j].item() * n
axs[0, 2].plot(
[X[i, 0], Y[j, 0]],
[X[i, 1], Y[j, 1]],
color="black",
alpha=w,
label="W2 plan" if i == 0 and j == 0 else None,
)
axs[1, 2].imshow(W2_plan, interpolation="nearest", cmap="Blues")

for ax in axs[0, :]:
ax.scatter(X[:, 0], X[:, 1], label="X")
ax.scatter(Y[:, 0], Y[:, 1], label="Y")

for ax in axs.flatten():
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])

fig.tight_layout()

##############################################################################
# Compare Expected Sliced plans with different inverse-temperatures beta
# ------------------------------------
## As the temperature decreases, ES becomes sparser and approaches minPS
betas = [0.0, 5.0, 50.0]
n_plots = len(betas) + 1
size = 4
fig, axs = plt.subplots(2, n_plots, figsize=(size * n_plots, size))
fig.suptitle(
"Expected Sliced plan varying beta (inverse temperature)", y=0.95, fontsize=16
)
for beta_idx, beta in enumerate(betas):
expected_plan, expected_cost = ot.expected_sliced(X, Y, thetas, beta=beta)
print(f"beta={beta}: cost={expected_cost:.2f}")

axs[0, beta_idx].set_title(f"beta={beta}: cost={expected_cost:.2f}")
for i in range(n):
for j in range(n):
w = alpha * expected_plan[i, j].item() * n
axs[0, beta_idx].plot(
[X[i, 0], Y[j, 0]],
[X[i, 1], Y[j, 1]],
color="black",
alpha=w,
label="Expected Sliced plan" if i == 0 and j == 0 else None,
)

axs[0, beta_idx].scatter(X[:, 0], X[:, 1], label="X")
axs[0, beta_idx].scatter(Y[:, 0], Y[:, 1], label="Y")
axs[1, beta_idx].imshow(expected_plan, interpolation="nearest", cmap="Blues")

# draw min sliced permutation (limit when beta -> +inf)
axs[0, -1].set_title(f"Min Pivot Sliced: cost={min_cost:.2f}")
for i in range(n):
axs[0, -1].plot(
[X[i, 0], Y[min_perm[i], 0]],
[X[i, 1], Y[min_perm[i], 1]],
color="black",
alpha=alpha,
label="min-Sliced perm" if i == 0 else None,
)
axs[0, -1].scatter(X[:, 0], X[:, 1], label="X")
axs[0, -1].scatter(Y[:, 0], Y[:, 1], label="Y")
axs[1, -1].imshow(min_plan, interpolation="nearest", cmap="Blues")

for ax in axs.flatten():
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])

fig.tight_layout()
4 changes: 4 additions & 0 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@
sliced_wasserstein_sphere,
sliced_wasserstein_sphere_unif,
linear_sliced_wasserstein_sphere,
min_pivot_sliced,
expected_sliced,
)
from .gromov import (
gromov_wasserstein,
Expand Down Expand Up @@ -109,6 +111,8 @@
"sliced_wasserstein_distance",
"sliced_wasserstein_sphere",
"linear_sliced_wasserstein_sphere",
"min_pivot_sliced",
"expected_sliced",
"gromov_wasserstein",
"gromov_wasserstein2",
"gromov_barycenters",
Expand Down
Loading
Loading