Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient definitions & supertypes #168

Merged
merged 1 commit into from
Sep 30, 2021
Merged

Conversation

mcabbott
Copy link
Contributor

@mcabbott mcabbott commented Sep 29, 2021

This is inspired by SciML/SciMLSensitivity.jl#493 (comment) . I wrote many examples there, but the core surprise was that a VectorOfArrays{T} <: AbstractArray{T,3} has a gradient Vector{Matrix{T}} which is not a 3-array. If I understood correctly, it would desirable that the result be another VectorOfArray, i.e. something AbstractArray{T,3} which iterates like Vector{Matrix}.

One way to do that would be using ProjectTo. But in fact I believe the Vector{Matrix} is being explicitly generated by a rule defined here, so perhaps it is simpler just to alter that rule to wrap this?

RFC, WIP, needs tests... and needs thought about other cases:

julia> va = RecursiveArrayTools.VectorOfArray([rand(3,3), rand(3,3)]);

julia> gradient(va -> sum(va[1]), va)[1]
VectorOfArray{Float64,3}:
2-element Vector{AbstractMatrix{Float64}}:
 3×3 Fill{Float64}, with entries equal to 1.0
 [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]

julia> ans + va |> summary  # essential this work for a valid tangent
"VectorOfArray{Float64,3}"

julia> gradient((x,y) -> sum(x ./ y), va, rand(3))[1] |> summary  # should this also be changed?
"3×3×2 Array{Float64, 3}"

julia> gradient(va -> sum(va[:,1]), va)[1]  # this didn't work before, either. 
ERROR: BoundsError: attempt to access 3×3×2 VectorOfArray{Float64, 3, Vector{Matrix{Float64}}} at index [1:3, 1]
  • The broadcast case could apply something like VectorOfArray(collect(eachslice(Δ))), to return a comparable object, i.e. one which broadcasts one way and iterates a different way. This would need ProjectTo. It would also apply to things like gradient(va -> sum(abs, [1 2; 3 4] * va), VectorOfArray([[1,2], [3,4]])).
  • Is indexing va[:,1] expected to work? What should it return? The other way around should I think do this:
julia> gradient(va -> sum(va[1,:]), va)[1]
VectorOfArray{Float64,3}:
2-element Vector{Matrix{Float64}}:
 [1.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
 [1.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]

@ChrisRackauckas ChrisRackauckas marked this pull request as ready for review September 29, 2021 19:32
@show Δ′
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
(Δ′, nothing, map(_ -> nothing, j)...)
# (VectorOfArray(Δ′), nothing, map(_ -> nothing, j)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would it not be this one?

@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[i,j...] = Δ
(Δ′, i,map(_ -> nothing, j)...)
@show Δ′
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@show Δ′

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, very WIP!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no worries, still worth the downstream tests

@ChrisRackauckas
Copy link
Member

Thanks! Running downstream tests to check it out.

@@ -55,11 +73,13 @@ ZygoteRules.@adjoint function getindex(VA::AbstractVectorOfArray, i::Union{Int,A
function AbstractVectorOfArray_getindex_adjoint(Δ)
Δ′ = zero(VA)
Δ′[i,j...] = Δ
(Δ′, i,map(_ -> nothing, j)...)
@show Δ′
# (Δ′, i,map(_ -> nothing, j)...) # surely that i is a bug?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that's a bug

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants