From abb3e0c8d42a9a4d637422586b0f8efdf8172de1 Mon Sep 17 00:00:00 2001 From: Lucas Morton <23484003+lamorton@users.noreply.github.com> Date: Wed, 7 Jul 2021 18:39:27 -0700 Subject: [PATCH] Fix for another instance of the same kind of issue as #151. --- src/zygote.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index dc45972c..047ae548 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -43,7 +43,7 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (Δ′,nothing) @@ -51,7 +51,7 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i) VA[i],AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i, j...) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