From d46af1027d44945227ab6912e85d2a681e51e8d5 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 18 May 2026 16:15:18 -0400 Subject: [PATCH 1/5] Guard pullback implementations against empty `ind` --- src/pullbacks/eig.jl | 14 ++++++++------ src/pullbacks/eigh.jl | 14 ++++++++------ src/pullbacks/svd.jl | 9 ++++++--- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 6ed023547..ab7680981 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -26,13 +26,15 @@ function check_and_prepare_eig_cotangents( ΔV₊ = nothing VᴴΔV₁ = zero!(similar(V, (p, p))) end - bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v - return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) - end - Δgauge = norm(bc, Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if !isempty(D) # norm(bc, Inf) calls eltype on empty inputs + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) VᴴAΔV = VᴴΔV₁ diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index e9ac87ae4..70635bb86 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -27,13 +27,15 @@ function check_and_prepare_eigh_cotangents( ΔV₊ = nothing aVᴴΔV₁ = zero!(similar(V, (p, p))) end - bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v - return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) - end - Δgauge = norm(bc, Inf) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + if !isempty(D) # norm(bc, Inf) calls eltype on empty inputs + bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) + end + Δgauge = norm(bc, Inf) + Δgauge ≤ gauge_atol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + end aVᴴΔV₁ .*= inv_safe.(D' .- D, degeneracy_atol) VᴴAΔV = aVᴴΔV₁ diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 832d04a1d..407a38539 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -81,10 +81,13 @@ function check_and_prepare_svd_cotangents( ΔV₊ᴴ = nothing aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end - bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v - return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) + + if !isempty(S₁) # norm(bc, Inf) calls eltype for empty iterables + bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v + return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) + end + Δgauge = max(Δgauge, norm(bc, Inf)) end - Δgauge = max(Δgauge, norm(bc, Inf)) if !iszerotangent(ΔSmat) ΔS = diagview(ΔSmat) From 96bc002af4765d074230b79609306b893442d842 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 18 May 2026 16:32:19 -0400 Subject: [PATCH 2/5] better guard against empty pullbacks --- src/pullbacks/eig.jl | 15 ++++++++------- src/pullbacks/eigh.jl | 15 ++++++++------- src/pullbacks/svd.jl | 10 +++++----- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index ab7680981..df8cf51a9 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -27,14 +27,13 @@ function check_and_prepare_eig_cotangents( VᴴΔV₁ = zero!(similar(V, (p, p))) end - if !isempty(D) # norm(bc, Inf) calls eltype on empty inputs - bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v - return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) - end - Δgauge = norm(bc, Inf) - Δgauge ≤ gauge_atol || - @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end + Δgauge = norm(bc, Inf) + + Δgauge ≤ gauge_atol || + @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" VᴴΔV₁ .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) VᴴAΔV = VᴴΔV₁ @@ -83,6 +82,7 @@ function eig_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + isempty(D) && return ΔA ViG = inv(V)' ΔDmat, ΔV = ΔDV @@ -146,6 +146,7 @@ function eig_trunc_pullback!( (n, n) == size(ΔA) || throw(DimensionMismatch()) D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) + isempty(D) && return ΔA G = V' * V ViG = V / LinearAlgebra.cholesky!(G) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 70635bb86..764080e42 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -28,14 +28,13 @@ function check_and_prepare_eigh_cotangents( aVᴴΔV₁ = zero!(similar(V, (p, p))) end - if !isempty(D) # norm(bc, Inf) calls eltype on empty inputs - bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v - return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) - end - Δgauge = norm(bc, Inf) - Δgauge ≤ gauge_atol || - @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" + bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v + return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end + Δgauge = norm(bc, Inf) + + Δgauge ≤ gauge_atol || + @warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" aVᴴΔV₁ .*= inv_safe.(D' .- D, degeneracy_atol) VᴴAΔV = aVᴴΔV₁ @@ -84,6 +83,7 @@ function eigh_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + isempty(D) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, = check_and_prepare_eigh_cotangents( @@ -139,6 +139,7 @@ function eigh_trunc_pullback!( D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + isempty(D) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents( diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 407a38539..a6e27104a 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -82,12 +82,10 @@ function check_and_prepare_svd_cotangents( aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end - if !isempty(S₁) # norm(bc, Inf) calls eltype for empty iterables - bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v - return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) - end - Δgauge = max(Δgauge, norm(bc, Inf)) + bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v + return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) end + Δgauge = max(Δgauge, norm(bc, Inf)) if !iszerotangent(ΔSmat) ΔS = diagview(ΔSmat) @@ -152,6 +150,7 @@ function svd_pullback!( (m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)")) S = diagview(Smat) r = svd_rank(S; rank_atol) + iszero(r) && return ΔA U₁ = view(U, :, 1:r) V₁ᴴ = view(Vᴴ, 1:r, :) @@ -223,6 +222,7 @@ function svd_trunc_pullback!( p = length(S) p == size(U, 2) || throw(DimensionMismatch(lazy"U has $p columns but S has $(length(S)) singular values")) p == size(Vᴴ, 1) || throw(DimensionMismatch(lazy"Vᴴ has $p rows but S has $(length(S)) singular values")) + iszero(p) && return ΔA # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ From 7640b21914717e7f4c750a1d396663bc56f1da42 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 18 May 2026 16:38:19 -0400 Subject: [PATCH 3/5] add test cases --- test/testsuite/chainrules.jl | 37 ++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index af5604881..631d95576 100755 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -329,6 +329,16 @@ function test_chainrules_eig( output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) @test isequal(ΔDVtrunc, ΔDVtrunc_copy) + @testset "empty truncation" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + DV, DVtrunc, _, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + @test isempty(diagview(DVtrunc[1])) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test iszero(dA1) + @test iszero(dA2) + end end end end @@ -473,6 +483,16 @@ function test_chainrules_eigh( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @test isequal(ΔDVtrunc, ΔDVtrunc_copy) + @testset "empty truncation" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + DV, DVtrunc, _, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + @test isempty(diagview(DVtrunc[1])) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test iszero(dA1) + @test iszero(dA2) + end end end end @@ -624,6 +644,23 @@ function test_chainrules_svd( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) + @testset "empty truncation / zero rank" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + USVᴴ, USVᴴtrunc, _, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + @test isempty(diagview(USVᴴtrunc[2])) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, USVᴴtrunc, ΔUSVᴴtrunc) + @test iszero(dA1) + @test iszero(dA2) + # svd_pullback! short-circuits when every singular value is below rank_atol + _, ΔUSVᴴ = ad_svd_compact_setup(A) + huge_atol = 2 * maximum(diagview(USVᴴ[2])) + dA3 = MatrixAlgebraKit.svd_pullback!( + zero(A), A, USVᴴ, ΔUSVᴴ; rank_atol = huge_atol + ) + @test iszero(dA3) + end end end end From de7df2aa5143594d0b9980a4682f5e53ba9b9576 Mon Sep 17 00:00:00 2001 From: lkdvos Date: Mon, 18 May 2026 17:04:27 -0400 Subject: [PATCH 4/5] update changelog --- docs/src/changelog.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/changelog.md b/docs/src/changelog.md index fe61f9df1..b9f38061a 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Fixed +- Pullbacks of `eig_trunc`, `eigh_trunc`, and `svd_trunc` no longer error when the truncation strategy keeps no values; `svd_pullback!` also handles the zero-rank case where every singular value falls below `rank_atol` ([#233](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/233)). + ### Performance ## [0.6.7](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.6...v0.6.7) - 2026-05-06 From fa59cc6b826ec1c1ac969c5ee97a47ff9a6b51d2 Mon Sep 17 00:00:00 2001 From: Lukas Devos Date: Tue, 19 May 2026 07:25:58 -0400 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Jutho --- src/pullbacks/eig.jl | 4 ++-- src/pullbacks/eigh.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index df8cf51a9..0d6c2cfde 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -82,7 +82,7 @@ function eig_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - isempty(D) && return ΔA + iszero(n) && return ΔA ViG = inv(V)' ΔDmat, ΔV = ΔDV @@ -146,7 +146,7 @@ function eig_trunc_pullback!( (n, n) == size(ΔA) || throw(DimensionMismatch()) D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) - isempty(D) && return ΔA + iszero(p) && return ΔA G = V' * V ViG = V / LinearAlgebra.cholesky!(G) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index 764080e42..2aad27faa 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -83,7 +83,7 @@ function eigh_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - isempty(D) && return ΔA + iszero(n) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, = check_and_prepare_eigh_cotangents( @@ -139,7 +139,7 @@ function eigh_trunc_pullback!( D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) - isempty(D) && return ΔA + iszero(p) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents(