-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add many frules
#565
Add many frules
#565
Conversation
ef971c1
to
5714d78
Compare
src/rulesets/Base/arraymath.jl
Outdated
frule((_, Adot, Bdot), ::typeof(*), A, B) = A * B, muladd(Adot, B, A * Bdot) | ||
|
||
frule((_, Adot, Bdot, Cdot), ::typeof(*), A, B, C) = A*B*C, Adot*B*C + A*Bdot*C + A*B*Cdot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will anything go wrong if these are left without types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm, I mean it will mean the rule is hit rather than decomposing further.
But we can try without restrictions then if we run into bad cases add them later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One bad thing about Adot*B*C
is that, for 3 matrices, it will run the computation of how to group them 3 times. Without the rule, it would run once. But this ought to be tiny compared to matmul.
One good thing about it is that, for scalar-matrix-matrix, AD won't try to go inside the fused mul!
implementation. Although perhaps nothing bad would happen there anyway.
src/rulesets/LinearAlgebra/dense.jl
Outdated
function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B) | ||
mul!(C, A, B) | ||
mul!(ΔC, ΔA, B) | ||
mul!(ΔC, A, ΔB, true, true) | ||
return C, ΔC | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise few types, can this go wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have reviewed up to reshape
I will continue after my meeting today
src/rulesets/Base/array.jl
Outdated
@@ -4,6 +4,10 @@ | |||
|
|||
ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...) | |||
|
|||
function frule((_, xdot), ::Type{T}, x::AbstractArray) where {T<:Array} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we run a find + replace in files, and fix names like xdot
into unicode ẋ
for consistency with the rest of the project
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I've changed all the ones in array.jl to have dots. These are all linear rules, so all you have to see is that the dot equation has dots.
In arraymath.jl and dense.jl there is some precedent for Δx
which I've also followed. For rules where each term mixes up original and perturbation, I think it's a bit too subtle to put tiny dots on some factors.
The divide between array.jl and arraymath.jl is almost perfectly linear/nonlinear, in fact. Maybe the rules for array + & - should move to the linear file.
src/rulesets/Base/array.jl
Outdated
@@ -43,32 +51,81 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N} | |||
return Base.vect(X...), vect_pullback | |||
end | |||
|
|||
""" | |||
_make_real_zeros(xdots, xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps instantiate_zeros
make_real_zeros
could be interpretted as a Real vs Complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or materialize_zeros
we should be roughly consistent with what broadcasting has to say about the wording
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just saw this. For now it's _instantiate_zeros
, clearly internal, and unlike materialise I can reliably spell it...
src/rulesets/Base/array.jl
Outdated
""" | ||
_make_real_zeros(xdots, xs) = map(_real_zero, xdots, xs) | ||
_real_zero(xdot, x) = xdot | ||
_real_zero(xdot::AbstractZero, x) = zero(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if this shouldn't be:
_real_zero(xdot::AbstractZero, x) = zero(x) | |
_real_zero(xdot::ZeroTangent, x) = zero(x) | |
_real_zero(xdot::DoesNotExist, x) = isapplicable(zero, x) ? zero(x) : xdot |
So that we never end up calling zero("abc")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the example is something like gradient(x -> ["abc", x][end], 1)
. That does actually work in Zygote; it fails in Diffractor seemingly before hitting _real_zero
.
What's DoesNotExist? I think hasmethod
is pretty slow; maybe better to leave an error until we have a better plan?
julia> @btime Base.hasmethod(zero, Tuple{typeof($([1,2,3]))})
min 293.075 ns, mean 306.452 ns (3 allocations, 144 bytes)
true
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops DoesNotExist
was renamed to NoTangent
We can push hasmethod
to compile-time if we want with Tricks.jl.
If we really want.
but also maybe we want to do:
_real_zero(xdot::AbstractZero, x) = zero(x) | |
_real_zero(xdot::ZeroTangent, x) = zero(x) | |
_real_zero(xdot::NoTangent, x) = xdot |
but yeah i think fine enough to leave it as an error til it becomes a problem.
src/rulesets/Base/array.jl
Outdated
_make_real_zeros(xdots::NTuple{<:Any, <:Number}, xs) = xdots | ||
_make_real_zeros(xdots::AbstractArray{<:Number}, xs) = xdots | ||
_make_real_zeros(xdots::AbstractArray{<:AbstractArray}, xs) = xdots |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use
eltype
and HasEltype
here? so that we are abstracted over collections?
Possibly not required right now and we can leave it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My guess is that this is already into premature optimisation territory. A tuple of numbers for vect
, and a tuple of arrays for vcat
, are the cases that I actually managed to trigger; an array for reduce(vcat, xs)
seemed easy to handle with the same machine. Maybe there should be a xdots::NTuple{<:Any, <:AbstractArray}
method. Although map
on tuples is mostly free.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice. Huge work well done.
Address the comments i leave as you deem good.
Check the code coverage is good,
and then this should be good to merged
src/rulesets/Base/array.jl
Outdated
@@ -43,32 +51,81 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N} | |||
return Base.vect(X...), vect_pullback | |||
end | |||
|
|||
""" | |||
_make_real_zeros(xdots, xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or materialize_zeros
we should be roughly consistent with what broadcasting has to say about the wording
ax = axes(A) | ||
project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :) | ||
∂dims = broadcast(Returns(NoTangent()), dims) | ||
reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I correct in saying project
will not do the reshaping for us, as it only handles cases with singleton dimensions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. It will accept arrays whose size
is almost right, i.e. differing by trailing 1s only. The offsets can be wrong. Doing one reshape(... , axes)
here should mean it never reshapes twice.
@@ -42,6 +52,12 @@ end | |||
##### `sortslices` | |||
##### | |||
|
|||
function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...) | |||
p = sortperm(collect(eachslice(x; dims=dims)); kw...) | |||
inds = ntuple(d -> d == dims ? p : (:), ndims(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dims
can also be a tuple.
But the rrule doesn't support that either so probably fine?
Though it is a bif of a problem with it being a kwarg (for both forwards and reverse modes)
since won't redispatch to using the AD to work it out.
Anyway we should make an issue for this unless the solution is trivial.
E.g.:
inds = ntuple(d -> d == dims ? p : (:), ndims(x)) | |
inds = ntuple(d -> d in dims ? p : (:), ndims(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we assume the abstract array uses 1 based indexing.
should we actualy be doing
something with eachindex/axes ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not think about offsets. They ought to work in some places:
julia> sortperm(OffsetArray(rand(3), 4))
3-element OffsetArray(::Vector{Int64}, 5:7) with eltype Int64 with indices 5:7:
7
6
5
but not with collect(eachslice)
.
inds = ntuple(d -> d == dims ? p : (:), ndims(x)) | |
firstindex(x, d) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here.")) | |
inds = ntuple(d -> d == dims ? p : (:), ndims(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vote to kick dims::Tuple
down the road. It might not be hard to handle, but would need a bit of thought, and some tests...
It will fail with an error now. I don't think Zygote et al. had any hope of getting through sortslices
before the rule.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact works fine with OffsetArrays. The generator used for eachslice propagates indices through correctly:
julia> x = OffsetMatrix(rand(2,3), 4, 5);
julia> sortslices(x; dims=2)
2×3 OffsetArray(::Matrix{Float64}, 5:6, 6:8) with eltype Float64 with indices 5:6×6:8:
0.689597 0.805156 0.995562
0.727762 0.320152 0.924272
julia> rrule(sortslices, x; dims=2)[1]
2×3 OffsetArray(::Matrix{Float64}, 5:6, 6:8) with eltype Float64 with indices 5:6×6:8:
0.689597 0.805156 0.995562
0.727762 0.320152 0.924272
julia> collect(eachslice(x; dims=2))
3-element OffsetArray(::Vector{SubArray{Float64, 1, OffsetMatrix{Float64, Matrix{Float64}}, Tuple{Base.Slice{OffsetArrays.IdOffsetRange{Int64, Base.OneTo{Int64}}}, Int64}, true}}, 6:8) with eltype SubArray{Float64, 1, OffsetMatrix{Float64, Matrix{Float64}}, Tuple{Base.Slice{OffsetArrays.IdOffsetRange{Int64, Base.OneTo{Int64}}}, Int64}, true} with indices 6:8:
[0.9955624968587956, 0.9242722045713299]
[0.8051558526354236, 0.32015211201093363]
[0.689597064698794, 0.7277619193702523]
I've removed the error message.
@@ -1,23 +1,36 @@ | |||
@testset "Array constructors" begin | |||
|
|||
@testset "undef" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indenting of comment is now wrong?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is. I moved the heading as I was confused for a minute about how much the comment applied to. But I thought preserving the blame for the comment might be helpful.
test/rulesets/Base/array.jl
Outdated
@test rrule(reshape, adjoint(rand(ComplexF64, 4)), :)[2](rand(4))[2] isa Adjoint{ComplexF64} | ||
@test rrule(reshape, Diagonal(rand(4)), (2, :))[2](ones(2,8))[2] isa Diagonal | ||
@test_skip test_rrule(reshape, Diagonal(rand(4)), 2, :) # DimensionMismatch("second dimension of A, 22, does not match length of x, 16") | ||
@test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we open an issue/issues with a list of all things that are skipped and why and cross link the URL in the code where the skip happens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a pity you can't put @test_broken
instead, so that you find out when the bug is fixed.
test/rulesets/Base/indexing.jl
Outdated
@testset "forward mode" begin | ||
test_frule(getindex, x, 2) | ||
test_frule(getindex, x, 2, 1) | ||
test_frule(getindex, x, CartesianIndex(2, 3)) | ||
|
||
test_rrule(getindex, x, 2:3) | ||
test_rrule(getindex, x, (:), 2:3) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these seem misorganized, several of these are for rrule
(or are those a mistake?)
And i think we would in general rather organise by operands than by mode.
So we can probably just push them down into the signle element/slice etc tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well spotted, that's just a typo, I meant to only test frules:
@testset "forward mode" begin | |
test_frule(getindex, x, 2) | |
test_frule(getindex, x, 2, 1) | |
test_frule(getindex, x, CartesianIndex(2, 3)) | |
test_rrule(getindex, x, 2:3) | |
test_rrule(getindex, x, (:), 2:3) | |
end | |
@testset "forward mode" begin | |
test_frule(getindex, x, 2) | |
test_frule(getindex, x, 2, 1) | |
test_frule(getindex, x, CartesianIndex(2, 3)) | |
test_frule(getindex, x, 2:3) | |
test_frule(getindex, x, (:), 2:3) | |
end |
The argument for keeping them separate is that it makes sense to have far fewer. The frule
is super-simple and has few edge cases. The rrule
has many more things to think through, will eventually need to handle arrays of arrays (and second derivatives of those...).
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
Many of these are almost trivial, you almost want an
@trivial_frule vcat(xs...)
.One subtlety is that
vcat(x...); x==(1,2,3)
might getxdot = (4, ZeroTangent(), ZeroTangent())
which would give aVector{Any}
, whereas I think in this case you ought to make aVector{Float64}
, takingfloat(T)
from the forward pass? I'm far from sure that all such cases are handled well.Maybe they can be automated. Maybe this should do something sensible:
Edit: this is now handled by
_make_real_zeros
.In adding a forward rule for reshape, I tidied up the reverse one, and made one for
dropdims
.Completely fails on Julia 1.0, because syntaxf((x...,), y) = x, y
was adde d only in 1.6. Maybe we should drop 1.0? (#577)Closes #406