Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ProjectTo in broadcasting & gradient #1044

Merged
merged 43 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
050ea52
use ProjectTo in broadcasting, etc
mcabbott Jul 27, 2021
a416263
separate methods for Params
mcabbott Jul 27, 2021
ac1281b
move after defn
mcabbott Jul 27, 2021
0bb31c2
better dims handling in unbroadcast
mcabbott Aug 1, 2021
d087bbe
tidier
mcabbott Aug 1, 2021
d7ce02f
tests
mcabbott Aug 2, 2021
f353ae2
more wrapping
mcabbott Aug 18, 2021
48fbfcc
fix a test
mcabbott Aug 18, 2021
a826092
handle a few nothings
mcabbott Aug 18, 2021
91fc91f
fix more, including FFT tests
mcabbott Aug 18, 2021
d905c3d
tests
mcabbott Aug 19, 2021
fbebbe9
one test
mcabbott Aug 19, 2021
502d85d
tests
mcabbott Aug 19, 2021
361d047
tests
mcabbott Aug 19, 2021
b621330
tests
mcabbott Aug 19, 2021
3e3e16e
these are fixed
mcabbott Aug 19, 2021
ea54df7
add Compat
mcabbott Aug 19, 2021
ff5f20e
tests
mcabbott Aug 19, 2021
8599e1b
add tests for issues closed
mcabbott Aug 19, 2021
27e52b2
simplify, some doctests
mcabbott Sep 5, 2021
ff9aacf
fix some tests
mcabbott Sep 5, 2021
5bf5342
less piracy
mcabbott Sep 5, 2021
e9ea88a
adjoint
mcabbott Sep 5, 2021
0013fd3
piract
mcabbott Sep 5, 2021
c07ae9f
skip a test
mcabbott Sep 5, 2021
7ff1159
splat tests
mcabbott Sep 5, 2021
6549c57
skip on 1.3
mcabbott Sep 5, 2021
298f119
simplify _project
mcabbott Sep 5, 2021
e3922a9
a typo
mcabbott Sep 5, 2021
a2814ae
tweak
mcabbott Sep 5, 2021
08f8c46
broken GPU test, unrelated
mcabbott Sep 5, 2021
c8bc588
unexpected pass
mcabbott Sep 5, 2021
5080490
only broken on 1.6
mcabbott Sep 5, 2021
1b37161
let nothing through
mcabbott Sep 5, 2021
4c08118
rm some broken things
mcabbott Sep 5, 2021
7197491
target 1.3 fix
mcabbott Sep 5, 2021
dde922b
comments
mcabbott Sep 9, 2021
1c07a7c
update for ProjectTo(::Any)
mcabbott Sep 21, 2021
35280d5
fix a test
mcabbott Sep 21, 2021
80123a1
Update test/utils.jl
mcabbott Sep 21, 2021
3bc2e09
Update src/lib/broadcast.jl
mcabbott Sep 21, 2021
02397b5
cu tests
mcabbott Sep 22, 2021
a3e3a97
v0.6.22
mcabbott Sep 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.21"
version = "0.6.22"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
ChainRulesCore = "1.1"
ChainRulesCore = "1.6"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ julia> using Zygote
julia> f(x) = 5x + 3

julia> f(10), f'(10)
(53, 5)
(53, 5.0)

julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
Expand Down
22 changes: 22 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,33 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end

"""
_project(x, dx)

Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape.
Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

# Piracy:
# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
oxinabox marked this conversation as resolved.
Show resolved Hide resolved

"""
ZBack{F}(back) <: Function

