Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add many frules #565

Merged
merged 18 commits into from
Jan 25, 2022
Merged

Add many frules #565

merged 18 commits into from
Jan 25, 2022

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 14, 2022

Many of these are almost trivial, you almost want an @trivial_frule vcat(xs...).

One subtlety is that vcat(x...); x==(1,2,3) might get xdot = (4, ZeroTangent(), ZeroTangent()) which would give a Vector{Any}, whereas I think in this case you ought to make a Vector{Float64}, taking float(T) from the forward pass? I'm far from sure that all such cases are handled well.

Maybe they can be automated. Maybe this should do something sensible:

julia> promote(4, ZeroTangent(), ZeroTangent())
ERROR: promotion of types Int64, ZeroTangent and ZeroTangent failed to change any arguments

Edit: this is now handled by _make_real_zeros.

In adding a forward rule for reshape, I tidied up the reverse one, and made one for dropdims.

Completely fails on Julia 1.0, because syntax f((x...,), y) = x, y was adde d only in 1.6. Maybe we should drop 1.0? (#577)

Closes #406

@mcabbott mcabbott marked this pull request as draft January 14, 2022 21:51
@mcabbott mcabbott force-pushed the forward branch 7 times, most recently from ef971c1 to 5714d78 Compare January 20, 2022 04:08
Comment on lines 22 to 24
frule((_, Adot, Bdot), ::typeof(*), A, B) = A * B, muladd(Adot, B, A * Bdot)

frule((_, Adot, Bdot, Cdot), ::typeof(*), A, B, C) = A*B*C, Adot*B*C + A*Bdot*C + A*B*Cdot
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will anything go wrong if these are left without types?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, I mean it will mean the rule is hit rather than decomposing further.
But we can try without restrictions then if we run into bad cases add them later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One bad thing about Adot*B*C is that, for 3 matrices, it will run the computation of how to group them 3 times. Without the rule, it would run once. But this ought to be tiny compared to matmul.

One good thing about it is that, for scalar-matrix-matrix, AD won't try to go inside the fused mul! implementation. Although perhaps nothing bad would happen there anyway.

src/rulesets/Base/indexing.jl Outdated Show resolved Hide resolved
Comment on lines 66 to 71
function frule((_, ΔC, ΔA, ΔB), ::typeof(mul!), C::AbstractArray, A, B)
mul!(C, A, B)
mul!(ΔC, ΔA, B)
mul!(ΔC, A, ΔB, true, true)
return C, ΔC
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise few types, can this go wrong?

@mcabbott mcabbott marked this pull request as ready for review January 20, 2022 04:25
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have reviewed up to reshape
I will continue after my meeting today

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@@ -4,6 +4,10 @@

ChainRules.@non_differentiable (::Type{T} where {T<:Array})(::UndefInitializer, args...)

function frule((_, xdot), ::Type{T}, x::AbstractArray) where {T<:Array}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we run a find + replace in files, and fix names like xdot into unicode
for consistency with the rest of the project

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I've changed all the ones in array.jl to have dots. These are all linear rules, so all you have to see is that the dot equation has dots.

In arraymath.jl and dense.jl there is some precedent for Δx which I've also followed. For rules where each term mixes up original and perturbation, I think it's a bit too subtle to put tiny dots on some factors.

The divide between array.jl and arraymath.jl is almost perfectly linear/nonlinear, in fact. Maybe the rules for array + & - should move to the linear file.

@@ -43,32 +51,81 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
return Base.vect(X...), vect_pullback
end

"""
_make_real_zeros(xdots, xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps instantiate_zeros
make_real_zeros could be interpretted as a Real vs Complex.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or materialize_zeros
we should be roughly consistent with what broadcasting has to say about the wording

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just saw this. For now it's _instantiate_zeros, clearly internal, and unlike materialise I can reliably spell it...

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
"""
_make_real_zeros(xdots, xs) = map(_real_zero, xdots, xs)
_real_zero(xdot, x) = xdot
_real_zero(xdot::AbstractZero, x) = zero(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this shouldn't be:

Suggested change
_real_zero(xdot::AbstractZero, x) = zero(x)
_real_zero(xdot::ZeroTangent, x) = zero(x)
_real_zero(xdot::DoesNotExist, x) = isapplicable(zero, x) ? zero(x) : xdot

So that we never end up calling zero("abc")

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the example is something like gradient(x -> ["abc", x][end], 1). That does actually work in Zygote; it fails in Diffractor seemingly before hitting _real_zero.

What's DoesNotExist? I think hasmethod is pretty slow; maybe better to leave an error until we have a better plan?

julia> @btime Base.hasmethod(zero, Tuple{typeof($([1,2,3]))})
  min 293.075 ns, mean 306.452 ns (3 allocations, 144 bytes)
true

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops DoesNotExist was renamed to NoTangent

We can push hasmethod to compile-time if we want with Tricks.jl.
If we really want.

but also maybe we want to do:

Suggested change
_real_zero(xdot::AbstractZero, x) = zero(x)
_real_zero(xdot::ZeroTangent, x) = zero(x)
_real_zero(xdot::NoTangent, x) = xdot

but yeah i think fine enough to leave it as an error til it becomes a problem.

Comment on lines 66 to 68
_make_real_zeros(xdots::NTuple{<:Any, <:Number}, xs) = xdots
_make_real_zeros(xdots::AbstractArray{<:Number}, xs) = xdots
_make_real_zeros(xdots::AbstractArray{<:AbstractArray}, xs) = xdots
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use
eltype and HasEltype here? so that we are abstracted over collections?
Possibly not required right now and we can leave it?

Copy link
Member Author

@mcabbott mcabbott Jan 24, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess is that this is already into premature optimisation territory. A tuple of numbers for vect, and a tuple of arrays for vcat, are the cases that I actually managed to trigger; an array for reduce(vcat, xs) seemed easy to handle with the same machine. Maybe there should be a xdots::NTuple{<:Any, <:AbstractArray} method. Although map on tuples is mostly free.

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice. Huge work well done.

Address the comments i leave as you deem good.
Check the code coverage is good,
and then this should be good to merged

@@ -43,32 +51,81 @@ function rrule(::typeof(Base.vect), X::Vararg{Any,N}) where {N}
return Base.vect(X...), vect_pullback
end

"""
_make_real_zeros(xdots, xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or materialize_zeros
we should be roughly consistent with what broadcasting has to say about the wording

ax = axes(A)
project = ProjectTo(A) # Projection is here for e.g. reshape(::Diagonal, :)
∂dims = broadcast(Returns(NoTangent()), dims)
reshape_pullback(Ȳ) = (NoTangent(), project(reshape(Ȳ, ax)), ∂dims...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I correct in saying project will not do the reshaping for us, as it only handles cases with singleton dimensions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It will accept arrays whose size is almost right, i.e. differing by trailing 1s only. The offsets can be wrong. Doing one reshape(... , axes) here should mean it never reshapes twice.

src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
src/rulesets/Base/array.jl Outdated Show resolved Hide resolved
@@ -42,6 +52,12 @@ end
##### `sortslices`
#####

function frule((_, ẋ), ::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dims can also be a tuple.
But the rrule doesn't support that either so probably fine?

Though it is a bif of a problem with it being a kwarg (for both forwards and reverse modes)
since won't redispatch to using the AD to work it out.

Anyway we should make an issue for this unless the solution is trivial.
E.g.:

Suggested change
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
inds = ntuple(d -> d in dims ? p : (:), ndims(x))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we assume the abstract array uses 1 based indexing.
should we actualy be doing
something with eachindex/axes ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not think about offsets. They ought to work in some places:

julia> sortperm(OffsetArray(rand(3), 4))
3-element OffsetArray(::Vector{Int64}, 5:7) with eltype Int64 with indices 5:7:
 7
 6
 5

but not with collect(eachslice).

Suggested change
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
firstindex(x, d) == 1 || throw(ArgumentError("The `rrule` for `sortslices` does not at present handle offset indices here."))
inds = ntuple(d -> d == dims ? p : (:), ndims(x))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote to kick dims::Tuple down the road. It might not be hard to handle, but would need a bit of thought, and some tests...

It will fail with an error now. I don't think Zygote et al. had any hope of getting through sortslices before the rule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact works fine with OffsetArrays. The generator used for eachslice propagates indices through correctly:

julia> x = OffsetMatrix(rand(2,3), 4, 5);

julia> sortslices(x; dims=2)
2×3 OffsetArray(::Matrix{Float64}, 5:6, 6:8) with eltype Float64 with indices 5:6×6:8:
 0.689597  0.805156  0.995562
 0.727762  0.320152  0.924272

julia> rrule(sortslices, x; dims=2)[1]
2×3 OffsetArray(::Matrix{Float64}, 5:6, 6:8) with eltype Float64 with indices 5:6×6:8:
 0.689597  0.805156  0.995562
 0.727762  0.320152  0.924272

julia> collect(eachslice(x; dims=2))
3-element OffsetArray(::Vector{SubArray{Float64, 1, OffsetMatrix{Float64, Matrix{Float64}}, Tuple{Base.Slice{OffsetArrays.IdOffsetRange{Int64, Base.OneTo{Int64}}}, Int64}, true}}, 6:8) with eltype SubArray{Float64, 1, OffsetMatrix{Float64, Matrix{Float64}}, Tuple{Base.Slice{OffsetArrays.IdOffsetRange{Int64, Base.OneTo{Int64}}}, Int64}, true} with indices 6:8:
 [0.9955624968587956, 0.9242722045713299]
 [0.8051558526354236, 0.32015211201093363]
 [0.689597064698794, 0.7277619193702523]

I've removed the error message.

@@ -1,23 +1,36 @@
@testset "Array constructors" begin

@testset "undef" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indenting of comment is now wrong?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is. I moved the heading as I was confused for a minute about how much the comment applied to. But I thought preserving the blame for the comment might be helpful.

@test rrule(reshape, adjoint(rand(ComplexF64, 4)), :)[2](rand(4))[2] isa Adjoint{ComplexF64}
@test rrule(reshape, Diagonal(rand(4)), (2, :))[2](ones(2,8))[2] isa Diagonal
@test_skip test_rrule(reshape, Diagonal(rand(4)), 2, :) # DimensionMismatch("second dimension of A, 22, does not match length of x, 16")
@test_skip test_rrule(reshape, UpperTriangular(rand(4,4)), (8, 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we open an issue/issues with a list of all things that are skipped and why and cross link the URL in the code where the skip happens?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a pity you can't put @test_broken instead, so that you find out when the bug is fixed.

Comment on lines 5 to 12
@testset "forward mode" begin
test_frule(getindex, x, 2)
test_frule(getindex, x, 2, 1)
test_frule(getindex, x, CartesianIndex(2, 3))

test_rrule(getindex, x, 2:3)
test_rrule(getindex, x, (:), 2:3)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these seem misorganized, several of these are for rrule (or are those a mistake?)
And i think we would in general rather organise by operands than by mode.
So we can probably just push them down into the signle element/slice etc tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well spotted, that's just a typo, I meant to only test frules:

Suggested change
@testset "forward mode" begin
test_frule(getindex, x, 2)
test_frule(getindex, x, 2, 1)
test_frule(getindex, x, CartesianIndex(2, 3))
test_rrule(getindex, x, 2:3)
test_rrule(getindex, x, (:), 2:3)
end
@testset "forward mode" begin
test_frule(getindex, x, 2)
test_frule(getindex, x, 2, 1)
test_frule(getindex, x, CartesianIndex(2, 3))
test_frule(getindex, x, 2:3)
test_frule(getindex, x, (:), 2:3)
end

The argument for keeping them separate is that it makes sense to have far fewer. The frule is super-simple and has few edge cases. The rrule has many more things to think through, will eventually need to handle arrays of arrays (and second derivatives of those...).

@mcabbott mcabbott merged commit 8c34f19 into JuliaDiff:main Jan 25, 2022
@mcabbott mcabbott deleted the forward branch January 25, 2022 22:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants