Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add map/broadcast/algebra/iteration/dict interface for Grads #902

Merged
merged 17 commits into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,40 @@ Zygote.checkpointed
```

`Params` and `Grads` can be copied to and from arrays using the `copy!` function.

## Working with Grads

Map, broadcast, and iteration are supported for the dictionary-like `Grads` objects.
These operations are value based and preserve the keys.

```julia
using Zygote, Test

w, x1, x2, b = rand(2), rand(2), rand(2), rand(2)

gs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))
gs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))

# accumulate gradients
gs = gs1 .+ gs2
@test gs[w] ≈ gs1[w] + gs2[w]
@test gs[b] ≈ gs1[b] + gs2[b]

# gradients and dictionaries interact nicely
gs .+= Dict(p => randn(size(p)) for p in keys(gs))

# clip gradients
map(x -> clamp.(x, -0.1, 0.1), gs)

# clip gradients in-place
foreach(x -> clamp!(x, -0.1, 0.1), gs)

for (p, g) in pairs(gs)
# do something with parameter `p` and corresponding gradient `g`
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

# note that gradients must be w.r.t. to the same parameter key set
gs3 = gradient(() -> sum(tanh.(w .* x2)), Params([w]))
# gs3 does not have the key b
@test_throws ArgumentError gs1 .+ gs3
```
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using MacroTools: @forward
import Distributed: pmap, CachingPool, workers
export Params, gradient, jacobian, hessian, pullback, pushforward, @code_adjoint

const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}

include("tools/idset.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")
Expand Down
53 changes: 52 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
import Base: copy!
import Base.Broadcast: broadcasted, materialize!

mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
Expand Down Expand Up @@ -139,7 +140,23 @@ end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey
@forward Grads.grads Base.setindex!
@forward Grads.params Base.length

const ADictOrGrads = Union{AbstractDict, Grads}

# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x ∈ gs.params
Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)

function Base.iterate(gs::Grads, state...)
res = iterate(gs.params, state...)
isnothing(res) && return nothing
p, next_state = res
return gs[p], next_state
end

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
Expand Down Expand Up @@ -171,6 +188,40 @@ function copy!(x::AbstractVector, gs::Grads)
x
end

broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would want this

Suggested change
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
broadcasted(f, gs::ADictOrGrads, gss::ADictOrGrads...) = map(f, gs, gss...)

But that would overload broadcasting for the case where gs and gss are just Dicts (no Grads). Don't know a quick way around it but unfortunate that the first arg must be a Grad.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, I don't know either, unless we start adding a bunch of definitions like

broadcasted(f, gs1::AbstractDict, gs2::Grads, gss::ADictOrGrads...) = map(f, gs1, gs2, gss...)

and the corresponding ones for map, but this seems pretty annoying and I would leave it to future PRs

Copy link
Member

@darsnack darsnack Feb 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is broadcasted(f, gss::Vararg{Union{AbstractDict, Grads}}) different from broadcasted(f, gss::Vararg{<:Union{AbstractDict, Grads}})? I think the first forces the union type. Let me double check.

It does not 😞. Yeah I think leave this for another PR unless someone knows a quick fix.


broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)

function materialize!(gs1::Grads, gs2::Grads)
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
issetequal(gs1.params, gs2.params) ||
throw(ArgumentError("Expected Grads objects with the same Params."))
for p in gs1.params
gs1[p] = gs2[p]
end
return gs1
end


function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
gsout = Grads(IdDict{Any,Any}(), Params(gs1.params))
return map!(f, gsout, gs1, gss...)
end

function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
throw(ArgumentError("map! expects Grads objects with the same Params."))
for p in gsout.params
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end

function _getformap(gs, p)
g = gs[p]
isnothing(g) ? fill!(similar(p), 0) : g
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
end

function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
Expand Down
2 changes: 1 addition & 1 deletion src/forward/Forward.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Forward

import ..Zygote
import ..Zygote: __new__, __splatnew__
import ..Zygote: __new__, __splatnew__, Numeric

export pushforward

Expand Down
2 changes: 0 additions & 2 deletions src/forward/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Base.Broadcast: AbstractArrayStyle, broadcasted

Numeric{T<:Number} = Union{T,AbstractArray{<:T}}

@tangent Broadcast.preprocess(dest, bc) =
Broadcast.preprocess(dest, bc), (ddest, dbc) -> dbc

Expand Down
2 changes: 0 additions & 2 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
# to do CSE, then broadcast-ify the expression so that the closure captures the
# right arrays.

Numeric{T<:Number} = Union{T,AbstractArray{<:T}}

@adjoint broadcasted(::typeof(+), xs::Numeric...) =
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)

Expand Down
90 changes: 88 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "Parmas" begin
using Zygote: Grads

@testset "Params" begin
@testset "delete!" begin
w = rand(2,3)
b = rand(2)
Expand Down Expand Up @@ -55,4 +57,88 @@
@test ps1 == ps2
@test ps1 != ps3 # comparison is order dependent
end
end
end

@testset "Grads" begin
@testset "algebra" begin
w, b = rand(2), rand(2)
x1, x2 = rand(2), rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] ≈ 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] ≈ gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

@test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3))
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved

@testset "map and broadcast" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test map(x -> zeros(2), gs1) isa Grads

gs11 = map(x -> clamp.(x, -1e-5, 1e-5), gs1)
@test gs11 isa Grads
@test all(abs.(gs11[w]) .<= 1e-5)

@test (x -> zeros(2)).(gs1) isa Grads
end

@testset "dictionary interface" begin
w, b, x = rand(2), rand(2), rand(2)
ps = Params([w, b])
gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)

@test issetequal(keys(gs), ps)
@test length(values(gs)) == 2
@test length(pairs(gs)) == 2
k, v = first(pairs(gs))
@test k === first(ps)
@test v === gs[first(ps)]
end

@testset "iteration" begin
w, b, x = rand(2), rand(2), rand(2)
ps = Params([w, b])
gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)

# value-based iteration
foreach(x -> clamp!(x, -1e-5, 1e-5), gs)
@test all(abs.(gs[w]) .<= 1e-5)
@test all(abs.(gs[b]) .<= 1e-5)
end
end