Expand Down
23 changes: 16 additions & 7 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
([14.0, 22.0], 2, nothing)
([14.0, 22.0], 2.0, nothing)
```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a method to _project and avoid this change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add a method to _project and avoid this change

Can you write exactly what method that would be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like _project(x, ::Nothing) = nothing maybe

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is easy to try:

julia> _project(x, ::Nothing) = nothing
_project (generic function with 1 method)

julia> map(_project, (1,2,3), nothing)
ERROR: MethodError: no method matching length(::Nothing)


Base.adjoint(f::Function) = x -> gradient(f, x)[1]
# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons
y, back = pullback(f, x)
back(sensitivity(y))[1]
end

"""
withgradient(f, args...)
Expand All @@ -95,7 +100,9 @@ true
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
(val = y, grad = back(sensitivity(y)))
grad = back(sensitivity(y))
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
(val=y, grad=results)
end

# Param-style wrappers
Expand All @@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do
Grads(...)

julia> g[x]
2×3 Matrix{Int64}:
7 70 700
8 80 800
2×3 Matrix{Float64}:
7.0 70.0 700.0
8.0 80.0 800.0

julia> haskey(g, z) # only x and y are parameters
false
Expand All @@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in

Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)

function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
return ps
Expand Down
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ end
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (dx, map(_->nothing, inds)...)
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
Expand Down
19 changes: 10 additions & 9 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think doing this makes unbroadcast less generic, we don't need to define projections here afaict. Let's retain the current definition.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What case exactly is not handled, if this is less generic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It restricts it to what can be handled by _project as opposed to simple sizes and lengths of arrays.

Copy link
Member

@oxinabox oxinabox Sep 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

those are broadly the same now, as of recent changes. _project will never method error now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that before CRC changes, _project had extra methods to handle other cases.

trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)

unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))

function unbroadcast(x::AbstractArray, x̄)
N = ndims(x̄)
if length(x) == length(x̄)
_project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
else
dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
_project(x, accum_sum(x̄; dims = dims))
end
end
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand Down
33 changes: 31 additions & 2 deletions test/complex.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
using Zygote, Test, LinearAlgebra

@testset "basic" begin

@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ -1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im
@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero
@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im
@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0

@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
Expand All @@ -21,6 +25,8 @@ using Zygote, Test, LinearAlgebra
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3)

end # @testset

fs_C_to_R = (real,
imag,
abs,
Expand Down Expand Up @@ -81,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj,
end
end
end

@testset "issue 342" begin
@test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,)
@test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,)
end

@testset "issue 402" begin
A = [1,2,3.0]
y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
bA = B_getindex(1)[1]
@test bA isa Diagonal
@test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
end

@testset "issue #917" begin
function fun(v)
c = v[1:3] + v[4:6]*im
r = v[7:9]
sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c
end
@test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
end

39 changes: 37 additions & 2 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
using CUDA
using Zygote: Grads
using LinearAlgebra
using Random: randn!
CUDA.allowscalar(false)

# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
@test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
@test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32}
@test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix
@test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
@test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul!

# Other direction:
@test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray
@test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
end

@testset "broadcasting" begin
Expand All @@ -31,17 +39,38 @@ end
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]

# Projection: eltype preservation:
@test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
@test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback
# structure restoration:
@test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix
@test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
# non-differentiables
@test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing
end

@testset "sum(f, x)" begin
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01]
a_gpu = a |> cu

f(x) = sum(abs, x)
g = gradient(f, a)[1]
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g

f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule
g2 = gradient(f2, a)[1]
g2_gpu = gradient(f2, a_gpu)[1]
@test g2_gpu isa CuArray
@test g2_gpu |> collect ≈ g2

f3(x) = sum(y->y^3, x') # anonymous function
g3 = gradient(f3, a')[1]
g3_gpu = gradient(f3, a_gpu')[1]
@test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure
@test g3_gpu |> collect ≈ g3
end

@testset "jacobian" begin
Expand Down Expand Up @@ -103,5 +132,11 @@ end
r = cu(rand(Float32, 3))
grads = (cu(ones(Float32, 3)), 1.f0)
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads

@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32}
@test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection

@test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
@test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end

24 changes: 22 additions & 2 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,9 @@ end

@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)

@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),)
@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)

@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),)
@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)

struct Bar{T}
a::T
Expand Down Expand Up @@ -262,6 +262,7 @@ D(f, x) = grad(f, x)[1]
@test D(x -> x*D(y -> x+y, 1), 1) == 1
@test D(x -> x*D(y -> x*y, 1), 4) == 8

@test sin''(1.0) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)

f(x) = throw(DimensionMismatch("fubar"))
Expand Down Expand Up @@ -499,6 +500,25 @@ end
@test x[1] == x[2]
end

@testset "splats" begin
@test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1]
@test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0)

@test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1]
@test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1]

# https://github.com/FluxML/Zygote.jl/issues/599
@test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector

# https://github.com/FluxML/Zygote.jl/issues/866
f866(x) = reshape(x, fill(2, 2)...)
@test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1]

# https://github.com/FluxML/Zygote.jl/issues/731
f731(x) = sum([x' * x, x...])
@test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64})
end

@testset "accumulation" begin
# from https://github.com/FluxML/Zygote.jl/issues/905
function net(x1)
Expand Down
3 changes: 2 additions & 1 deletion test/forward/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ end == 1
x
end == 0

@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1]
@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real

using LinearAlgebra

Expand Down
Loading