Skip to content

Commit

Permalink
Merge 2902091 into b1daa7a
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen committed Mar 7, 2022
2 parents b1daa7a + 2902091 commit 945b07e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.27.0"
version = "1.27.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
5 changes: 4 additions & 1 deletion src/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ function _matfun!(::typeof(exp), A::StridedMatrix{T}) where {T<:BlasFloat}
X *= X
push!(Xpows, X)
end
else
# Xpows[1] must remain balanced for computing the Fréchet derivative
X = copy(X)
end

_unbalance!(X, ilo, ihi, scale, n)
Expand Down Expand Up @@ -247,7 +250,7 @@ function _matfun_frechet!(
∂P = copy(∂A2)
∂W = C[4] * ∂P
∂V = C[3] * ∂P
for k in 2:(length(Apows) - 1)
for k in 2:length(Apows)
k2 = 2 * k
P = Apows[k - 1]
∂P, ∂temp = mul!(mul!(∂temp, ∂P, A2), P, ∂A2, true, true), ∂P
Expand Down
36 changes: 36 additions & 0 deletions test/rulesets/LinearAlgebra/matfun.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
test_frule(LinearAlgebra.exp!, A)
end
@testset "imbalanced A with no squaring" begin
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
A = [
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
2.090838913114862 -1.254084243281689 -0.04145771190198238;
2.3397892123412833 -0.6650489083959324 0.6387266010923911
]
test_frule(LinearAlgebra.exp!, A)
end
@testset "exhaustive test" begin
# added to ensure we never hit truncation error
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
rng = MersenneTwister(1)
for _ in 1:100
A = randn(rng, 3, 3)
test_frule(LinearAlgebra.exp!, A)
end
end
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
A = Matrix(Hermitian(randn(T, n, n)))
test_frule(LinearAlgebra.exp!, A)
Expand Down Expand Up @@ -48,6 +66,24 @@
A = Float64[0 10 0 0; -1 0 0 0; 0 0 0 0; -2 0 0 0]
test_rrule(exp, A; check_inferred=false)
end
@testset "imbalanced A with no squaring" begin
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
A = [
-0.007623430669065629 -0.567237096385192 0.4419041897734335;
2.090838913114862 -1.254084243281689 -0.04145771190198238;
2.3397892123412833 -0.6650489083959324 0.6387266010923911
]
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
end
@testset "exhaustive test" begin
# added to ensure we never hit truncation error
# https://github.com/JuliaDiff/ChainRules.jl/issues/595
rng = MersenneTwister(1)
for _ in 1:100
A = randn(rng, 3, 3)
test_rrule(LinearAlgebra.exp, A; check_inferred=false)
end
end
@testset "hermitian A, T=$T" for T in (Float64, ComplexF64)
A = Matrix(Hermitian(randn(T, n, n)))
test_rrule(exp, A; check_inferred=false)
Expand Down

0 comments on commit 945b07e

Please sign in to comment.