From b18d10aafcf368057025efc03ce586b9371ae4ee Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 25 Jul 2021 23:21:42 -0400 Subject: [PATCH 1/2] catch adjoint dispatch --- src/zygote.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/zygote.jl b/src/zygote.jl index 047ae548..e21aa3e9 100644 --- a/src/zygote.jl +++ b/src/zygote.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Int) +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (NoTangent(),Δ′,NoTangent()) @@ -6,7 +6,7 @@ function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, i::I VA[i],AbstractVectorOfArray_getindex_adjoint end -function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Vararg{Int,N}) where {N} +function ChainRulesCore.rrule(::typeof(getindex),VA::AbstractVectorOfArray, indices::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[indices...] = Δ @@ -43,15 +43,16 @@ function ChainRulesCore.rrule(::typeof(getproperty),A::ArrayPartition, s::Symbol A.x,literal_ArrayPartition_x_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = [ (i == j ? Δ : zero(x)) for (x,j) in zip(VA.u, 1:length(VA))] (Δ′,nothing) end + @show VA[i] VA[i],AbstractVectorOfArray_getindex_adjoint end -ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Int, j::Int...) +ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}, j::Union{Int,AbstractArray{Int},CartesianIndex,Colon,BitArray,AbstractArray{Bool}}...) function AbstractVectorOfArray_getindex_adjoint(Δ) Δ′ = zero(VA) Δ′[i,j...] = Δ From 8124486685dc862355b34529dcf44e1c59fc2516 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sun, 25 Jul 2021 23:30:57 -0400 Subject: [PATCH 2/2] don't allow chainrulescore 1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a4361b59..6b23e199 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ArrayInterface = "2.7, 3.0" -ChainRulesCore = "0.10.7, 1" +ChainRulesCore = "0.10.7" DocStringExtensions = "0.8" RecipesBase = "0.7, 0.8, 1.0" Requires = "0.5, 1.0"