From 7f4c63a351a24baae4e7cd18073e2cf87671f812 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 5 Jan 2020 19:27:58 +0000 Subject: [PATCH] Fix Fill ctor --- src/lib/array.jl | 7 ++++++- test/gradcheck.jl | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index cc09f8828..7231faae9 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -18,7 +18,12 @@ using Base.Broadcast: broadcasted, broadcast_shape # Array Constructors @adjoint (::Type{T})(x::T) where T<:Array = T(x), ȳ -> (ȳ,) -@adjoint (::Type{T})(x::Number, sz) where {T <: Fill} = Fill(x, sz), Δ -> (sum(Δ), nothing) +@adjoint function (::Type{T})(x::Number, sz) where {T <: Fill} + back(Δ::AbstractArray) = (sum(Δ), nothing) + back(Δ::NamedTuple) = (Δ.value, nothing) + return Fill(x, sz), back +end + @adjoint (::Type{T})(sz) where {T<:Zeros} = Zeros(sz), Δ->(nothing,) @adjoint (::Type{T})(sz) where {T<:Ones} = Ones(sz), Δ->(nothing,) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 36669c54f..f464f8577 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1116,6 +1116,8 @@ end @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing + @test gradcheck(x->Fill(x[], 5).value, [0.1]) + @test gradcheck(x->FillArrays.getindex_value(Fill(x[], 5)), [0.1]) end @testset "AbstractArray Addition / Subtraction / Negation" begin