Skip to content

Commit

Permalink
Merge 136acea into 78a0028
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Jul 22, 2020
2 parents 78a0028 + 136acea commit d717bb4
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 177 deletions.
1 change: 1 addition & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ end
include("rulesets/Base/utils.jl")
include("rulesets/Base/base.jl")
include("rulesets/Base/fastmath_able.jl")
include("rulesets/Base/evalpoly.jl")
include("rulesets/Base/array.jl")
include("rulesets/Base/mapreduce.jl")

Expand Down
142 changes: 0 additions & 142 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,145 +151,3 @@ function rrule(::typeof(identity), x)
return (x, identity_pullback)
end

#####
##### `evalpoly`
#####

if VERSION v"1.4"
function frule((_, Δx, Δp), ::typeof(evalpoly), x, p)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
end

function rrule(::typeof(evalpoly), x, p)
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
end
return y, evalpoly_pullback
end

# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
end
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end

# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x

@inline _evalpoly_backp(pi, ∂yi) = ∂yi

function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Composite{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
end
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = Δy
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Composite{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
end
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
end
end
139 changes: 139 additions & 0 deletions src/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@

if VERSION v"1.4"
function frule((_, Δx, Δp), ::typeof(evalpoly), x, p)
N = length(p)
@inbounds y = p[N]
Δy = Δp[N]
@inbounds for i in (N - 1):-1:1
Δy = muladd(Δx, y, muladd(x, Δy, Δp[i]))
y = muladd(x, y, p[i])
end
return y, Δy
end

function rrule(::typeof(evalpoly), x, p)
y, ys = _evalpoly_intermediates(x, p)
function evalpoly_pullback(Δy)
∂x, ∂p = _evalpoly_back(x, p, ys, Δy)
return NO_FIELDS, ∂x, ∂p
end
return y, evalpoly_pullback
end

# evalpoly but storing intermediates
function _evalpoly_intermediates(x, p::Tuple)
return if @generated
N = length(p.parameters)
exs = []
vars = []
ex = :(p[$N])
for i in 1:(N - 1)
yi = Symbol("y", i)
push!(vars, yi)
push!(exs, :($yi = $ex))
ex = :(muladd(x, $yi, p[$(N - i)]))
end
push!(exs, :(y = $ex))
Expr(:block, exs..., :(y, ($(vars...),)))
else
_evalpoly_intermediates_fallback(x, p)
end
end
function _evalpoly_intermediates_fallback(x, p::Tuple)
N = length(p)
y = p[N]
ys = (y, ntuple(N - 2) do i
return y = muladd(x, y, p[N - i])
end...)
y = muladd(x, y, p[1])
return y, ys
end
function _evalpoly_intermediates(x, p)
N = length(p)
@inbounds yn = one(x) * p[N]
ys = similar(p, typeof(yn), N - 1)
@inbounds ys[1] = yn
@inbounds for i in 2:(N - 1)
ys[i] = muladd(x, ys[i - 1], p[N - i + 1])
end
@inbounds y = muladd(x, ys[N - 1], p[1])
return y, ys
end

# TODO: Handle following cases
# 1) x is a UniformScaling, pᵢ is a matrix
# 2) x is a matrix, pᵢ is a UniformScaling
@inline _evalpoly_backx(x, yi, ∂yi) = ∂yi * yi'
@inline _evalpoly_backx(x, yi, ∂x, ∂yi) = muladd(∂yi, yi', ∂x)
@inline _evalpoly_backx(x::Number, yi, ∂yi) = conj(dot(∂yi, yi))
@inline _evalpoly_backx(x::Number, yi, ∂x, ∂yi) = _evalpoly_backx(x, yi, ∂yi) + ∂x

@inline _evalpoly_backp(pi, ∂yi) = ∂yi

function _evalpoly_back(x, p::Tuple, ys, Δy)
return if @generated
exs = []
vars = []
N = length(p.parameters)
for i in 2:(N - 1)
∂pi = Symbol("∂p", i)
push!(vars, ∂pi)
push!(exs, :(∂x = _evalpoly_backx(x, ys[$(N - i)], ∂x, ∂yi)))
push!(exs, :($∂pi = _evalpoly_backp(p[$i], ∂yi)))
push!(exs, :(∂yi = x′ * ∂yi))
end
push!(vars, :(_evalpoly_backp(p[$N], ∂yi))) # ∂pN
Expr(
:block,
:(x′ = x'),
:(∂yi = Δy),
:(∂p1 = _evalpoly_backp(p[1], ∂yi)),
:(∂x = _evalpoly_backx(x, ys[$(N - 1)], ∂yi)),
:(∂yi = x′ * ∂yi),
exs...,
:(∂p = (∂p1, $(vars...))),
:(∂x, Composite{typeof(p),typeof(∂p)}(∂p)),
)
else
_evalpoly_back_fallback(x, p, ys, Δy)
end
end
function _evalpoly_back_fallback(x, p::Tuple, ys, Δy)
x′ = x'
∂yi = Δy
N = length(p)
∂p1 = _evalpoly_backp(p[1], ∂yi)
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p = (
∂p1,
ntuple(N - 2) do i
∂x = _evalpoly_backx(x, ys[N-i-1], ∂x, ∂yi)
∂pi = _evalpoly_backp(p[i+1], ∂yi)
∂yi = x′ * ∂yi
return ∂pi
end...,
_evalpoly_backp(p[N], ∂yi), # ∂pN
)
return ∂x, Composite{typeof(p),typeof(∂p)}(∂p)
end
function _evalpoly_back(x, p, ys, Δy)
x′ = x'
∂yi = one(x′) * Δy
N = length(p)
@inbounds ∂p1 = _evalpoly_backp(p[1], ∂yi)
∂p = similar(p, typeof(∂p1))
@inbounds begin
∂x = _evalpoly_backx(x, ys[N - 1], ∂yi)
∂yi = x′ * ∂yi
∂p[1] = ∂p1
for i in 2:(N - 1)
∂x = _evalpoly_backx(x, ys[N - i], ∂x, ∂yi)
∂p[i] = _evalpoly_backp(p[i], ∂yi)
∂yi = x′ * ∂yi
end
∂p[N] = _evalpoly_backp(p[N], ∂yi)
end
return ∂x, ∂p
end
end
35 changes: 0 additions & 35 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,41 +133,6 @@
)
end

VERSION v"1.4" && @testset "evalpoly" begin
# test fallbacks for when code generation fails
@testset "fallbacks for $T" for T in (Float64, ComplexF64)
x, p = randn(T), Tuple(randn(T, 10))
y_fb, ys_fb = ChainRules._evalpoly_intermediates_fallback(x, p)
y, ys = ChainRules._evalpoly_intermediates(x, p)
@test y_fb y
@test collect(ys_fb) collect(ys)

Δy, ys = randn(T), Tuple(randn(T, 9))
∂x_fb, ∂p_fb = ChainRules._evalpoly_back_fallback(x, p, ys, Δy)
∂x, ∂p = ChainRules._evalpoly_back(x, p, ys, Δy)
@test ∂x_fb ∂x
@test collect(∂p_fb) collect(∂p)
end

@testset "x dim: $(nx), pi dim: $(np), type: $T" for T in (Float64, ComplexF64), nx in (tuple(), 3), np in (tuple(), 3)
# skip x::Matrix, pi::Number case, which is not supported by evalpoly
isempty(np) && !isempty(nx) && continue
m = 5
sx = (nx..., nx...)
sp = (np..., np...)
x, ẋ, x̄ = randn(T, sx...), randn(T, sx...), randn(T, sx...)
p = [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
Ω = evalpoly(x, p)
Ω̄ = randn(T, size(Ω)...)
frule_test(evalpoly, (x, ẋ), (p, ṗ))
frule_test(evalpoly, (x, ẋ), (Tuple(p), Tuple(ṗ)))
rrule_test(evalpoly, Ω̄, (x, x̄), (p, p̄))
rrule_test(evalpoly, Ω̄, (x, x̄), (Tuple(p), Tuple(p̄)))
end
end

@testset "Constants" for x in (-0.1, 6.4, 1.0+0.5im, -10.0+0im, 0.0+200im)
test_scalar(one, x)
test_scalar(zero, x)
Expand Down
35 changes: 35 additions & 0 deletions test/rulesets/Base/evalpoly.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
VERSION v"1.4" && @testset "evalpoly" begin
# test fallbacks for when code generation fails
@testset "fallbacks for $T" for T in (Float64, ComplexF64)
x, p = randn(T), Tuple(randn(T, 10))
y_fb, ys_fb = ChainRules._evalpoly_intermediates_fallback(x, p)
y, ys = ChainRules._evalpoly_intermediates(x, p)
@test y_fb y
@test collect(ys_fb) collect(ys)

Δy, ys = randn(T), Tuple(randn(T, 9))
∂x_fb, ∂p_fb = ChainRules._evalpoly_back_fallback(x, p, ys, Δy)
∂x, ∂p = ChainRules._evalpoly_back(x, p, ys, Δy)
@test ∂x_fb ∂x
@test collect(∂p_fb) collect(∂p)
end

@testset "x dim: $(nx), pi dim: $(np), type: $T" for T in (Float64, ComplexF64), nx in (tuple(), 3), np in (tuple(), 3)
# skip x::Matrix, pi::Number case, which is not supported by evalpoly
isempty(np) && !isempty(nx) && continue
m = 5
sx = (nx..., nx...)
sp = (np..., np...)
x, ẋ, x̄ = randn(T, sx...), randn(T, sx...), randn(T, sx...)
p = [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
= [randn(T, sp...) for _ in 1:m]
Ω = evalpoly(x, p)
Ω̄ = randn(T, size(Ω)...)
frule_test(evalpoly, (x, ẋ), (p, ṗ))
frule_test(evalpoly, (x, ẋ), (Tuple(p), Tuple(ṗ)))
rrule_test(evalpoly, Ω̄, (x, x̄), (p, p̄))
rrule_test(evalpoly, Ω̄, (x, x̄), (Tuple(p), Tuple(p̄)))
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ println("Testing ChainRules.jl")
@testset "Base" begin
include(joinpath("rulesets", "Base", "base.jl"))
include(joinpath("rulesets", "Base", "fastmath_able.jl"))
include(joinpath("rulesets", "Base", "evalpoly.jl"))
include(joinpath("rulesets", "Base", "array.jl"))
include(joinpath("rulesets", "Base", "mapreduce.jl"))
end
Expand Down

0 comments on commit d717bb4

Please sign in to comment.