|
4 | 4 | Sparse Optimal Transport |
5 | 5 | ============================================ |
6 | 6 |
|
7 | | -In many real-world optimal transport (OT) problems, the transport plan is naturally sparse: only a small fraction of all possible source-target pairs actually exchange mass. In such cases, using sparse OT solvers can provide significant computational speedups and memory savings compared to dense solvers, which compute and store the full transport matrix. |
8 | | -
|
9 | | -The figure below illustrates the advantages of sparse OT solvers over dense ones in terms of speed and memory usage for different sparsity levels of the transport plan. |
10 | | -
|
11 | | -.. image:: /_static/images/comparison.png |
12 | | - :align: center |
13 | | - :width: 80% |
14 | | - :alt: Dense vs Sparse OT: Speed and Memory Advantages |
| 7 | +In many real-world optimal transport (OT) problems, the transport plan is |
| 8 | +naturally sparse: only a small fraction of all possible source-target pairs |
| 9 | +actually exchange mass. Using sparse OT solvers can provide significant |
| 10 | +computational speedups and memory savings compared to dense solvers. |
| 11 | +
|
| 12 | +This example demonstrates how to use sparse cost matrices with POT's EMD solver, |
| 13 | +comparing sparse and dense formulations on both a minimal example and a larger |
| 14 | +concentric circles dataset. |
15 | 15 | """ |
16 | 16 |
|
17 | | - |
18 | 17 | # Author: Nathan Neike <nathan.neike@example.com> |
19 | 18 | # License: MIT License |
20 | 19 | # sphinx_gallery_thumbnail_number = 2 |
|
26 | 25 |
|
27 | 26 |
|
28 | 27 | ############################################################################## |
29 | | -# Generate minimal example data |
| 28 | +# Minimal example with 4 points |
30 | 29 | # ------------------------------ |
31 | | -# |
32 | | -# We create a simple example with 2 source points and 2 target points to |
33 | | -# illustrate the concept of sparse optimal transport. |
34 | 30 |
|
35 | 31 | # %% |
36 | 32 |
|
37 | | -X = np.array([[0, 0], [1, 0]]) |
38 | | -Y = np.array([[0, 1], [1, 1]]) |
39 | | -a = np.array([0.5, 0.5]) |
40 | | -b = np.array([0.5, 0.5]) |
41 | | - |
42 | | - |
43 | | -############################################################################## |
44 | | -# Build sparse cost matrix |
45 | | -# ------------------------- |
46 | | -# |
47 | | -# Instead of allowing all possible edges (dense OT), we only allow two edges: |
48 | | -# source 0 -> target 0 and source 1 -> target 1. This is specified using a |
49 | | -# sparse matrix format (COO). |
50 | | - |
51 | | -# %% |
| 33 | +X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]]) |
| 34 | +Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]]) |
| 35 | +a = np.array([0.25, 0.25, 0.25, 0.25]) |
| 36 | +b = np.array([0.25, 0.25, 0.25, 0.25]) |
52 | 37 |
|
53 | | -# Only allow two edges: source 0 -> target 0, source 1 -> target 1 |
54 | | -rows = [0, 1] |
55 | | -cols = [0, 1] |
56 | | -vals = [np.linalg.norm(X[0] - Y[0]), np.linalg.norm(X[1] - Y[1])] |
57 | | -M_sparse = coo_matrix((vals, (rows, cols)), shape=(2, 2)) |
| 38 | +# Build sparse cost matrix allowing only selected edges |
| 39 | +rows = [0, 1, 2, 3] |
| 40 | +cols = [0, 1, 2, 3] |
| 41 | +vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)] |
| 42 | +M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4)) |
58 | 43 |
|
59 | 44 |
|
60 | 45 | ############################################################################## |
61 | | -# Solve sparse OT problem |
62 | | -# ------------------------ |
63 | | -# |
64 | | -# When passing a sparse cost matrix to ot.emd with log=True, the solution |
65 | | -# is returned in the log dictionary with fields 'flow_sources', 'flow_targets', |
66 | | -# and 'flow_values' containing the edge information. |
| 46 | +# Solve and display sparse OT solution |
| 47 | +# ------------------------------------- |
67 | 48 |
|
68 | 49 | # %% |
69 | 50 |
|
70 | 51 | G, log = ot.emd(a, b, M_sparse, log=True) |
71 | 52 |
|
72 | 53 | print("Sparse OT cost:", log["cost"]) |
73 | | -print("Edges:") |
74 | | -for i, j, v in zip(log["flow_sources"], log["flow_targets"], log["flow_values"]): |
75 | | - print(f" source {i} -> target {j}, flow={v:.3f}") |
| 54 | +print("Solution format:", type(G)) |
| 55 | +print("Non-zero edges:", G.nnz) |
| 56 | +print("\nEdges:") |
| 57 | +G_coo = G if isinstance(G, coo_matrix) else G.tocoo() |
| 58 | +for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data): |
| 59 | + if v > 1e-10: |
| 60 | + print(f" source {i} -> target {j}, flow={v:.3f}") |
76 | 61 |
|
77 | 62 |
|
78 | 63 | ############################################################################## |
79 | | -# Visualize allowed edges |
80 | | -# --------------------------------- |
81 | | -# |
82 | | -# The sparse cost matrix only allows transport along specific edges. |
| 64 | +# Visualize sparse vs dense edge structure |
| 65 | +# ----------------------------------------- |
83 | 66 |
|
84 | 67 | # %% |
85 | 68 |
|
86 | | - |
87 | 69 | plt.figure(figsize=(8, 4)) |
88 | 70 |
|
89 | | -# Sparse OT: allowed edges only |
90 | 71 | plt.subplot(1, 2, 1) |
91 | 72 | plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) |
92 | 73 | plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) |
93 | 74 | for i, j in zip(rows, cols): |
94 | | - plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.6) |
| 75 | + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6) |
95 | 76 | plt.title("Sparse OT: Allowed Edges Only") |
96 | | - |
97 | | -plt.xlim(-0.5, 1.5) |
| 77 | +plt.xlim(-0.5, 2.0) |
98 | 78 | plt.ylim(-0.5, 1.5) |
99 | | -plt.xticks([0, 1]) |
100 | | -plt.yticks([0, 1]) |
101 | 79 |
|
102 | | -# Dense OT: all possible edges |
103 | 80 | plt.subplot(1, 2, 2) |
104 | 81 | plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) |
105 | 82 | plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) |
106 | | -for i in range(2): |
107 | | - for j in range(2): |
108 | | - plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=2, alpha=0.3) |
| 83 | +for i in range(len(X)): |
| 84 | + for j in range(len(Y)): |
| 85 | + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3) |
109 | 86 | plt.title("Dense OT: All Possible Edges") |
110 | | -plt.xlim(-0.5, 1.5) |
| 87 | +plt.xlim(-0.5, 2.0) |
111 | 88 | plt.ylim(-0.5, 1.5) |
112 | | -plt.xticks([0, 1]) |
113 | | -plt.yticks([0, 1]) |
114 | 89 |
|
115 | 90 | plt.tight_layout() |
116 | 91 |
|
117 | 92 |
|
118 | 93 | ############################################################################## |
119 | | -# Larger example with clusters |
120 | | -# -------------------------------------- |
121 | | -# |
122 | | -# Now we create a more realistic example with multiple clusters of sources |
123 | | -# and targets, where transport is only allowed within each cluster. |
| 94 | +# Larger example: concentric circles |
| 95 | +# ----------------------------------- |
124 | 96 |
|
125 | 97 | # %% |
126 | 98 |
|
127 | | -grid_size = 4 |
128 | | -n_clusters = grid_size * grid_size |
129 | | -points_per_cluster = 2 |
130 | | -cluster_spacing = 15.0 |
131 | | -intra_cluster_spacing = 1.5 |
132 | | -cluster_centers = ( |
133 | | - np.array([[i, j] for i in range(grid_size) for j in range(grid_size)]) |
134 | | - * cluster_spacing |
135 | | -) |
| 99 | +n_clusters = 8 |
| 100 | +points_per_cluster = 25 |
| 101 | +n = n_clusters * points_per_cluster |
| 102 | +k_neighbors = 8 |
| 103 | +rng = np.random.default_rng(0) |
136 | 104 |
|
137 | | -X_large = [] |
138 | | -Y_large = [] |
139 | | -a_large = [] |
140 | | -b_large = [] |
| 105 | +r_source = 1.0 |
| 106 | +r_target = 2.0 |
| 107 | +noise_scale = 0.06 |
141 | 108 |
|
142 | | -for idx, (cx, cy) in enumerate(cluster_centers): |
143 | | - for i in range(points_per_cluster): |
144 | | - X_large.append( |
145 | | - [cx + intra_cluster_spacing * (i - 1), cy - intra_cluster_spacing] |
146 | | - ) |
147 | | - a_large.append(1.0 / (n_clusters * points_per_cluster)) |
148 | | - |
149 | | - for i in range(points_per_cluster): |
150 | | - Y_large.append( |
151 | | - [cx + intra_cluster_spacing * (i - 1), cy + intra_cluster_spacing] |
152 | | - ) |
153 | | - b_large.append(1.0 / (n_clusters * points_per_cluster)) |
| 109 | +theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False) |
| 110 | +cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster) |
154 | 111 |
|
155 | | -X_large = np.array(X_large) |
156 | | -Y_large = np.array(Y_large) |
157 | | -a_large = np.array(a_large) |
158 | | -b_large = np.array(b_large) |
159 | | - |
160 | | -nA = nB = n_clusters * points_per_cluster |
161 | | -source_labels = np.repeat(np.arange(n_clusters), points_per_cluster) |
162 | | -sink_labels = np.repeat(np.arange(n_clusters), points_per_cluster) |
163 | | - |
164 | | - |
165 | | -############################################################################## |
166 | | -# Build sparse cost matrix (intra-cluster only) |
167 | | -# ---------------------------------------------- |
168 | | -# |
169 | | -# We construct a sparse cost matrix that only includes edges within each cluster. |
| 112 | +X_large = np.column_stack( |
| 113 | + [r_source * np.cos(theta), r_source * np.sin(theta)] |
| 114 | +) + rng.normal(scale=noise_scale, size=(n, 2)) |
| 115 | +Y_large = np.column_stack( |
| 116 | + [r_target * np.cos(theta), r_target * np.sin(theta)] |
| 117 | +) + rng.normal(scale=noise_scale, size=(n, 2)) |
170 | 118 |
|
171 | | -# %% |
| 119 | +a_large = np.zeros(n) |
| 120 | +b_large = np.zeros(n) |
| 121 | +for k in range(n_clusters): |
| 122 | + idx = np.where(cluster_labels == k)[0] |
| 123 | + a_large[idx] = 1.0 / n_clusters / points_per_cluster |
| 124 | + b_large[idx] = 1.0 / n_clusters / points_per_cluster |
172 | 125 |
|
173 | 126 | M_full = ot.dist(X_large, Y_large, metric="euclidean") |
174 | 127 |
|
| 128 | +# Build sparse cost matrix: intra-cluster k-nearest neighbors |
| 129 | +angles_X = np.arctan2(X_large[:, 1], X_large[:, 0]) |
| 130 | +angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0]) |
| 131 | + |
175 | 132 | rows = [] |
176 | 133 | cols = [] |
177 | 134 | vals = [] |
178 | 135 | for k in range(n_clusters): |
179 | | - src_idx = np.where(source_labels == k)[0] |
180 | | - sink_idx = np.where(sink_labels == k)[0] |
| 136 | + src_idx = np.where(cluster_labels == k)[0] |
| 137 | + tgt_idx = np.where(cluster_labels == k)[0] |
181 | 138 | for i in src_idx: |
182 | | - for j in sink_idx: |
| 139 | + diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i]))) |
| 140 | + idx = np.argsort(np.abs(diff))[:k_neighbors] |
| 141 | + for j_local in idx: |
| 142 | + j = tgt_idx[j_local] |
183 | 143 | rows.append(i) |
184 | 144 | cols.append(j) |
185 | 145 | vals.append(M_full[i, j]) |
186 | | -M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(nA, nB)) |
187 | 146 |
|
| 147 | +M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n)) |
| 148 | +allowed_sparse = set(zip(rows, cols)) |
188 | 149 |
|
189 | 150 | ############################################################################## |
190 | | -# Visualize allowed edges structure |
191 | | -# ---------------------------------- |
192 | | -# |
193 | | -# Dense OT allows all connections, while sparse OT restricts to intra-cluster edges. |
| 151 | +# Visualize edge structures |
| 152 | +# -------------------------- |
194 | 153 |
|
195 | 154 | # %% |
196 | 155 |
|
197 | 156 | plt.figure(figsize=(16, 6)) |
198 | 157 |
|
199 | | -# Dense OT: all possible edges |
200 | 158 | plt.subplot(1, 2, 1) |
201 | | -for i in range(nA): |
202 | | - for j in range(nB): |
| 159 | +for i in range(n): |
| 160 | + for j in range(n): |
203 | 161 | plt.plot( |
204 | 162 | [X_large[i, 0], Y_large[j, 0]], |
205 | 163 | [X_large[i, 1], Y_large[j, 1]], |
206 | 164 | color="blue", |
207 | | - alpha=0.1, |
208 | | - linewidth=0.7, |
| 165 | + alpha=0.2, |
| 166 | + linewidth=0.05, |
209 | 167 | ) |
210 | 168 | plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
211 | 169 | plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
212 | 170 | plt.axis("equal") |
213 | 171 | plt.title("Dense OT: All Possible Edges") |
214 | 172 |
|
215 | | -# Sparse OT: only intra-cluster edges |
216 | 173 | plt.subplot(1, 2, 2) |
217 | | -for k in range(n_clusters): |
218 | | - src_idx = np.where(source_labels == k)[0] |
219 | | - sink_idx = np.where(sink_labels == k)[0] |
220 | | - for i in src_idx: |
221 | | - for j in sink_idx: |
222 | | - plt.plot( |
223 | | - [X_large[i, 0], Y_large[j, 0]], |
224 | | - [X_large[i, 1], Y_large[j, 1]], |
225 | | - color="blue", |
226 | | - alpha=0.7, |
227 | | - linewidth=1.5, |
228 | | - ) |
| 174 | +for i, j in allowed_sparse: |
| 175 | + plt.plot( |
| 176 | + [X_large[i, 0], Y_large[j, 0]], |
| 177 | + [X_large[i, 1], Y_large[j, 1]], |
| 178 | + color="blue", |
| 179 | + alpha=1, |
| 180 | + linewidth=0.05, |
| 181 | + ) |
229 | 182 | plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
230 | 183 | plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
231 | 184 | plt.axis("equal") |
232 | | -plt.title("Sparse OT: Only Intra-Cluster Edges") |
| 185 | +plt.title("Sparse OT: Intra-Cluster k-NN Edges") |
233 | 186 |
|
234 | 187 | plt.tight_layout() |
235 | | - |
| 188 | +plt.show() |
236 | 189 |
|
237 | 190 | ############################################################################## |
238 | | -# Solve and compare sparse vs dense OT |
239 | | -# ------------------------------------- |
240 | | -# |
241 | | -# We solve both dense and sparse OT problems and verify that they produce |
242 | | -# the same optimal solution when the sparse edges include the optimal paths. |
| 191 | +# Solve and visualize transport plans |
| 192 | +# ------------------------------------ |
243 | 193 |
|
244 | 194 | # %% |
245 | 195 |
|
246 | | -# Solve dense OT (full cost matrix) |
247 | 196 | G_dense = ot.emd(a_large, b_large, M_full) |
248 | 197 | cost_dense = np.sum(G_dense * M_full) |
249 | 198 | print(f"Dense OT cost: {cost_dense:.6f}") |
250 | 199 |
|
251 | | -# Solve sparse OT (intra-cluster only) |
252 | 200 | G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True) |
253 | 201 | cost_sparse = log_sparse["cost"] |
254 | 202 | print(f"Sparse OT cost: {cost_sparse:.6f}") |
255 | 203 |
|
256 | | - |
257 | | -############################################################################## |
258 | | -# Visualize optimal transport plans |
259 | | -# ---------------------------------- |
260 | | -# |
261 | | -# Plot the edges that carry flow in the optimal solutions. |
262 | | - |
263 | | -# %% |
264 | | - |
265 | 204 | plt.figure(figsize=(16, 6)) |
266 | 205 |
|
267 | | -# Dense OT |
268 | 206 | plt.subplot(1, 2, 1) |
269 | | -for i in range(nA): |
270 | | - for j in range(nB): |
271 | | - if G_dense[i, j] > 1e-10: |
272 | | - plt.plot( |
273 | | - [X_large[i, 0], Y_large[j, 0]], |
274 | | - [X_large[i, 1], Y_large[j, 1]], |
275 | | - color="blue", |
276 | | - alpha=0.7, |
277 | | - linewidth=1.5, |
278 | | - ) |
279 | | -plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
280 | | -plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 207 | +ot.plot.plot2D_samples_mat( |
| 208 | + X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 |
| 209 | +) |
| 210 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) |
| 211 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) |
281 | 212 | plt.axis("equal") |
282 | 213 | plt.title("Dense OT: Optimal Transport Plan") |
283 | 214 |
|
284 | | -# Sparse OT |
285 | 215 | plt.subplot(1, 2, 2) |
286 | | -if log_sparse["flow_sources"] is not None: |
287 | | - for i, j, v in zip( |
288 | | - log_sparse["flow_sources"], |
289 | | - log_sparse["flow_targets"], |
290 | | - log_sparse["flow_values"], |
291 | | - ): |
292 | | - if v > 1e-10: |
293 | | - plt.plot( |
294 | | - [X_large[i, 0], Y_large[j, 0]], |
295 | | - [X_large[i, 1], Y_large[j, 1]], |
296 | | - color="blue", |
297 | | - alpha=0.7, |
298 | | - linewidth=1.5, |
299 | | - ) |
300 | | -plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) |
301 | | -plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) |
| 216 | +ot.plot.plot2D_samples_mat( |
| 217 | + X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 |
| 218 | +) |
| 219 | +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) |
| 220 | +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) |
302 | 221 | plt.axis("equal") |
303 | 222 | plt.title("Sparse OT: Optimal Transport Plan") |
304 | 223 |
|
|
0 commit comments