Skip to content

Commit 1152398

Browse files
committed
feat: Add backend-agnostic sparse EMD support
Refactor sparse optimal transport implementation to work across different backends (NumPy/scipy.sparse, PyTorch/torch.sparse). Key changes: - Add `sparse_coo_data()` method to backend layer for uniform sparse matrix handling across NumPy, PyTorch, JAX, and TensorFlow backends - Update `emd()` and `emd2()` to return transport plans in backend-native sparse format (scipy.sparse for NumPy, torch.sparse for PyTorch) - Refactor `plot2D_samples_mat()` to efficiently visualize both dense and sparse transport plans by detecting format and iterating only over non-zero entries for sparse matrices - Update `plot_sparse_emd.py` example to use new plotting function - Add comprehensive tests for sparse EMD across backends - Update documentation to reflect backend-agnostic sparse support
1 parent b184cd4 commit 1152398

File tree

5 files changed

+433
-397
lines changed

5 files changed

+433
-397
lines changed

examples/plot_sparse_emd.py

Lines changed: 99 additions & 180 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,16 @@
44
Sparse Optimal Transport
55
============================================
66
7-
In many real-world optimal transport (OT) problems, the transport plan is naturally sparse: only a small fraction of all possible source-target pairs actually exchange mass. In such cases, using sparse OT solvers can provide significant computational speedups and memory savings compared to dense solvers, which compute and store the full transport matrix.
8-
9-
The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan.
10-
11-
.. image:: /_static/images/comparison.png
12-
:align: center
13-
:width: 80%
14-
:alt: Dense vs Sparse OT: Speed and Memory Advantages
7+
In many real-world optimal transport (OT) problems, the transport plan is
8+
naturally sparse: only a small fraction of all possible source-target pairs
9+
actually exchange mass. Using sparse OT solvers can provide significant
10+
computational speedups and memory savings compared to dense solvers.
11+
12+
This example demonstrates how to use sparse cost matrices with POT's EMD solver,
13+
comparing sparse and dense formulations on both a minimal example and a larger
14+
concentric circles dataset.
1515
"""
1616

17-
1817
# Author: Nathan Neike <nathan.neike@example.com>
1918
# License: MIT License
2019
# sphinx_gallery_thumbnail_number = 2
@@ -26,279 +25,199 @@
2625

2726

2827
##############################################################################
29-
# Generate minimal example data
28+
# Minimal example with 4 points
3029
# ------------------------------
31-
#
32-
# We create a simple example with 2 source points and 2 target points to
33-
# illustrate the concept of sparse optimal transport.
3430

3531
# %%
3632

37-
X = np.array([[0, 0], [1, 0]])
38-
Y = np.array([[0, 1], [1, 1]])
39-
a = np.array([0.5, 0.5])
40-
b = np.array([0.5, 0.5])
41-
42-
43-
##############################################################################
44-
# Build sparse cost matrix
45-
# -------------------------
46-
#
47-
# Instead of allowing all possible edges (dense OT), we only allow two edges:
48-
# source 0 -> target 0 and source 1 -> target 1. This is specified using a
49-
# sparse matrix format (COO).
50-
51-
# %%
33+
X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]])
34+
Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]])
35+
a = np.array([0.25, 0.25, 0.25, 0.25])
36+
b = np.array([0.25, 0.25, 0.25, 0.25])
5237

53-
# Only allow two edges: source 0 -> target 0, source 1 -> target 1
54-
rows = [0, 1]
55-
cols = [0, 1]
56-
vals = [np.linalg.norm(X[0] - Y[0]), np.linalg.norm(X[1] - Y[1])]
57-
M_sparse = coo_matrix((vals, (rows, cols)), shape=(2, 2))
38+
# Build sparse cost matrix allowing only selected edges
39+
rows = [0, 1, 2, 3]
40+
cols = [0, 1, 2, 3]
41+
vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)]
42+
M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4))
5843

5944

6045
##############################################################################
61-
# Solve sparse OT problem
62-
# ------------------------
63-
#
64-
# When passing a sparse cost matrix to ot.emd with log=True, the solution
65-
# is returned in the log dictionary with fields 'flow_sources', 'flow_targets',
66-
# and 'flow_values' containing the edge information.
46+
# Solve and display sparse OT solution
47+
# -------------------------------------
6748

6849
# %%
6950

7051
G, log = ot.emd(a, b, M_sparse, log=True)
7152

7253
print("Sparse OT cost:", log["cost"])
73-
print("Edges:")
74-
for i, j, v in zip(log["flow_sources"], log["flow_targets"], log["flow_values"]):
75-
print(f" source {i} -> target {j}, flow={v:.3f}")
54+
print("Solution format:", type(G))
55+
print("Non-zero edges:", G.nnz)
56+
print("\nEdges:")
57+
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
58+
for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data):
59+
if v > 1e-10:
60+
print(f" source {i} -> target {j}, flow={v:.3f}")
7661

7762

7863
##############################################################################
79-
# Visualize allowed edges
80-
# ---------------------------------
81-
#
82-
# The sparse cost matrix only allows transport along specific edges.
64+
# Visualize sparse vs dense edge structure
65+
# -----------------------------------------
8366

8467
# %%
8568

86-
8769
plt.figure(figsize=(8, 4))
8870

89-
# Sparse OT: allowed edges only
9071
plt.subplot(1, 2, 1)
9172
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
9273
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
9374
for i, j in zip(rows, cols):
94-
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.6)
75+
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6)
9576
plt.title("Sparse OT: Allowed Edges Only")
96-
97-
plt.xlim(-0.5, 1.5)
77+
plt.xlim(-0.5, 2.0)
9878
plt.ylim(-0.5, 1.5)
99-
plt.xticks([0, 1])
100-
plt.yticks([0, 1])
10179

102-
# Dense OT: all possible edges
10380
plt.subplot(1, 2, 2)
10481
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
10582
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
106-
for i in range(2):
107-
for j in range(2):
108-
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.3)
83+
for i in range(len(X)):
84+
for j in range(len(Y)):
85+
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3)
10986
plt.title("Dense OT: All Possible Edges")
110-
plt.xlim(-0.5, 1.5)
87+
plt.xlim(-0.5, 2.0)
11188
plt.ylim(-0.5, 1.5)
112-
plt.xticks([0, 1])
113-
plt.yticks([0, 1])
11489

11590
plt.tight_layout()
11691

11792

11893
##############################################################################
119-
# Larger example with clusters
120-
# --------------------------------------
121-
#
122-
# Now we create a more realistic example with multiple clusters of sources
123-
# and targets, where transport is only allowed within each cluster.
94+
# Larger example: concentric circles
95+
# -----------------------------------
12496

12597
# %%
12698

127-
grid_size = 4
128-
n_clusters = grid_size * grid_size
129-
points_per_cluster = 2
130-
cluster_spacing = 15.0
131-
intra_cluster_spacing = 1.5
132-
cluster_centers = (
133-
np.array([[i, j] for i in range(grid_size) for j in range(grid_size)])
134-
* cluster_spacing
135-
)
99+
n_clusters = 8
100+
points_per_cluster = 25
101+
n = n_clusters * points_per_cluster
102+
k_neighbors = 8
103+
rng = np.random.default_rng(0)
136104

137-
X_large = []
138-
Y_large = []
139-
a_large = []
140-
b_large = []
105+
r_source = 1.0
106+
r_target = 2.0
107+
noise_scale = 0.06
141108

142-
for idx, (cx, cy) in enumerate(cluster_centers):
143-
for i in range(points_per_cluster):
144-
X_large.append(
145-
[cx + intra_cluster_spacing * (i - 1), cy - intra_cluster_spacing]
146-
)
147-
a_large.append(1.0 / (n_clusters * points_per_cluster))
148-
149-
for i in range(points_per_cluster):
150-
Y_large.append(
151-
[cx + intra_cluster_spacing * (i - 1), cy + intra_cluster_spacing]
152-
)
153-
b_large.append(1.0 / (n_clusters * points_per_cluster))
109+
theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False)
110+
cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
154111

155-
X_large = np.array(X_large)
156-
Y_large = np.array(Y_large)
157-
a_large = np.array(a_large)
158-
b_large = np.array(b_large)
159-
160-
nA = nB = n_clusters * points_per_cluster
161-
source_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
162-
sink_labels = np.repeat(np.arange(n_clusters), points_per_cluster)
163-
164-
165-
##############################################################################
166-
# Build sparse cost matrix (intra-cluster only)
167-
# ----------------------------------------------
168-
#
169-
# We construct a sparse cost matrix that only includes edges within each cluster.
112+
X_large = np.column_stack(
113+
[r_source * np.cos(theta), r_source * np.sin(theta)]
114+
) + rng.normal(scale=noise_scale, size=(n, 2))
115+
Y_large = np.column_stack(
116+
[r_target * np.cos(theta), r_target * np.sin(theta)]
117+
) + rng.normal(scale=noise_scale, size=(n, 2))
170118

171-
# %%
119+
a_large = np.zeros(n)
120+
b_large = np.zeros(n)
121+
for k in range(n_clusters):
122+
idx = np.where(cluster_labels == k)[0]
123+
a_large[idx] = 1.0 / n_clusters / points_per_cluster
124+
b_large[idx] = 1.0 / n_clusters / points_per_cluster
172125

173126
M_full = ot.dist(X_large, Y_large, metric="euclidean")
174127

128+
# Build sparse cost matrix: intra-cluster k-nearest neighbors
129+
angles_X = np.arctan2(X_large[:, 1], X_large[:, 0])
130+
angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0])
131+
175132
rows = []
176133
cols = []
177134
vals = []
178135
for k in range(n_clusters):
179-
src_idx = np.where(source_labels == k)[0]
180-
sink_idx = np.where(sink_labels == k)[0]
136+
src_idx = np.where(cluster_labels == k)[0]
137+
tgt_idx = np.where(cluster_labels == k)[0]
181138
for i in src_idx:
182-
for j in sink_idx:
139+
diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i])))
140+
idx = np.argsort(np.abs(diff))[:k_neighbors]
141+
for j_local in idx:
142+
j = tgt_idx[j_local]
183143
rows.append(i)
184144
cols.append(j)
185145
vals.append(M_full[i, j])
186-
M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(nA, nB))
187146

147+
M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n))
148+
allowed_sparse = set(zip(rows, cols))
188149

189150
##############################################################################
190-
# Visualize allowed edges structure
191-
# ----------------------------------
192-
#
193-
# Dense OT allows all connections, while sparse OT restricts to intra-cluster edges.
151+
# Visualize edge structures
152+
# --------------------------
194153

195154
# %%
196155

197156
plt.figure(figsize=(16, 6))
198157

199-
# Dense OT: all possible edges
200158
plt.subplot(1, 2, 1)
201-
for i in range(nA):
202-
for j in range(nB):
159+
for i in range(n):
160+
for j in range(n):
203161
plt.plot(
204162
[X_large[i, 0], Y_large[j, 0]],
205163
[X_large[i, 1], Y_large[j, 1]],
206164
color="blue",
207-
alpha=0.1,
208-
linewidth=0.7,
165+
alpha=0.2,
166+
linewidth=0.05,
209167
)
210168
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
211169
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
212170
plt.axis("equal")
213171
plt.title("Dense OT: All Possible Edges")
214172

215-
# Sparse OT: only intra-cluster edges
216173
plt.subplot(1, 2, 2)
217-
for k in range(n_clusters):
218-
src_idx = np.where(source_labels == k)[0]
219-
sink_idx = np.where(sink_labels == k)[0]
220-
for i in src_idx:
221-
for j in sink_idx:
222-
plt.plot(
223-
[X_large[i, 0], Y_large[j, 0]],
224-
[X_large[i, 1], Y_large[j, 1]],
225-
color="blue",
226-
alpha=0.7,
227-
linewidth=1.5,
228-
)
174+
for i, j in allowed_sparse:
175+
plt.plot(
176+
[X_large[i, 0], Y_large[j, 0]],
177+
[X_large[i, 1], Y_large[j, 1]],
178+
color="blue",
179+
alpha=1,
180+
linewidth=0.05,
181+
)
229182
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
230183
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
231184
plt.axis("equal")
232-
plt.title("Sparse OT: Only Intra-Cluster Edges")
185+
plt.title("Sparse OT: Intra-Cluster k-NN Edges")
233186

234187
plt.tight_layout()
235-
188+
plt.show()
236189

237190
##############################################################################
238-
# Solve and compare sparse vs dense OT
239-
# -------------------------------------
240-
#
241-
# We solve both dense and sparse OT problems and verify that they produce
242-
# the same optimal solution when the sparse edges include the optimal paths.
191+
# Solve and visualize transport plans
192+
# ------------------------------------
243193

244194
# %%
245195

246-
# Solve dense OT (full cost matrix)
247196
G_dense = ot.emd(a_large, b_large, M_full)
248197
cost_dense = np.sum(G_dense * M_full)
249198
print(f"Dense OT cost: {cost_dense:.6f}")
250199

251-
# Solve sparse OT (intra-cluster only)
252200
G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True)
253201
cost_sparse = log_sparse["cost"]
254202
print(f"Sparse OT cost: {cost_sparse:.6f}")
255203

256-
257-
##############################################################################
258-
# Visualize optimal transport plans
259-
# ----------------------------------
260-
#
261-
# Plot the edges that carry flow in the optimal solutions.
262-
263-
# %%
264-
265204
plt.figure(figsize=(16, 6))
266205

267-
# Dense OT
268206
plt.subplot(1, 2, 1)
269-
for i in range(nA):
270-
for j in range(nB):
271-
if G_dense[i, j] > 1e-10:
272-
plt.plot(
273-
[X_large[i, 0], Y_large[j, 0]],
274-
[X_large[i, 1], Y_large[j, 1]],
275-
color="blue",
276-
alpha=0.7,
277-
linewidth=1.5,
278-
)
279-
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
280-
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
207+
ot.plot.plot2D_samples_mat(
208+
X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
209+
)
210+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
211+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
281212
plt.axis("equal")
282213
plt.title("Dense OT: Optimal Transport Plan")
283214

284-
# Sparse OT
285215
plt.subplot(1, 2, 2)
286-
if log_sparse["flow_sources"] is not None:
287-
for i, j, v in zip(
288-
log_sparse["flow_sources"],
289-
log_sparse["flow_targets"],
290-
log_sparse["flow_values"],
291-
):
292-
if v > 1e-10:
293-
plt.plot(
294-
[X_large[i, 0], Y_large[j, 0]],
295-
[X_large[i, 1], Y_large[j, 1]],
296-
color="blue",
297-
alpha=0.7,
298-
linewidth=1.5,
299-
)
300-
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
301-
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
216+
ot.plot.plot2D_samples_mat(
217+
X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
218+
)
219+
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
220+
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
302221
plt.axis("equal")
303222
plt.title("Sparse OT: Optimal Transport Plan")
304223

0 commit comments

Comments
 (0)