diff --git a/src/zygote.jl b/src/zygote.jl index dc45972c..047ae548 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -43,7 +43,7 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (Δ′,nothing) @@ -51,7 +51,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i) VA[i],AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