Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mutation #75

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using LinearAlgebra: copytri!
# has several caveats. Recursion will cause inference to stack overflow.
# Gradient redefinitions may result in ugly type errors. And Jameson *will* know.
const usetyped = get(ENV, "ZYGOTE_TYPED", false) == "true"
const mutate = true

using IRTools
using MacroTools, Requires
Expand Down
22 changes: 15 additions & 7 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
struct Key
id::UInt64
Key(x) = new(objectid(x))
end

# Shaves some time on dict lookups (which is all we use this for).
Base.hash(k::Key) = k.id

mutable struct Context
cache::Union{IdDict{Any,Any},Nothing}
cache::Union{Dict{Key,Any},Nothing}
globals::Union{Dict{GlobalRef,Any},Nothing}
end

Context() = Context(nothing, nothing)

cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache
cache(cx::Context) = cx.cache === nothing ? (cx.cache = Dict{Key,Any}()) : cx.cache
globals(cx::Context) = cx.globals === nothing ? (cx.globals = Dict{GlobalRef,Any}()) : cx.globals

struct Pullback{S,T}
Expand Down Expand Up @@ -79,24 +87,24 @@ function Base.show(io::IO, ps::Params)
end

struct Grads
grads::IdDict{Any,Any}
grads::Dict{Key,Any}
end

Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")

@forward Grads.grads Base.getindex, Base.haskey

function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
return gs.grads[x]
return gs.grads[Key(x)]
end

Base.haskey(gs::Grads, x) = haskey(gs.grads, Key(x))

function forward(f, ps::Params)
cx = Context()
y, back = _forward(cx, f)
y, function (Δ)
for p in ps
cache(cx)[p] = nothing
cache(cx)[Key(p)] = ismutvalue(p) ? grad_mut(p) : nothing
end
back(Δ)
Grads(cx.cache) # TODO make a copy
Expand Down
100 changes: 92 additions & 8 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,108 @@
ismutvalue(x::AbstractArray) = !isimmutable(x)

_zero(xs::AbstractArray{<:Integer}) = fill!(similar(xs, float(eltype(xs))), false)
_zero(xs::AbstractArray{<:Number}) = zero(xs)
_zero(xs::AbstractArray) = Any[nothing for x in xs]

grad_mut(xs::AbstractArray) = _zero(xs)

function accum(x::AbstractArray, y::AbstractArray)
if ismutvalue(x) || ismutvalue(y)
@assert x === y
return x
else
accum.(x, y)
end
end

function accum!(x::AbstractArray, y)
x === y && return x
x .= accum.(x, y)
return x
end

@adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing

@nograd size, length, eachindex, Colon(), findfirst, randn, ones, zeros, one, zero,
print, println


@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,)

@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)

@adjoint (::Type{T})(x::T) where T<:Array = T(x), ȳ -> (ȳ,)

_zero(xs::AbstractArray{<:Integer}) = fill!(similar(xs, float(eltype(xs))), false)
_zero(xs::AbstractArray{<:Number}) = zero(xs)
_zero(xs::AbstractArray) = Any[nothing for x in xs]

@adjoint function getindex(xs::Array, i...)
# TODO a smarter implementation for mutable arrays
# we should just grab `dxs` and mutate it
@adjoint function getindex(xs::AbstractArray, i...)
xs[i...], function (Δ)
Δ′ = _zero(xs)
Δ′[i...] = Δ
(Δ′, map(_ -> nothing, i)...)
end
end

@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
_ -> error("Mutating arrays is not supported")
@adjoint! function setindex!(x::AbstractArray, v, i...)
old = x[i...]
setindex!(x, v, i...), function (dx)
if dx !== nothing
dv = dx[i...]
ismutvalue(dx) && (view(dx, i...) .= 0)
else
dv = nothing
end
x[i...] = old
return (dx, dv, map(_ -> nothing, i)...)
end
end

