diff --git a/ext/TensorKitChainRulesCoreExt.jl b/ext/TensorKitChainRulesCoreExt.jl index 61f4d849..5d7881f9 100644 --- a/ext/TensorKitChainRulesCoreExt.jl +++ b/ext/TensorKitChainRulesCoreExt.jl @@ -536,9 +536,8 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, atol::Real=0, rtol::Real=atol > 0 ? 0 : eps(real(eltype(R)))^(3 / 4)) Rd = view(R, diagind(R)) - p = let tol = atol > 0 ? atol : rtol * maximum(abs, Rd) - findlast(x -> abs(x) >= tol, Rd) - end + tol = atol > 0 ? atol : rtol * maximum(abs, Rd) + p = findlast(>=(tol) ∘ abs, Rd) m, n = size(R) Q1 = view(Q, :, 1:p) @@ -548,7 +547,6 @@ function qr_pullback!(ΔA::AbstractMatrix, Q::AbstractMatrix, R::AbstractMatrix, ΔA1 = view(ΔA, :, 1:p) ΔQ1 = view(ΔQ, :, 1:p) ΔR1 = view(ΔR, 1:p, :) - ΔR11 = view(ΔR, 1:p, 1:p) M = similar(R, (p, p)) ΔR isa AbstractZero || mul!(M, ΔR1, R1') @@ -591,9 +589,8 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, atol::Real=0, rtol::Real=atol > 0 ? 0 : eps(real(eltype(L)))^(3 / 4)) Ld = view(L, diagind(L)) - p = let tol = atol > 0 ? atol : rtol * maximum(abs, Ld) - findlast(x -> abs(x) >= tol, Ld) - end + tol = atol > 0 ? atol : rtol * maximum(abs, Ld) + p = findlast(>=(tol) ∘ abs, Ld) m, n = size(L) L1 = view(L, :, 1:p) @@ -603,7 +600,6 @@ function lq_pullback!(ΔA::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, ΔA1 = view(ΔA, 1:p, :) ΔQ1 = view(ΔQ, 1:p, :) ΔL1 = view(ΔL, :, 1:p) - ΔR11 = view(ΔL, 1:p, 1:p) M = similar(L, (p, p)) ΔL isa AbstractZero || mul!(M, L1', ΔL1)