Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mutation #75

Draft
wants to merge 23 commits into
base: master
Choose a base branch
from
Draft

Support Mutation #75

wants to merge 23 commits into from

Conversation

MikeInnes
Copy link
Member

Specifically, mutation of arrays/values, as opposed to mutation of data structures, which we already support well.

This introduces some internal complexity and makes performance a little trickier, so it remains open whether we'll actually want to merge it in. The main goal right now is that people can play with this and test it, and in particular I'd like to get some nice benchmarks with differential equations (#37).

It is of course not ideal to maintain this separately. One option might be to make mutation optional and compile different code if it's enabled, though again this is a significant additional complexity in the system.

This was referenced Mar 22, 2019
@andreasnoack
Copy link
Contributor

Tried this out with TransformVariables. The examples below transforms back and forth from Vector to NamedTuple, i.e. the identity, and the sums the result.

julia> using TransformVariables

julia> t = as((μ = asℝ, σ = asℝ₊))
TransformVariables.TransformNamedTuple{(:μ, :σ),Tuple{TransformVariables.Identity,TransformVariables.ShiftedExp{true,Float64}}}((TransformVariables.Identity(), TransformVariables.ShiftedExp{true,Float64}(0.0)), 2)

julia> sum(inverse(t, transform(t, ones(2))))
2.0

julia> gradient(s -> sum(inverse(t, transform(t, s))), ones(2))
ERROR: Compiling Tuple{typeof(inverse!),Array{Float64,1},TransformVariables.TransformNamedTuple{(, ),Tuple{TransformVariables.Identity,TransformVariables.ShiftedExp{true,Float64}}},NamedTuple{(, ),Tuple{Float64,Float64}}}: MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::ArgCheck.ArgCheckFlavor)
Closest candidates are:
  exprtype(::Core.Compiler.IRCode, ::Expr) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:64
  exprtype(::Core.Compiler.IRCode, ::QuoteNode) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:61
  exprtype(::Core.Compiler.IRCode, ::GlobalRef) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:60
  ...
Stacktrace:
 [1] _broadcast_getindex_evalf at ./broadcast.jl:625 [inlined]
 [2] _broadcast_getindex at ./broadcast.jl:598 [inlined]
 [3] getindex at ./broadcast.jl:558 [inlined]
 [4] copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at ./broadcast.jl:982
 [5] copy at ./broadcast.jl:836 [inlined]

@andreasnoack
Copy link
Contributor

Oh. I'm now reading the actual error more carefully and can see that it's related to ArgCheck.jl. I guess handling macros could be complicated. Any ideas to how to proceed?

@MikeInnes
Copy link
Member Author

Small update. Maintaining this separately is kind of a pain. I think I'd like to merge this and keep it under a feature flag like ZYGOTE_TYPED, and then it'll get CI and so on. Still no immediate plan to make it default behaviour, though.

@Roger-luo
Copy link
Contributor

remember to add this (or something equivalent) later to make tuple vector conversion work.

Zygote.@adjoint! function copyto!(xs::AbstractVector, ys::Tuple)
    xs_ = copy(xs)
    copyto!(xs, ys), function (dxs)
        copyto!(xs_, xs)
        return (nothing, Tuple(dxs))
    end
end

@cossio
Copy link
Contributor

cossio commented Jan 15, 2020

Is there a way to support mutation of arrays whose gradient is dropped (Zygote.dropgrad)? That is, suppose I have a function:

A = zeros(2,2)
f(x) = (Zygote.dropgrad(A .= x); return x .+ A)
f'(x) # doesn't work

So here I want to treat A as a constant and mutate it.

@MikeInnes
Copy link
Member Author

MikeInnes commented Jan 15, 2020

If you don't care about gradient at all you can define something like

ignore(f) = f()
@nograd ignore

and you can do the mutation inside a do block.

@maartenvd
Copy link
Contributor

So what is the plan for supporting mutation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants