Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.66"
version = "0.7.67"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
10 changes: 5 additions & 5 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

@scalar_rule one(x) zero(x)
@scalar_rule zero(x) zero(x)
@scalar_rule transpose(x) One()
@scalar_rule transpose(x) true

# `adjoint`

Expand All @@ -16,7 +16,7 @@ end

# `real`

@scalar_rule real(x::Real) One()
@scalar_rule real(x::Real) true

frule((_, Δz), ::typeof(real), z::Number) = (real(z), real(Δz))

Expand Down Expand Up @@ -75,9 +75,9 @@ function rrule(::typeof(hypot), z::Complex)
return (Ω, hypot_pullback)
end

@scalar_rule fma(x, y, z) (y, x, One())
@scalar_rule muladd(x, y, z) (y, x, One())
@scalar_rule rem2pi(x, r::RoundingMode) (One(), NoTangent())
@scalar_rule fma(x, y, z) (y, x, true)
@scalar_rule muladd(x, y, z) (y, x, true)
@scalar_rule rem2pi(x, r::RoundingMode) (true, NoTangent())
@scalar_rule(
mod(x, y),
@setup((u, nan) = promote(x / y, NaN16), isint = isinteger(x / y)),
Expand Down
6 changes: 3 additions & 3 deletions src/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ let
return (Ω, hypot_pullback)
end

@scalar_rule x + y (One(), One())
@scalar_rule x - y (One(), -1)
@scalar_rule x + y (true, true)
@scalar_rule x - y (true, -1)
@scalar_rule x / y (one(x) / y, -(Ω / y))
#log(complex(x)) is required so it gives correct complex answer for x<0
@scalar_rule(x ^ y,
Expand All @@ -181,7 +181,7 @@ let
@scalar_rule min(x, y) @setup(gt = x > y) (!gt, gt)

# Unary functions
@scalar_rule +x One()
@scalar_rule +x true
@scalar_rule -x -1

# `sign`
Expand Down
34 changes: 17 additions & 17 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,19 @@ end
# ∂U is overwritten if not an `AbstractZero`
function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U)
∂λ isa AbstractZero && ∂U isa AbstractZero && return ∂λ + ∂U
= similar(parent(A), eltype(U))
Ā = similar(parent(A), eltype(U))
tmp = ∂U
if ∂U isa AbstractZero
mul!(, U, real.(∂λ) .* U')
mul!(Ā, U, real.(∂λ) .* U')
else
_eigen_norm_phase_rev!(∂U, A, U)
∂K = mul!(, U', ∂U)
∂K = mul!(Ā, U', ∂U)
∂K ./= λ' .- λ
∂K[diagind(∂K)] .= real.(∂λ)
mul!(tmp, ∂K, U')
mul!(, U, tmp)
mul!(Ā, U, tmp)
end
∂A = _hermitrizelike!(, A)
∂A = _hermitrizelike!(Ā, A)
return ∂A
end

Expand Down Expand Up @@ -296,7 +296,7 @@ end
##### matrix functions
#####

# Formula for frule (Fréchet derivative) from Daleckiĭ-Kreĭn theorem given in Theorem 3.11 of
# Formula for frule (Fréchet derivative) from Daleckiĭ-Kreĭn theorem given in Theorem 3.11 of
# Higham N.J. Functions of Matrices: Theory and Computation. 2008. ISBN: 978-0-898716-46-7.
# rrule is derived from frule. These rules are more stable for degenerate matrices than
# applying the chain rule to the rules for `eigen`.
Expand All @@ -305,12 +305,12 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
@eval begin
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
Y, intermediates = _matfun($func, A)
= _matfun_frechet($func, ΔA, A, Y, intermediates)
Ȳ = _matfun_frechet($func, ΔA, A, Y, intermediates)
# If ΔA was hermitian, then ∂Y has the same structure as Y
∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian))
_symhermlike!(, Y)
_symhermlike!(Ȳ, Y)
else
Ȳ
end
return Y, ∂Y
end
Expand All @@ -321,9 +321,9 @@ for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :a
# for Hermitian Y, we don't need to realify the diagonal of ΔY, since the
# effect is the same as applying _hermitrizelike! at the end
∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY
= _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates)
Ā = _matfun_frechet_adjoint($func, ∂Y, A, Y, intermediates)
# the cotangent of Hermitian A should be Hermitian
∂A = _hermitrizelike!(, A)
∂A = _hermitrizelike!(Ā, A)
return NO_FIELDS, ∂A
end
return Y, $(Symbol(func, :_pullback))
Expand Down Expand Up @@ -356,19 +356,19 @@ function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
ΔsinA, ΔcosA = real(ΔsinA), real(ΔcosA)
end
if ΔcosA isa AbstractZero
= _matfun_frechet_adjoint(sin, ΔsinA, A, sinA, (λ, U, sinλ, cosλ))
Ā = _matfun_frechet_adjoint(sin, ΔsinA, A, sinA, (λ, U, sinλ, cosλ))
elseif ΔsinA isa AbstractZero
= _matfun_frechet_adjoint(cos, ΔcosA, A, cosA, (λ, U, cosλ, -sinλ))
Ā = _matfun_frechet_adjoint(cos, ΔcosA, A, cosA, (λ, U, cosλ, -sinλ))
else
# we will overwrite tmp with various temporary values during this computation
tmp = mul!(similar(U, Base.promote_eltype(U, ΔsinA, ΔcosA)), ΔsinA, U)
∂sinΛ = mul!(similar(tmp), U', tmp)
∂cosΛ = U' * mul!(tmp, ΔcosA, U)
∂Λ = _muldiffquotmat!!(∂sinΛ, sin, λ, sinλ, cosλ, ∂sinΛ)
∂Λ = _muldiffquotmat!!(∂Λ, cos, λ, cosλ, -sinλ, ∂cosΛ, true)
= mul!(∂Λ, U, mul!(tmp, ∂Λ, U'))
Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U'))
end
∂A = _hermitrizelike!(, A)
∂A = _hermitrizelike!(Ā, A)
return NO_FIELDS, ∂A
end
return Y, sincos_pullback
Expand All @@ -386,9 +386,9 @@ Note any function `f` used with this **must** have a `frule` defined on it.
function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
if all(λi -> _isindomain(f, λi), λ)
fλ_df_dλ = map(λi -> frule((ZeroTangent(), One()), f, λi), λ)
fλ_df_dλ = map(λi -> frule((ZeroTangent(), true), f, λi), λ)
else # promote to complex if necessary
fλ_df_dλ = map(λi -> frule((ZeroTangent(), One()), f, complex(λi)), λ)
fλ_df_dλ = map(λi -> frule((ZeroTangent(), true), f, complex(λi)), λ)
end
fλ = first.(fλ_df_dλ)
df_dλ = last.(unthunk.(fλ_df_dλ))
Expand Down
16 changes: 8 additions & 8 deletions src/rulesets/packages/NaNMath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
NaNMath.max(x, y),
(ifelse(
(y > x) | (signbit(y) < signbit(x)),
ifelse(isnan(y), One(), ZeroTangent()),
ifelse(isnan(x), ZeroTangent(), One())),
ifelse(isnan(y), true, ZeroTangent()),
ifelse(isnan(x), ZeroTangent(), true)),
ifelse(
(y > x) | (signbit(y) < signbit(x)),
ifelse(isnan(y), ZeroTangent(), One()),
ifelse(isnan(x), One(), ZeroTangent())),
ifelse(isnan(y), ZeroTangent(), true),
ifelse(isnan(x), true, ZeroTangent())),
)
)
@scalar_rule(
NaNMath.min(x, y),
(ifelse(
(y < x) | (signbit(y) > signbit(x)),
ifelse(isnan(y), One(), ZeroTangent()),
ifelse(isnan(x), ZeroTangent(), One())),
ifelse(isnan(y), true, ZeroTangent()),
ifelse(isnan(x), ZeroTangent(), true)),
ifelse(
(y < x) | (signbit(y) > signbit(x)),
ifelse(isnan(y), ZeroTangent(), One()),
ifelse(isnan(x), One(), ZeroTangent())),
ifelse(isnan(y), ZeroTangent(), true),
ifelse(isnan(x), true, ZeroTangent())),
)
)