diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl index 779038b..1b339c4 100644 --- a/src/ReverseDiff.jl +++ b/src/ReverseDiff.jl @@ -29,8 +29,8 @@ include("tape.jl") include("tracked.jl") include("macros.jl") include("derivatives/arrays.jl") -include("derivatives/broadcast.jl") include("derivatives/propagation.jl") +include("derivatives/broadcast.jl") include("derivatives/scalars.jl") include("derivatives/elementwise.jl") include("derivatives/linalg/arithmetic.jl") diff --git a/src/derivatives/arrays.jl b/src/derivatives/arrays.jl index 44ded91..5ff0213 100644 --- a/src/derivatives/arrays.jl +++ b/src/derivatives/arrays.jl @@ -62,9 +62,11 @@ end function back(Δ) start = 0 Δs = map(xs) do xsi - x = map(_ -> :, size(xsi)) - i = isempty(x) ? x : Base.tail(x) - d = Δ[start+1:start+size(xsi,1), i...] + if xsi isa Number + d = Δ[start+1] + else + d = Δ[start+1:start+size(xsi,1), :] + end start += size(xsi, 1) d end @@ -75,11 +77,13 @@ end @grad function hcat(xs::Union{Number, AbstractVecOrMat}...) xs_value = value.(xs) - out_value = reduce(hcat,xs_value) + out_value = reduce(hcat, xs_value) function back(Δ) start = 0 Δs = map(xs) do xsi - d = if ndims(xsi) == 1 + d = if ndims(xsi) == 0 + Δ[start+1] + elseif ndims(xsi) == 1 Δ[:, start+1] else i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail @@ -102,11 +106,16 @@ end return cat(Xs_value...; dims = dims), Δ -> begin start = ntuple(i -> 0, Val(ndims(Δ))) Δs = map(Xs) do xs - dim_xs = 1:ndims(xs) - till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) - xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) - d = reshape(Δ[xs_in_Δ...],size(xs)) - start = start .+ till_xs + if xs isa Number + d = Δ[start+1] + start = start .+ 1 + else + dim_xs = 1:ndims(xs) + till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val(ndims(Δ))) + xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val(ndims(Δ))) + d = reshape(Δ[xs_in_Δ...],size(xs)) + start = start .+ till_xs + end d end return (Δs...,) diff --git a/src/derivatives/broadcast.jl b/src/derivatives/broadcast.jl index 93f0023..3cd2b56 100644 --- a/src/derivatives/broadcast.jl +++ b/src/derivatives/broadcast.jl @@ -196,29 +196,29 @@ end results, _, bounds = instruction.cache N = length(input) if N == 1 || all(isequal(size(input[1])), size.(Base.tail(input))) - _add_to_deriv!(input, output_deriv, results) + _br_add_to_deriv!(input, output_deriv, results) else - _add_to_deriv!(input, output_deriv, results, bounds) + _br_add_to_deriv!(input, output_deriv, results, bounds) end unseed!(output) return nothing end -@generated function _add_to_deriv!(xs::T, o, r) where {T <: Tuple} +@generated function _br_add_to_deriv!(xs::T, o, r) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) + return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i))) for i in 1:N]...) end -_add_to_deriv!(_, _, _, _) = nothing -function _add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i} +_br_add_to_deriv!(_, _, _, _) = nothing +function _br_add_to_deriv!(x::Union{TrackedReal, TrackedArray}, out_deriv, results, ::Val{i}) where {i} return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i) end -@generated function _add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} +@generated function _br_add_to_deriv!(xs::T, o, r, bounds) where {T <: Tuple} N = length(T.types) - return Expr(:block, [:(_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) + return Expr(:block, [:(_br_add_to_deriv!(xs[$i], o, r, Val($i), bounds[$i])) for i in 1:N]...) end -_add_to_deriv!(_, _, _, _, _) = nothing -function _add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i} +_br_add_to_deriv!(_, _, _, _, _) = nothing +function _br_add_to_deriv!(x::Union{TrackedReal,TrackedArray}, out_deriv, results, ::Val{i}, bound) where {i} return istracked(x) && diffresult_increment_deriv!(x, out_deriv, results, i, bound) end diff --git a/test/derivatives/ArrayFunctionTests.jl b/test/derivatives/ArrayFunctionTests.jl index a0e8dfe..236cf1e 100644 --- a/test/derivatives/ArrayFunctionTests.jl +++ b/test/derivatives/ArrayFunctionTests.jl @@ -1,3 +1,4 @@ +using ForwardDiff using ReverseDiff: track, value, gradient, TrackedVector, TrackedMatrix, TrackedArray using Test @@ -32,6 +33,29 @@ function testcat(f, args::Tuple{Any, Any}, type, kwargs=NamedTuple()) x = f(track.(args)...; kwargs...) @test x isa type @test value(x) == f(args...; kwargs...) + + sizes = size.(args) + F = vecx -> sum(f(unpack(sizes, vecx)...; kwargs...)) + X = pack(args) + @test ForwardDiff.gradient(F, X) == gradient(F, X) +end +function pack(xs) + return mapreduce(vcat, xs) do x + x isa Number ? x : vec(x) + end +end +function unpack(sizes, vecx) + start = 0 + out = map(sizes) do s + if s === () + x = vecx[start+1] + start += 1 + else + x = reshape(vecx[start+1:start+prod(s)], s) + start += prod(s) + end + end + return out end @testset "cat" begin