Skip to content
24 changes: 24 additions & 0 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,30 @@ end
end
end

Zygote.@adjoint function Zygote.literal_getproperty(A::RecursiveArrayTools.AbstractVectorOfArray, ::Val{:u})
function literal_AbstractVofA_u_adjoint(d)
dA = vofa_u_adjoint(d, A)
(dA, nothing)
end
A.u, literal_AbstractVofA_u_adjoint
end

function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractVectorOfArray)
m = map(enumerate(d)) do (idx, d_i)
isnothing(d_i) && return zero(A.u[idx])
d_i
end
VectorOfArray(m)
end

function vofa_u_adjoint(d, A::RecursiveArrayTools.AbstractDiffEqArray)
m = map(enumerate(d)) do (idx, d_i)
isnothing(d_i) && return zero(A.u[idx])
d_i
end
DiffEqArray(m, A.t)
end

@adjoint function literal_getproperty(A::ArrayPartition, ::Val{:x})
function literal_ArrayPartition_x_adjoint(d)
(ArrayPartition((isnothing(d[i]) ? zero(A.x[i]) : d[i] for i in 1:length(d))...),)
Expand Down
6 changes: 6 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,9 @@ loss(x)
VectorOfArray([collect((3i):(3i + 3)) for i in 1:5])
@test Zygote.gradient(loss10, x)[1] == ForwardDiff.gradient(loss10, x)
@test Zygote.gradient(loss11, x)[1] == ForwardDiff.gradient(loss11, x)

voa = RecursiveArrayTools.VectorOfArray(fill(rand(3), 3))
voa_gs, = Zygote.gradient(voa) do x
sum(sum.(x.u))
end
@test voa_gs isa RecursiveArrayTools.VectorOfArray
Loading