diff --git a/docs/src/changelog.md b/docs/src/changelog.md index fe61f9df1..b9f38061a 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec ### Fixed +- Pullbacks of `eig_trunc`, `eigh_trunc`, and `svd_trunc` no longer error when the truncation strategy keeps no values; `svd_pullback!` also handles the zero-rank case where every singular value falls below `rank_atol` ([#233](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/233)). + ### Performance ## [0.6.7](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.6...v0.6.7) - 2026-05-06 diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index 6ed023547..0d6c2cfde 100755 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -26,6 +26,7 @@ function check_and_prepare_eig_cotangents( ΔV₊ = nothing VᴴΔV₁ = zero!(similar(V, (p, p))) end + bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end @@ -81,6 +82,7 @@ function eig_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + iszero(n) && return ΔA ViG = inv(V)' ΔDmat, ΔV = ΔDV @@ -144,6 +146,7 @@ function eig_trunc_pullback!( (n, n) == size(ΔA) || throw(DimensionMismatch()) D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) + iszero(p) && return ΔA G = V' * V ViG = V / LinearAlgebra.cholesky!(G) diff --git a/src/pullbacks/eigh.jl b/src/pullbacks/eigh.jl index e9ac87ae4..2aad27faa 100755 --- a/src/pullbacks/eigh.jl +++ b/src/pullbacks/eigh.jl @@ -27,6 +27,7 @@ function check_and_prepare_eigh_cotangents( ΔV₊ = nothing aVᴴΔV₁ = zero!(similar(V, (p, p))) end + bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v) end @@ -82,6 +83,7 @@ function eigh_pullback!( D = diagview(Dmat) n == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + iszero(n) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, = check_and_prepare_eigh_cotangents( @@ -137,6 +139,7 @@ function eigh_trunc_pullback!( D = diagview(Dmat) p == length(D) || throw(DimensionMismatch()) (n, n) == size(ΔA) || throw(DimensionMismatch()) + iszero(p) && return ΔA ΔDmat, ΔV = ΔDV VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents( diff --git a/src/pullbacks/svd.jl b/src/pullbacks/svd.jl index 832d04a1d..a6e27104a 100755 --- a/src/pullbacks/svd.jl +++ b/src/pullbacks/svd.jl @@ -81,6 +81,7 @@ function check_and_prepare_svd_cotangents( ΔV₊ᴴ = nothing aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r))) end + bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v) end @@ -149,6 +150,7 @@ function svd_pullback!( (m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)")) S = diagview(Smat) r = svd_rank(S; rank_atol) + iszero(r) && return ΔA U₁ = view(U, :, 1:r) V₁ᴴ = view(Vᴴ, 1:r, :) @@ -220,6 +222,7 @@ function svd_trunc_pullback!( p = length(S) p == size(U, 2) || throw(DimensionMismatch(lazy"U has $p columns but S has $(length(S)) singular values")) p == size(Vᴴ, 1) || throw(DimensionMismatch(lazy"Vᴴ has $p rows but S has $(length(S)) singular values")) + iszero(p) && return ΔA # Extract and check the cotangents ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ diff --git a/test/testsuite/chainrules.jl b/test/testsuite/chainrules.jl index af5604881..631d95576 100755 --- a/test/testsuite/chainrules.jl +++ b/test/testsuite/chainrules.jl @@ -329,6 +329,16 @@ function test_chainrules_eig( output_tangent = ΔDVtrunc, atol = atol, rtol = rtol ) @test isequal(ΔDVtrunc, ΔDVtrunc_copy) + @testset "empty truncation" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + DV, DVtrunc, _, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg) + @test isempty(diagview(DVtrunc[1])) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test iszero(dA1) + @test iszero(dA2) + end end end end @@ -473,6 +483,16 @@ function test_chainrules_eigh( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @test isequal(ΔDVtrunc, ΔDVtrunc_copy) + @testset "empty truncation" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + DV, DVtrunc, _, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg) + @test isempty(diagview(DVtrunc[1])) + ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc) + dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind) + dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc) + @test iszero(dA1) + @test iszero(dA2) + end end end end @@ -624,6 +644,23 @@ function test_chainrules_svd( atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false ) @test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy) + @testset "empty truncation / zero rank" begin + truncalg = TruncatedAlgorithm(alg, truncrank(0)) + USVᴴ, USVᴴtrunc, _, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg) + @test isempty(diagview(USVᴴtrunc[2])) + ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc) + dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind) + dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, USVᴴtrunc, ΔUSVᴴtrunc) + @test iszero(dA1) + @test iszero(dA2) + # svd_pullback! short-circuits when every singular value is below rank_atol + _, ΔUSVᴴ = ad_svd_compact_setup(A) + huge_atol = 2 * maximum(diagview(USVᴴ[2])) + dA3 = MatrixAlgebraKit.svd_pullback!( + zero(A), A, USVᴴ, ΔUSVᴴ; rank_atol = huge_atol + ) + @test iszero(dA3) + end end end end