Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ProjectTo in broadcasting & gradient #1044

Merged
merged 43 commits into from
Sep 22, 2021
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jul 27, 2021

This starts building ChainRulesCore's type projection story into how Zygote handles broadcasting, and into its user-facing functions. This will already be called in some rules handled by ChainRules, but this applies it a bit more broadly.

After:

julia> gradient(x->imag(x + 2.0*im), 3.0)  # https://github.com/FluxML/Zygote.jl/issues/342
(0.0,)

julia> gradient(x -> getindex(x,2,1), Diagonal(rand(3,3)))[1]  # https://github.com/FluxML/Zygote.jl/issues/402
3×3 Diagonal{Float64, Vector{Float64}}:
 0.0        
     0.0    
         0.0

julia> gradient(x -> sum(sqrt.(x .+ 1)), [1,2,3]')[1]  # previously became a matrix
1×3 adjoint(::Vector{Float64}) with eltype Float64:
 0.353553  0.288675  0.25

Before:

julia> gradient(x->imag(x + 2.0*im), 3.0) 
(0.0 + 1.0im,)

julia> gradient(x -> getindex(x,2,1), Diagonal(rand(3,3)))[1]
3×3 Zygote.OneElement{Float64, 2, Tuple{Int64, Int64}, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}:
 0.0  0.0  0.0
 1.0  0.0  0.0
 0.0  0.0  0.0

julia> gradient(x -> sum(sqrt.(x .+ 1)), [1,2,3]')[1]
1×3 Matrix{Float64}:
 0.353553  0.288675  0.25

Replaces #965, or most of it.

Many tests will fail, including most of the FFT tests I think, since those tend to return a complex gradient for a real input. FFT tests are unchanged.

Closes #342, closes #402. Fixes #917, fixes #431.

Closes FluxML/Flux.jl#886

src/compiler/interface.jl Outdated Show resolved Hide resolved
@@ -73,7 +73,8 @@ julia> gradient([7, 11], 0, 1) do x, y, d
"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
grad = back(sensitivity(y))
map(_project, args, grad)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want this at the gradient or the pullback level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was to start small! Applying it to gradient applies it to the user-facing calls, once. Applying it to pullback or _pullback inserts it into many more places internally... maybe it'll make sin'''(1.0) unhappy.

@mcabbott
Copy link
Member Author

mcabbott commented Aug 1, 2021

One side-effect of this is that it makes this wrong answer into an error:

julia> gradient((x,y) -> sum(map(+,x,y)), [1,2], [3,4,5,6])  # before
([1, 1], [1, 1])

julia> gradient((x,y) -> sum(map(+,x,y)), [1,2], [3,4,5,6])  # after
ERROR: DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (2,)")
Stacktrace:
 [1] (::ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}})(dx::Vector{Int64})
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/ySyqy/src/projection.jl:197

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Aug 19, 2021

That error seems awkward to me. Previously, the Julia behaviour of the function was the reason behind this gradient. Presumably, the resultant gradient should be sized appropriately, not error.

@@ -95,11 +97,32 @@ true
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
(val = y, grad = back(sensitivity(y)))
grad = back(sensitivity(y))
isnothing(grad) && return (val=y, grad=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this check necessary?

Copy link
Member Author

@mcabbott mcabbott Aug 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can't map over nothing.

src/compiler/chainrules.jl Outdated Show resolved Hide resolved
@@ -45,18 +45,20 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing this makes unbroadcast less generic, we don't need to define projections here afaict. Let's retain the current definition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What case exactly is not handled, if this is less generic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It restricts it to what can be handled by _project as opposed to simple sizes and lengths of arrays.

Copy link
Member

@oxinabox oxinabox Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are broadly the same now, as of recent changes. _project will never method error now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that before CRC changes, _project had extra methods to handle other cases.

@@ -45,18 +45,20 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It restricts it to what can be handled by _project as opposed to simple sizes and lengths of arrays.

src/lib/broadcast.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott force-pushed the projectto branch 2 times, most recently from 2c1252b to 09a0ed6 Compare September 5, 2021 13:15
@mcabbott mcabbott marked this pull request as ready for review September 5, 2021 16:02
@inline function _project(x::Union{Numeric, Ref{<:Numeric}}, dx)
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end
_project(x::AbstractArray, dx) = dx isa AbstractArray ? reshape(dx, axes(x)) : dx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be broken down into a different method

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you write exactly what method you prefer? There are obviously always other ways to write things.

```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a method to _project and avoid this change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a method to _project and avoid this change

Can you write exactly what method that would be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like _project(x, ::Nothing) = nothing maybe

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is easy to try:

julia> _project(x, ::Nothing) = nothing
_project (generic function with 1 method)

julia> map(_project, (1,2,3), nothing)
ERROR: MethodError: no method matching length(::Nothing)

@mcabbott mcabbott changed the title Use ProjectTo in broadcasting, etc. Use ProjectTo in broadcasting & gradient Sep 6, 2021
@oxinabox
Copy link
Member

oxinabox commented Sep 6, 2021

Is there a reason not to pull all of broadcasting down into ChainRules.jl?
Probably combined with setting up Zygote to claim that ForwardDiff is it's ForwardMode AD?

@DhairyaLGandhi
Copy link
Member

Why don't we give Zygote.Forward more love? It's better for neural networks.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 6, 2021

Is there a reason not to pull all of broadcasting down into ChainRules.jl?

One reason not to is that Zygote's un-fused broadcast might not be the last word here. Maybe you can write a fused forward broadcast in Diffractor which would be hopelessly slow here. I think there's a lot of exploring left to be done. Unlike the basic rules in ChainRules, where we can write a pretty close to optimal rule once & let everything use it.

Anyway this PR has much more modest goals. In the linked Flux issues it comes pretty close to entirely removing the penalty for mixing up your eltypes. And it fixes a lot of Zygote issues about real/complex.

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Sep 7, 2021

Mixing eltypes is going to get really important with low precision work picking up the pace. We shouldn't have to write custom passes for every operation related to 16 bit floats.

Besides, its good not to be opinionated and guide users to be type stable. Wouldn't we expect complex numbers to have gradients with complex types? Changing that seems like a bug.

@mcabbott
Copy link
Member Author

mcabbott commented Sep 7, 2021

Yes there have been rumours of mixed-precision training for ages. I don't see any obvious problem though. It does not involve randomly mixing types and hoping that Julia's promotion will figure it out.

Complex/real has been discussed at great length. This PR really isn't the place to argue it; if you think it's wrong you should open an issue on ChainRulesCore and make your case.

Wouldn't we expect complex numbers to have gradients with complex types?

Err, they do? There would be a lot of broken tests if that were altered. I think you may have misunderstood what problem this projection solves. The first message has examples, and links to issues closed.

@mcabbott mcabbott merged commit 528e0be into FluxML:master Sep 22, 2021
@mcabbott mcabbott deleted the projectto branch September 22, 2021 03:04
@DhairyaLGandhi
Copy link
Member

I did respond on the slack where I'd mentioned wanting to take a look at it today.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants