Skip to content

Commit 1a3dc41

Browse files
committed
throw error when unsupported backend used for sparse
1 parent 1152398 commit 1a3dc41

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

ot/lp/_network_simplex.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import numpy as np
1212
import warnings
13+
from scipy.sparse import issparse as scipy_issparse
1314

1415
from ..utils import list_to_array, check_number_threads
1516
from ..backend import get_backend
@@ -298,10 +299,20 @@ def emd(
298299
a, b = list_to_array(a, b)
299300
nx = get_backend(a, b)
300301

301-
# Check if M is sparse using backend's issparse method
302-
is_sparse = nx.issparse(M)
302+
# Check if M is sparse (either backend sparse or scipy.sparse)
303+
is_sparse = nx.issparse(M) or scipy_issparse(M)
303304

304305
if is_sparse:
306+
# Check if backend supports sparse matrices
307+
backend_name = nx.__class__.__name__
308+
if backend_name in ["JaxBackend", "TensorflowBackend"]:
309+
raise NotImplementedError(
310+
f"Sparse optimal transport is not supported for {backend_name}. "
311+
"JAX does not have native sparse matrix support, and TensorFlow's "
312+
"sparse implementation is incomplete. Please convert your sparse "
313+
"matrix to dense format using M.toarray() or equivalent before calling emd()."
314+
)
315+
305316
# Extract COO data using backend method - returns numpy arrays
306317
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)
307318

@@ -572,10 +583,22 @@ def emd2(
572583
a, b = list_to_array(a, b)
573584
nx = get_backend(a, b)
574585

575-
# Check if M is sparse using backend's issparse method
576-
is_sparse = nx.issparse(M)
586+
# Check if M is sparse (either backend sparse or scipy.sparse)
587+
from scipy.sparse import issparse as scipy_issparse
588+
589+
is_sparse = nx.issparse(M) or scipy_issparse(M)
577590

578591
if is_sparse:
592+
# Check if backend supports sparse matrices
593+
backend_name = nx.__class__.__name__
594+
if backend_name in ["JaxBackend", "TensorflowBackend"]:
595+
raise NotImplementedError(
596+
f"Sparse optimal transport is not supported for {backend_name}. "
597+
"JAX does not have native sparse matrix support, and TensorFlow's "
598+
"sparse implementation is incomplete. Please convert your sparse "
599+
"matrix to dense format using M.toarray() or equivalent before calling emd2()."
600+
)
601+
579602
# Extract COO data using backend method - returns numpy arrays
580603
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)
581604

0 commit comments

Comments
 (0)