Skip to content

Commit 54479d5

Browse files
committed
Fix sparse tensor gradients and add backend checks
- Preserve PyTorch sparse tensors through numpy conversion for autograd - Verify gradient w.r.t. M equals transport plan - Add sparse backend compatibility checks and teststhrow error when unsupported backend used for sparse"
1 parent 1a3dc41 commit 54479d5

File tree

4 files changed

+185
-15
lines changed

4 files changed

+185
-15
lines changed

ot/backend.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,16 @@ def _get_backend_instance(backend_impl):
178178

179179

180180
def _check_args_backend(backend_impl, args):
181-
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
181+
# Get backend instance to use issparse method
182+
backend = _get_backend_instance(backend_impl)
183+
184+
# Check if each arg is either:
185+
# 1. An instance of backend.__type__ (e.g., np.ndarray for NumPy)
186+
# 2. A sparse matrix recognized by backend.issparse() (e.g., scipy.sparse for NumPy)
187+
is_instance = set(
188+
isinstance(arg, backend_impl.__type__) or backend.issparse(arg) for arg in args
189+
)
190+
182191
# check that all arguments matched or not the type
183192
if len(is_instance) == 1:
184193
return is_instance.pop()

ot/lp/_network_simplex.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import warnings
13-
from scipy.sparse import issparse as scipy_issparse
1413

