Skip to content

Commit

Permalink
Merge #902
Browse files Browse the repository at this point in the history
902: add map/broadcast/algebra/iteration/dict interface for Grads r=CarloLucibello a=CarloLucibello

Fix FluxML/Flux.jl#707

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it>
  • Loading branch information
3 people committed Feb 20, 2021
2 parents a4da332 + 8d7c0fe commit 6b89a06
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 8 deletions.
37 changes: 37 additions & 0 deletions docs/src/utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,40 @@ Zygote.checkpointed
```

`Params` and `Grads` can be copied to and from arrays using the `copy!` function.

## Working with Grads

Map, broadcast, and iteration are supported for the dictionary-like `Grads` objects.
These operations are value based and preserve the keys.

```julia
using Zygote, Test

w, x1, x2, b = rand(2), rand(2), rand(2), rand(2)

gs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))
gs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))

# accumulate gradients
gs = gs1 .+ gs2
@test gs[w] gs1[w] + gs2[w]
@test gs[b] gs1[b] + gs2[b]

# gradients and dictionaries interact nicely
gs .+= Dict(p => randn(size(p)) for p in keys(gs))

# clip gradients
map(x -> clamp.(x, -0.1, 0.1), gs)

# clip gradients in-place
foreach(x -> clamp!(x, -0.1, 0.1), gs)

for (p, g) in pairs(gs)
# do something with parameter `p` and corresponding gradient `g`
end

# note that gradients must be w.r.t. to the same parameter key set
gs3 = gradient(() -> sum(tanh.(w .* x2)), Params([w]))
# gs3 does not have the key b
@test_throws ArgumentError gs1 .+ gs3
```
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using MacroTools: @forward
import Distributed: pmap, CachingPool, workers
export Params, gradient, jacobian, hessian, pullback, pushforward, @code_adjoint

const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}

include("tools/idset.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")
Expand Down
53 changes: 52 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
import Base: copy!
import Base.Broadcast: broadcasted, materialize!

mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
Expand Down Expand Up @@ -139,7 +140,23 @@ end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey
@forward Grads.grads Base.setindex!
@forward Grads.params Base.length

const ADictOrGrads = Union{AbstractDict, Grads}

# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x gs.params
Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)

function Base.iterate(gs::Grads, state...)
res = iterate(gs.params, state...)
isnothing(res) && return nothing
p, next_state = res
return gs[p], next_state
end

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
Expand Down Expand Up @@ -171,6 +188,40 @@ function copy!(x::AbstractVector, gs::Grads)
x
end

broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)

broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)

function materialize!(gs1::Grads, gs2::Grads)
issetequal(gs1.params, gs2.params) ||
throw(ArgumentError("Expected Grads objects with the same Params."))
for p in gs1.params
gs1[p] = gs2[p]
end
return gs1
end


function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
gsout = Grads(IdDict{Any,Any}(), Params(gs1.params))
return map!(f, gsout, gs1, gss...)
end

function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
throw(ArgumentError("map! expects Grads objects with the same Params."))
for p in gsout.params
gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end

function _getformap(gs, p)
g = gs[p]
isnothing(g) ? fill!(similar(p), 0) : g
end

function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
Expand Down
2 changes: 1 addition & 1 deletion src/forward/Forward.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Forward

import ..Zygote
import ..Zygote: __new__, __splatnew__
import ..Zygote: __new__, __splatnew__, Numeric

export pushforward

Expand Down
2 changes: 0 additions & 2 deletions src/forward/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Base.Broadcast: AbstractArrayStyle, broadcasted

Numeric{T<:Number} = Union{T,AbstractArray{<:T}}

@tangent Broadcast.preprocess(dest, bc) =
Broadcast.preprocess(dest, bc), (ddest, dbc) -> dbc

Expand Down
2 changes: 0 additions & 2 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
# to do CSE, then broadcast-ify the expression so that the closure captures the
# right arrays.

Numeric{T<:Number} = Union{T,AbstractArray{<:T}}

@adjoint broadcasted(::typeof(+), xs::Numeric...) =
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)

Expand Down
90 changes: 88 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
@testset "Parmas" begin
using Zygote: Grads

@testset "Params" begin
@testset "delete!" begin
w = rand(2,3)
b = rand(2)
Expand Down Expand Up @@ -55,4 +57,88 @@
@test ps1 == ps2
@test ps1 != ps3 # comparison is order dependent
end
end
end

@testset "Grads" begin
@testset "algebra" begin
w, b = rand(2), rand(2)
x1, x2 = rand(2), rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test .- gs1 isa Grads
@test gs1 .- gs2 isa Grads
@test .+ gs1 isa Grads
@test gs1 .+ gs2 isa Grads
@test 2 .* gs1 isa Grads
@test (2 .* gs1)[w] 2 * gs1[w]
@test gs1 .* 2 isa Grads
@test gs1 ./ 2 isa Grads
@test (gs1 .+ gs2)[w] gs1[w] .+ gs2[w]

gs12 = gs1 .+ gs2
gs1 .+= gs2
@test gs12[w] gs1[w]

gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))

@test .- gs3 isa Grads
@test gs3 .- gs4 isa Grads
@test .+ gs3 isa Grads
@test gs3 .+ gs4 isa Grads
@test 2 .* gs3 isa Grads
@test gs3 .* 2 isa Grads
@test gs3 ./ 2 isa Grads
@test (gs3 .+ gs4)[w] gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] gs4[b]

@test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3))
@test gs3 isa Grads

@test_throws ArgumentError gs1 .+ gs4
end

@testset "map and broadcast" begin
w = rand(2)
x1 = rand(2)
x2 = rand(2)

gs1 = gradient(() -> sum(w .* x1), Params([w]))
gs2 = gradient(() -> sum(w .* x2), Params([w]))

@test map(x -> zeros(2), gs1) isa Grads

gs11 = map(x -> clamp.(x, -1e-5, 1e-5), gs1)
@test gs11 isa Grads
@test all(abs.(gs11[w]) .<= 1e-5)

@test (x -> zeros(2)).(gs1) isa Grads
end

@testset "dictionary interface" begin
w, b, x = rand(2), rand(2), rand(2)
ps = Params([w, b])
gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)

@test issetequal(keys(gs), ps)
@test length(values(gs)) == 2
@test length(pairs(gs)) == 2
k, v = first(pairs(gs))
@test k === first(ps)
@test v === gs[first(ps)]
end

@testset "iteration" begin
w, b, x = rand(2), rand(2), rand(2)
ps = Params([w, b])
gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)

# value-based iteration
foreach(x -> clamp!(x, -1e-5, 1e-5), gs)
@test all(abs.(gs[w]) .<= 1e-5)
@test all(abs.(gs[b]) .<= 1e-5)
end
end

0 comments on commit 6b89a06

Please sign in to comment.