|
10 | 10 |
|
11 | 11 | import numpy as np |
12 | 12 | import warnings |
| 13 | +from scipy.sparse import issparse as scipy_issparse |
13 | 14 |
|
14 | 15 | from ..utils import list_to_array, check_number_threads |
15 | 16 | from ..backend import get_backend |
@@ -298,10 +299,20 @@ def emd( |
298 | 299 | a, b = list_to_array(a, b) |
299 | 300 | nx = get_backend(a, b) |
300 | 301 |
|
301 | | - # Check if M is sparse using backend's issparse method |
302 | | - is_sparse = nx.issparse(M) |
| 302 | + # Check if M is sparse (either backend sparse or scipy.sparse) |
| 303 | + is_sparse = nx.issparse(M) or scipy_issparse(M) |
303 | 304 |
|
304 | 305 | if is_sparse: |
| 306 | + # Check if backend supports sparse matrices |
| 307 | + backend_name = nx.__class__.__name__ |
| 308 | + if backend_name in ["JaxBackend", "TensorflowBackend"]: |
| 309 | + raise NotImplementedError( |
| 310 | + f"Sparse optimal transport is not supported for {backend_name}. " |
| 311 | + "JAX does not have native sparse matrix support, and TensorFlow's " |
| 312 | + "sparse implementation is incomplete. Please convert your sparse " |
| 313 | + "matrix to dense format using M.toarray() or equivalent before calling emd()." |
| 314 | + ) |
| 315 | + |
305 | 316 | # Extract COO data using backend method - returns numpy arrays |
306 | 317 | edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M) |
307 | 318 |
|
@@ -572,10 +583,22 @@ def emd2( |
572 | 583 | a, b = list_to_array(a, b) |
573 | 584 | nx = get_backend(a, b) |
574 | 585 |
|
575 | | - # Check if M is sparse using backend's issparse method |
576 | | - is_sparse = nx.issparse(M) |
| 586 | + # Check if M is sparse (either backend sparse or scipy.sparse) |
| 587 | + from scipy.sparse import issparse as scipy_issparse |
| 588 | + |
| 589 | + is_sparse = nx.issparse(M) or scipy_issparse(M) |
577 | 590 |
|
578 | 591 | if is_sparse: |
| 592 | + # Check if backend supports sparse matrices |
| 593 | + backend_name = nx.__class__.__name__ |
| 594 | + if backend_name in ["JaxBackend", "TensorflowBackend"]: |
| 595 | + raise NotImplementedError( |
| 596 | + f"Sparse optimal transport is not supported for {backend_name}. " |
| 597 | + "JAX does not have native sparse matrix support, and TensorFlow's " |
| 598 | + "sparse implementation is incomplete. Please convert your sparse " |
| 599 | + "matrix to dense format using M.toarray() or equivalent before calling emd2()." |
| 600 | + ) |
| 601 | + |
579 | 602 | # Extract COO data using backend method - returns numpy arrays |
580 | 603 | edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M) |
581 | 604 |
|
|
0 commit comments