1514
from ..utils import list_to_array, check_number_threads
1615
from ..backend import get_backend
@@ -295,12 +294,14 @@ def emd(
295294
edge_costs = None
296295
n1, n2 = None, None
297296

298-
# Get backend to check if M is sparse
299-
a, b = list_to_array(a, b)
300-
nx = get_backend(a, b)
297+
# Get backend from M first, then use it for list_to_array
298+
# This ensures empty lists [] are converted to arrays in the correct backend
299+
nx_M = get_backend(M)
300+
a, b = list_to_array(a, b, nx=nx_M)
301+
nx = get_backend(a, b, M)
301302

302-
# Check if M is sparse (either backend sparse or scipy.sparse)
303-
is_sparse = nx.issparse(M) or scipy_issparse(M)
303+
# Check if M is sparse using backend's issparse method
304+
is_sparse = nx.issparse(M)
304305

305306
if is_sparse:
306307
# Check if backend supports sparse matrices
@@ -579,14 +580,17 @@ def emd2(
579580
edge_costs = None
580581
n1, n2 = None, None
581582

582-
# Get backend to check if M is sparse
583-
a, b = list_to_array(a, b)
584-
nx = get_backend(a, b)
583+
# Get backend from M first, then use it for list_to_array
584+
# This ensures empty lists [] are converted to arrays in the correct backend
585+
nx_M = get_backend(M)
586+
a, b = list_to_array(a, b, nx=nx_M)
587+
nx = get_backend(a, b, M)
585588

586-
# Check if M is sparse (either backend sparse or scipy.sparse)
587-
from scipy.sparse import issparse as scipy_issparse
589+
# Check if M is sparse using backend's issparse method
590+
is_sparse = nx.issparse(M)
588591

589-
is_sparse = nx.issparse(M) or scipy_issparse(M)
592+
# Save original sparse tensor for gradient tracking (before conversion to numpy)
593+
M_original_sparse = None
590594

591595
if is_sparse:
592596
# Check if backend supports sparse matrices
@@ -599,6 +603,9 @@ def emd2(
599603
"matrix to dense format using M.toarray() or equivalent before calling emd2()."
600604
)
601605

606+
# Save original M for gradient tracking (before numpy conversion)
607+
M_original_sparse = M
608+
602609
# Extract COO data using backend method - returns numpy arrays
603610
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)
604611

@@ -641,7 +648,9 @@ def emd2(
641648
M0 = None if is_sparse else M
642649

643650
if is_sparse:
644-
edge_costs_original = nx.from_numpy(edge_costs, type_as=type_as)
651+
# Use the original sparse tensor (preserves gradients for PyTorch)
652+
# instead of converting from numpy
653+
edge_costs_original = M_original_sparse
645654
else:
646655
edge_costs_original = None
647656

@@ -713,13 +722,27 @@ def f(b):
713722
if edge_idx >= 0:
714723
grad_edge_costs[edge_idx] = flow
715724

725+
# Convert gradient to sparse format matching edge_costs_original
726+
grad_edge_costs_backend = nx.from_numpy(grad_edge_costs, type_as=type_as)
727+
if nx.issparse(edge_costs_original):
728+
# Reconstruct sparse gradient tensor with same structure as original
729+
grad_M_sparse = nx.coo_matrix(
730+
grad_edge_costs_backend,
731+
nx.from_numpy(edge_sources.astype(np.int64), type_as=type_as),
732+
nx.from_numpy(edge_targets.astype(np.int64), type_as=type_as),
733+
shape=(n1, n2),
734+
type_as=type_as,
735+
)
736+
else:
737+
grad_M_sparse = grad_edge_costs_backend
738+
716739
cost = nx.set_gradients(
717740
nx.from_numpy(cost, type_as=type_as),
718741
(a0, b0, edge_costs_original),
719742
(
720743
nx.from_numpy(u - np.mean(u), type_as=type_as),
721744
nx.from_numpy(v - np.mean(v), type_as=type_as),
722-
nx.from_numpy(grad_edge_costs, type_as=type_as),
745+
grad_M_sparse,
723746
),
724747
)
725748
else:

test/test_backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,48 @@ class nx_subclass(nx.__type__):
7575
assert effective_nx.__name__ == nx.__name__
7676

7777

78+
def test_get_backend_sparse_matrix():
79+
"""Test that get_backend correctly handles sparse matrices and rejects mixed backends."""
80+
from scipy.sparse import coo_matrix
81+
82+
a_np = np.array([0.5, 0.5])
83+
b_np = np.array([0.5, 0.5])
84+
M_scipy = coo_matrix(([1.0, 2.0], ([0, 1], [0, 1])), shape=(2, 2))
85+
86+
nx = get_backend(a_np, b_np, M_scipy)
87+
assert nx.__name__ == "numpy", "NumPy backend should accept scipy.sparse matrices"
88+
89+
nx = get_backend(M_scipy)
90+
assert nx.__name__ == "numpy", "scipy.sparse should use NumPy backend"
91+
92+
if torch:
93+
a_torch = torch.tensor([0.5, 0.5])
94+
b_torch = torch.tensor([0.5, 0.5])
95+
M_torch_sparse = torch.sparse_coo_tensor(
96+
torch.tensor([[0, 1], [0, 1]]), torch.tensor([1.0, 2.0]), (2, 2)
97+
)
98+
99+
nx = get_backend(a_torch, b_torch, M_torch_sparse)
100+
assert (
101+
nx.__name__ == "torch"
102+
), "PyTorch backend should accept torch.sparse tensors"
103+
104+
nx = get_backend(M_torch_sparse)
105+
assert nx.__name__ == "torch", "torch.sparse should use PyTorch backend"
106+
107+
# Case 1: PyTorch dense + scipy.sparse (incompatible)
108+
with pytest.raises(ValueError):
109+
get_backend(a_torch, b_torch, M_scipy)
110+
111+
# Case 2: NumPy dense + torch.sparse (incompatible)
112+
with pytest.raises(ValueError):
113+
get_backend(a_np, b_np, M_torch_sparse)
114+
115+
# Case 3: scipy.sparse + torch.sparse (incompatible)
116+
with pytest.raises(ValueError):
117+
get_backend(M_scipy, M_torch_sparse)
118+
119+
78120
def test_convert_between_backends(nx):
79121
A = np.zeros((3, 2))
80122
B = np.zeros((3, 1))

test/test_ot.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,102 @@ def test_emd2_sparse_vs_dense():
10831083
np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7)
10841084

10851085

1086+
def test_emd2_sparse_gradients():
1087+
"""Test that PyTorch sparse tensors support gradient computation."""
1088+
if not torch:
1089+
pytest.skip("PyTorch not available")
1090+
1091+
n = 10
1092+
a = torch.tensor(ot.utils.unif(n), requires_grad=True, dtype=torch.float64)
1093+
b = torch.tensor(ot.utils.unif(n), requires_grad=True, dtype=torch.float64)
1094+
1095+
rows, cols, costs = [], [], []
1096+
for i in range(n):
1097+
rows.append(i)
1098+
cols.append(i)
1099+
costs.append(0.1)
1100+
for offset in [1, 2]:
1101+
j = (i + offset) % n
1102+
rows.append(i)
1103+
cols.append(j)
1104+
costs.append(float(offset))
1105+
1106+
indices = torch.tensor(
1107+
np.vstack([np.array(rows), np.array(cols)]), dtype=torch.int64
1108+
)
1109+
values = torch.tensor(costs, dtype=torch.float64)
1110+
M_sparse = torch.sparse_coo_tensor(indices, values, (n, n), dtype=torch.float64)
1111+
1112+
cost = ot.emd2(a, b, M_sparse)
1113+
cost.backward()
1114+
1115+
assert a.grad is not None
1116+
assert b.grad is not None
1117+
np.testing.assert_allclose(
1118+
a.grad.sum().item(), -b.grad.sum().item(), rtol=1e-5, atol=1e-7
1119+
)
1120+
1121+
1122+
def test_emd2_sparse_vs_dense_gradients():
1123+
"""Verify gradient w.r.t. cost matrix M equals transport plan G."""
1124+
if not torch:
1125+
pytest.skip("PyTorch not available")
1126+
1127+
n = 4
1128+
a = torch.tensor([0.25, 0.25, 0.25, 0.25], requires_grad=True, dtype=torch.float64)
1129+
b = torch.tensor([0.25, 0.25, 0.25, 0.25], requires_grad=True, dtype=torch.float64)
1130+
1131+
M_full = torch.tensor(
1132+
[
1133+
[0.1, 1.0, 2.0, 3.0],
1134+
[1.0, 0.1, 1.0, 2.0],
1135+
[2.0, 1.0, 0.1, 1.0],
1136+
[3.0, 2.0, 1.0, 0.1],
1137+
],
1138+
dtype=torch.float64,
1139+
requires_grad=True,
1140+
)
1141+
1142+
cost_dense = ot.emd2(a, b, M_full)
1143+
cost_dense.backward()
1144+
G_dense = ot.emd(a.detach(), b.detach(), M_full.detach())
1145+
1146+
np.testing.assert_allclose(
1147+
M_full.grad.numpy(), G_dense.numpy(), rtol=1e-7, atol=1e-10
1148+
)
1149+
1150+
a.grad = None
1151+
b.grad = None
1152+
1153+
rows, cols, costs = [], [], []
1154+
for i in range(n):
1155+
for j in range(max(0, i - 1), min(n, i + 2)):
1156+
rows.append(i)
1157+
cols.append(j)
1158+
costs.append(M_full[i, j].item())
1159+
1160+
rows_t = torch.tensor(rows, dtype=torch.int64)
1161+
cols_t = torch.tensor(cols, dtype=torch.int64)
1162+
M_sparse = torch.sparse_coo_tensor(
1163+
torch.stack([rows_t, cols_t]),
1164+
torch.tensor(costs, dtype=torch.float64),
1165+
(n, n),
1166+
dtype=torch.float64,
1167+
requires_grad=True,
1168+
)
1169+
1170+
cost_sparse = ot.emd2(a, b, M_sparse)
1171+
cost_sparse.backward()
1172+
G_sparse = ot.emd(a.detach(), b.detach(), M_sparse.detach()).to_dense()
1173+
1174+
grad_values = M_sparse.grad.coalesce().values().numpy()
1175+
G_values = G_sparse[rows_t, cols_t].numpy()
1176+
1177+
np.testing.assert_allclose(grad_values, G_values, rtol=1e-7, atol=1e-10)
1178+
assert grad_values.sum() > 0
1179+
assert np.abs(grad_values.sum() - 1.0) < 1e-7
1180+
1181+
10861182
def test_emd_sparse_backends(nx):
10871183
"""Test that sparse EMD works with different backends for weights a and b.
10881184

0 commit comments

Comments
 (0)