# Special case for potentially-undef arrays
# TODO: clean this up and fold it into the normal version
@adjoint! function setindex!(x::AbstractArray, v, i::Int...)
isdef = isassigned(x, i...)
isdef && (old = x[i...])
setindex!(x, v, i...), function (dx)
if dx !== nothing
dv = dx[i...]
ismutvalue(dx) && (view(dx, i...) .= 0)
else
dv = nothing
end
isdef && (x[i...] = old)
return (dx, dv, map(_ -> nothing, i)...)
end
end

@adjoint! function push!(xs::AbstractVector, x)
push!(xs, x), function (dxs)
dx = dxs === nothing ? nothing : dxs[end]
if ismutvalue(dxs)
pop!(dxs)
else
dxs = dxs[1:end-1]
end
pop!(xs)
return (dxs, dx)
end
end

@adjoint! function pop!(xs::AbstractVector)
x = pop!(xs)
x, function (dx)
dxs = _zero(xs)
push!(dxs, dx)
push!(xs, x)
return (dxs,)
end
end

@adjoint! function copyto!(xs::AbstractArray, ys::AbstractArray)
xs_ = copy(xs)
copyto!(xs, ys), function (dxs)
copyto!(xs_, xs)
(nothing, dxs)
end
end

# General

Expand Down Expand Up @@ -159,12 +238,17 @@ _backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims)
# LinAlg
# ======

ismutvalue(x::Transpose) = ismutvalue(x.parent)
ismutvalue(x::LinearAlgebra.Adjoint) = ismutvalue(x.parent)

@adjoint function(a::AbstractVecOrMat * b::AbstractVecOrMat)
return a * b, function(Δ)
return (reshape(Δ * b', size(a)), reshape(a' * Δ, size(b)))
end
end

@adjoint dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs)

@adjoint transpose(x) = transpose(x), Δ -> (transpose(Δ),)
@adjoint Base.adjoint(x) = x', Δ -> (Δ',)
@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
Expand Down
5 changes: 2 additions & 3 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ using Base: @get!

# Gradient of AD stacks

grad_mut(::AbstractVector) = []

@adjoint! function _push!(a::Vector, x)
_push!(a, x), function (y)
dstk = grad_mut(__context__, a)
Expand All @@ -14,9 +12,10 @@ grad_mut(::AbstractVector) = []
end

@adjoint! function pop!(stk::Stack)
i = stk.idx
pop!(stk), function (Δ)
dstk = grad_mut(__context__, stk.data)
push!(dstk, Δ)
dstk[i] = Δ
return
end
end
Expand Down
5 changes: 3 additions & 2 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function gradm(ex, mut = false)
!isempty(args) && isvararg(args[end]) && (argnames[end] = :($(argnames[end])...,))
args = esc.(args)
argnames = esc.(argnames)
∂argnames = isclosure ? [f, argnames...] : argnames
Ts = esc.(Ts)
cx = :($(esc(:__context__))::Context)
fargs = kw == nothing ? [cx, :($f::$T), args...] : [kw, cx, :($f::$T), args...]
Expand All @@ -44,13 +45,13 @@ function gradm(ex, mut = false)
y, _back = adjoint(__context__, $f, $(argnames...))
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuple(_back(Δ))
return y, back
return y, mutback($cx, $gradtuple(($(∂argnames...),)), y, back)
end
@inline function Zygote._forward($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...); kw...)
$(mut ? nothing : :(back(::Nothing) = nothing))
back(Δ) = $gradtuplekw(_back(Δ))
return y, back
return y, mutback($cx, $gradtuplekw(($(∂argnames...),)), y, back)
end
nothing
end
Expand Down
51 changes: 44 additions & 7 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ accum(x, y) =
accum(x, y, zs...) = accum(accum(x, y), zs...)

accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)

@generated function accum(x::NamedTuple, y::NamedTuple)
grad(x) = x in fieldnames(y) ? :(y.$x) : :nothing
Expand Down Expand Up @@ -42,9 +41,11 @@ end
@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)

