Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ When releasing a new version, move the "Unreleased" changes to a new version sec

### Fixed

- Pullbacks of `eig_trunc`, `eigh_trunc`, and `svd_trunc` no longer error when the truncation strategy keeps no values; `svd_pullback!` also handles the zero-rank case where every singular value falls below `rank_atol` ([#233](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/pull/233)).

### Performance

## [0.6.7](https://github.com/QuantumKitHub/MatrixAlgebraKit.jl/compare/v0.6.6...v0.6.7) - 2026-05-06
Expand Down
3 changes: 3 additions & 0 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ function check_and_prepare_eig_cotangents(
ΔV₊ = nothing
VᴴΔV₁ = zero!(similar(V, (p, p)))
end

bc = Base.broadcasted(transpose(D), D, VᴴΔV₁) do d₁, d₂, v
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
end
Expand Down Expand Up @@ -81,6 +82,7 @@ function eig_pullback!(
D = diagview(Dmat)
n == length(D) || throw(DimensionMismatch())
(n, n) == size(ΔA) || throw(DimensionMismatch())
iszero(n) && return ΔA
ViG = inv(V)'

ΔDmat, ΔV = ΔDV
Expand Down Expand Up @@ -144,6 +146,7 @@ function eig_trunc_pullback!(
(n, n) == size(ΔA) || throw(DimensionMismatch())
D = diagview(Dmat)
p == length(D) || throw(DimensionMismatch())
iszero(p) && return ΔA
G = V' * V
ViG = V / LinearAlgebra.cholesky!(G)

Expand Down
3 changes: 3 additions & 0 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ function check_and_prepare_eigh_cotangents(
ΔV₊ = nothing
aVᴴΔV₁ = zero!(similar(V, (p, p)))
end

bc = Base.broadcasted(transpose(D), D, aVᴴΔV₁) do d₁, d₂, v
return abs(d₁ - d₂) < degeneracy_atol ? v : zero(v)
end
Expand Down Expand Up @@ -82,6 +83,7 @@ function eigh_pullback!(
D = diagview(Dmat)
n == length(D) || throw(DimensionMismatch())
(n, n) == size(ΔA) || throw(DimensionMismatch())
iszero(n) && return ΔA

ΔDmat, ΔV = ΔDV
VᴴΔAV, = check_and_prepare_eigh_cotangents(
Expand Down Expand Up @@ -137,6 +139,7 @@ function eigh_trunc_pullback!(
D = diagview(Dmat)
p == length(D) || throw(DimensionMismatch())
(n, n) == size(ΔA) || throw(DimensionMismatch())
iszero(p) && return ΔA

ΔDmat, ΔV = ΔDV
VᴴΔAV, ΔV₊ = check_and_prepare_eigh_cotangents(
Expand Down
3 changes: 3 additions & 0 deletions src/pullbacks/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ function check_and_prepare_svd_cotangents(
ΔV₊ᴴ = nothing
aVᴴΔV₁ = zero!(similar(V₁ᴴ, (r, r)))
end

bc = Base.broadcasted(S₁', S₁, aUᴴΔU₁, aVᴴΔV₁) do s₁, s₂, u, v
return abs(s₁ - s₂) < degeneracy_atol ? u + v : zero(u) + zero(v)
end
Expand Down Expand Up @@ -149,6 +150,7 @@ function svd_pullback!(
(m, n) == size(ΔA) || throw(DimensionMismatch(lazy"size of ΔA ($(size(ΔA))) does not match size of USVᴴ ($m, $n)"))
S = diagview(Smat)
r = svd_rank(S; rank_atol)
iszero(r) && return ΔA

U₁ = view(U, :, 1:r)
V₁ᴴ = view(Vᴴ, 1:r, :)
Expand Down Expand Up @@ -220,6 +222,7 @@ function svd_trunc_pullback!(
p = length(S)
p == size(U, 2) || throw(DimensionMismatch(lazy"U has $p columns but S has $(length(S)) singular values"))
p == size(Vᴴ, 1) || throw(DimensionMismatch(lazy"Vᴴ has $p rows but S has $(length(S)) singular values"))
iszero(p) && return ΔA

# Extract and check the cotangents
ΔU, ΔSmat, ΔVᴴ = ΔUSVᴴ
Expand Down
37 changes: 37 additions & 0 deletions test/testsuite/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,16 @@ function test_chainrules_eig(
output_tangent = ΔDVtrunc, atol = atol, rtol = rtol
)
@test isequal(ΔDVtrunc, ΔDVtrunc_copy)
@testset "empty truncation" begin
truncalg = TruncatedAlgorithm(alg, truncrank(0))
DV, DVtrunc, _, ΔDVtrunc = ad_eig_trunc_setup(A, truncalg)
@test isempty(diagview(DVtrunc[1]))
ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc)
dA1 = MatrixAlgebraKit.eig_pullback!(zero(A), A, DV, ΔDVtrunc, ind)
dA2 = MatrixAlgebraKit.eig_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc)
@test iszero(dA1)
@test iszero(dA2)
end
end
end
end
Expand Down Expand Up @@ -473,6 +483,16 @@ function test_chainrules_eigh(
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
@test isequal(ΔDVtrunc, ΔDVtrunc_copy)
@testset "empty truncation" begin
truncalg = TruncatedAlgorithm(alg, truncrank(0))
DV, DVtrunc, _, ΔDVtrunc = ad_eigh_trunc_setup(A, truncalg)
@test isempty(diagview(DVtrunc[1]))
ind = MatrixAlgebraKit.findtruncated(diagview(DV[1]), truncalg.trunc)
dA1 = MatrixAlgebraKit.eigh_pullback!(zero(A), A, DV, ΔDVtrunc, ind)
dA2 = MatrixAlgebraKit.eigh_trunc_pullback!(zero(A), A, DVtrunc, ΔDVtrunc)
@test iszero(dA1)
@test iszero(dA2)
end
end
end
end
Expand Down Expand Up @@ -624,6 +644,23 @@ function test_chainrules_svd(
atol = atol, rtol = rtol, rrule_f = rrule_via_ad, check_inferred = false
)
@test isequal(ΔUSVᴴtrunc, ΔUSVᴴtrunc_copy)
@testset "empty truncation / zero rank" begin
truncalg = TruncatedAlgorithm(alg, truncrank(0))
USVᴴ, USVᴴtrunc, _, ΔUSVᴴtrunc = ad_svd_trunc_setup(A, truncalg)
@test isempty(diagview(USVᴴtrunc[2]))
ind = MatrixAlgebraKit.findtruncated(diagview(USVᴴ[2]), truncalg.trunc)
dA1 = MatrixAlgebraKit.svd_pullback!(zero(A), A, USVᴴ, ΔUSVᴴtrunc, ind)
dA2 = MatrixAlgebraKit.svd_trunc_pullback!(zero(A), A, USVᴴtrunc, ΔUSVᴴtrunc)
@test iszero(dA1)
@test iszero(dA2)
# svd_pullback! short-circuits when every singular value is below rank_atol
_, ΔUSVᴴ = ad_svd_compact_setup(A)
huge_atol = 2 * maximum(diagview(USVᴴ[2]))
dA3 = MatrixAlgebraKit.svd_pullback!(
zero(A), A, USVᴴ, ΔUSVᴴ; rank_atol = huge_atol
)
@test iszero(dA3)
end
end
end
end
Expand Down
Loading