From f08e5e178225852c3af03409274e0ffc3d6e31a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Bo=C3=AFt=C3=A9?= Date: Thu, 23 Oct 2025 18:46:54 +0200 Subject: [PATCH 1/4] stable matrix sqrt using closed-form diff --- ot/backend.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index f14da588b..678189992 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1938,6 +1938,7 @@ def __init__(self): self.rng_cuda_ = torch.Generator("cpu") from torch.autograd import Function + from torch.autograd.function import once_differentiable # define a function that takes inputs val and grads # ad returns a val tensor with proper gradients @@ -1951,8 +1952,32 @@ def forward(ctx, val, grads, *inputs): def backward(ctx, grad_output): # the gradients are grad return (None, None) + tuple(g * grad_output for g in ctx.grads) + + # define a differentiable SPD matrix sqrt + # with closed-form VJP + class MatrixSqrtFunction(Function): + @staticmethod + def forward(ctx, a): + a_sym = .5 * (a + a.transpose(-2, -1)) + L, V = torch.linalg.eigh(a_sym) + s = L.clamp_min(0).sqrt() + y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1) + ctx.save_for_backward(s, V) + return y + + @staticmethod + @once_differentiable + def backward(ctx, g): + s, V = ctx.saved_tensors + g_sym = .5 * (g + g.transpose(-2, -1)) + ghat = V.transpose(-2, -1) @ g_sym @ V + d = s.unsqueeze(-1) + s.unsqueeze(-2) + xhat = ghat / d + xhat = xhat.masked_fill(d == 0, 0) + return V @ xhat @ V.transpose(-2, -1) self.ValFunction = ValFunction + self.MatrixSqrtFunction = MatrixSqrtFunction def _to_numpy(self, a): if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): @@ -2395,12 +2420,7 @@ def pinv(self, a, hermitian=False): return torch.linalg.pinv(a, hermitian=hermitian) def sqrtm(self, a): - L, V = torch.linalg.eigh(a) - L = torch.sqrt(L) - # Q[...] = V[...] @ diag(L[...]) - Q = torch.einsum("...jk,...k->...jk", V, L) - # R[...] = Q[...] @ V[...].T - return torch.einsum("...jk,...kl->...jl", Q, torch.transpose(V, -1, -2)) + return self.MatrixSqrtFunction.apply(a) def eigh(self, a): return torch.linalg.eigh(a) From e69f98995e8eccb34fea3c91c1ddd5aae6ade7e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Bo=C3=AFt=C3=A9?= Date: Thu, 23 Oct 2025 18:49:07 +0200 Subject: [PATCH 2/4] torch matrix sqrt gradcheck --- test/test_backend.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/test_backend.py b/test/test_backend.py index 994895fda..c1dce2af2 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -822,6 +822,19 @@ def fun(a, b, d): assert nx.allclose(dl_db, b) +def test_sqrtm_backward_torch(): + if not torch: + pytest.skip("Torch not available") + nx = ot.backend.TorchBackend() + torch.manual_seed(42) + d = 5 + A = torch.randn(d, d, dtype=torch.float64, device="cpu") + A = A @ A.T + A.requires_grad_(True) + func = lambda x: nx.sqrtm(x).sum() + assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4) + + def test_get_backend_none(): a, b = np.zeros((2, 3)), None nx = get_backend(a, b) From a20b69f4a17ffcb2974fbafa04b664753bb90c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Bo=C3=AFt=C3=A9?= Date: Thu, 23 Oct 2025 19:06:51 +0200 Subject: [PATCH 3/4] pre commit --- ot/backend.py | 6 +++--- test/test_backend.py | 20 ++++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/ot/backend.py b/ot/backend.py index 678189992..a11c78209 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1952,13 +1952,13 @@ def forward(ctx, val, grads, *inputs): def backward(ctx, grad_output): # the gradients are grad return (None, None) + tuple(g * grad_output for g in ctx.grads) - + # define a differentiable SPD matrix sqrt # with closed-form VJP class MatrixSqrtFunction(Function): @staticmethod def forward(ctx, a): - a_sym = .5 * (a + a.transpose(-2, -1)) + a_sym = 0.5 * (a + a.transpose(-2, -1)) L, V = torch.linalg.eigh(a_sym) s = L.clamp_min(0).sqrt() y = (V * s.unsqueeze(-2)) @ V.transpose(-2, -1) @@ -1969,7 +1969,7 @@ def forward(ctx, a): @once_differentiable def backward(ctx, g): s, V = ctx.saved_tensors - g_sym = .5 * (g + g.transpose(-2, -1)) + g_sym = 0.5 * (g + g.transpose(-2, -1)) ghat = V.transpose(-2, -1) @ g_sym @ V d = s.unsqueeze(-1) + s.unsqueeze(-2) xhat = ghat / d diff --git a/test/test_backend.py b/test/test_backend.py index c1dce2af2..2a0fc9a48 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -823,16 +823,16 @@ def fun(a, b, d): def test_sqrtm_backward_torch(): - if not torch: - pytest.skip("Torch not available") - nx = ot.backend.TorchBackend() - torch.manual_seed(42) - d = 5 - A = torch.randn(d, d, dtype=torch.float64, device="cpu") - A = A @ A.T - A.requires_grad_(True) - func = lambda x: nx.sqrtm(x).sum() - assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4) + if not torch: + pytest.skip("Torch not available") + nx = ot.backend.TorchBackend() + torch.manual_seed(42) + d = 5 + A = torch.randn(d, d, dtype=torch.float64, device="cpu") + A = A @ A.T + A.requires_grad_(True) + func = lambda x: nx.sqrtm(x).sum() + assert torch.autograd.gradcheck(func, (A,), atol=1e-4, rtol=1e-4) def test_get_backend_none(): From 51b98ffbc37826d64e8e6c5236e576074d3b8779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Samuel=20Bo=C3=AFt=C3=A9?= Date: Thu, 23 Oct 2025 19:08:20 +0200 Subject: [PATCH 4/4] edit releases md --- RELEASES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASES.md b/RELEASES.md index ae4c3fadc..ff9fe13f5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -5,6 +5,7 @@ #### Closed issues - Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770) - Add test for build from source (PR #772, Issue #764) +- Stable `ot.TorchBackend.sqrtm` around repeated eigvals (PR #774, Issue #773) ## 0.9.6.post1