@generated function accum_param(cx::Context, x, Δ)
isbitstype(x) && return
isbitstype(x) && return
quote
haskey(cache(cx), x) && (cache(cx)[x] = accum(cache(cx)[x],Δ))
ismutvalue(x) && return accum!(grad_mut(cx, x), Δ)
k = Key(x)
haskey(cache(cx), k) && (cache(cx)[k] = accum(cache(cx)[k],Δ))
return
end
end
Expand All @@ -57,7 +58,7 @@ end

unwrap(x) = x

@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄); (x̄,))
@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),)

unwrap(ref, x) = x

Expand All @@ -66,6 +67,7 @@ unwrap(ref, x) = x
@adjoint unwrap(ref, x) = unwrap(x), function (x̄)
accum_global(__context__, ref, x̄)
accum_param(__context__, x, x̄)
return
end

# Tuples
Expand Down Expand Up @@ -161,11 +163,12 @@ end
grad_mut(x) = Ref{Any}(nt_nothing(x))

function grad_mut(cx::Context, x)
k = Key(x)
ch = cache(cx)
if haskey(ch, x)
ch[x]
if haskey(ch, k)
ch[k]
else
ch[x] = grad_mut(x)
ch[k] = grad_mut(x)
end
end

Expand Down Expand Up @@ -231,3 +234,37 @@ end
(nothing, ($(map(f -> :(x̄.$f), fieldnames(T))...),))
end
end

# Mutable Primitives (e.g. arrays)

ismutvalue(x) = false

mutkey(x) = ismutvalue(x) ? Key(x) : nothing
mutkeys(xs...) = map(mutkey, xs)

function out_grad_mut(cx, x, x̄)
ismutvalue(x) || return x̄
Δ = grad_mut(cx, x)
accum!(Δ, x̄)
return Δ
end

out_grad_mut(cx, xs::Tuple, dxs) = map((x, dx) -> out_grad_mut(cx, x, dx), xs, dxs)
out_grad_mut(cx, xs::Tuple, ::Nothing) = nothing

function in_grad_mut(cx, x, x̄)
ismutvalue(x) || return x̄
return accum!(grad_mut(cx, x), x̄)
end

in_grad_mut(cx, xs::Tuple, dxs) = map((x, dx) -> in_grad_mut(cx, x, dx), xs, dxs)
in_grad_mut(cx, ::Tuple, ::Nothing) = nothing

mutback(cache, ks::NTuple{<:Any,Nothing}, ::Nothing, back) = back

function mutback(cx, xs, y, back::F) where F
return function (ȳ)
dxs = back(out_grad_mut(cx, y, ȳ))
in_grad_mut(cx, xs, dxs)
end
end
28 changes: 22 additions & 6 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ if !Zygote.usetyped
end

y, back = Zygote.forward(x->tuple(x...), [1, 2, 3])
@test back((1, 1, 1)) == ((1,1,1),)
@test back((1, 1, 1))[1] in ((1,1,1),[1,1,1])

# Test for some compiler errors on complex CFGs
function f(x)
Expand All @@ -234,11 +234,6 @@ end

@test Zygote.@code_adjoint(f(1)) isa Zygote.Adjoint

@test_throws ErrorException Zygote.gradient(1) do x
push!([], x)
return x
end

@test gradient(1) do x
stk = []
Zygote._push!(stk, x)
Expand All @@ -262,6 +257,27 @@ end == (10,)
x + x
end == (2,)

# Mutation

@test gradient([1, 2],3) do x, y
x[1] = y
x[1] * x[2]
end == ([0, 3], 2)

using LinearAlgebra

@test gradient([1, 2]) do x
y = x ⋅ x
x[1] = 3
y
end == ([2, 4],)

@test gradient(2) do x
xs = []
push!(xs, x)
pop!(xs)*x
end == (4,)

global_param = 3

@testset "Global Params" begin
Expand Down