Skip to content
master
Switch branches/tags
Code

Latest commit

1130: Fix incorrect `@forward`ing of `Base.in` on `Params` r=DhairyaLGandhi a=ToucheSir

Because `MacroTools.`@forward`` puts the forwarded expression as the first parameter, it was getting the order of arguments for `in` backwards. c.f:
```
julia> x = rand(10);

julia> ps = Params([x]);

julia> `@which` x in ps # was falling back to iterating over Params.order (Buffer) and not Params.params (IdSet)!
in(x, itr) in Base at operators.jl:1280

julia> `@which` ps in x # this method generated by `@forward`
in(x::Params, args...; kwargs...) in Zygote at /home/brianc/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:17
```
Beyond being incorrect/vestigial, this was also causing invalidations at import time:
```
 inserting in(x::Params, args...; kwargs...) in Zygote at /home/brianc/.julia/packages/MacroTools/PP9IQ/src/examples/forward.jl:17 invalidated:
   backedges: 1: superseding in(x, itr) in Base at operators.jl:1280 with MethodInstance for in(::Any, ::Tuple{String, String}) (1 children)
              2: superseding in(x, s::Set) in Base at set.jl:58 with MethodInstance for in(::Any, ::Set{Symbol}) (3 children)
              3: superseding in(x, itr) in Base at operators.jl:1280 with MethodInstance for in(::Any, ::Tuple{Char, Char}) (4 children)
              4: superseding in(x, s::Set) in Base at set.jl:58 with MethodInstance for in(::Any, ::Set{Any}) (5 children)
              5: superseding in(x, s::Set) in Base at set.jl:58 with MethodInstance for in(::Any, ::Set{String}) (9 children)
              6: superseding in(k, v::Base.KeySet{<:Any, <:IdDict}) in Base at iddict.jl:189 with MethodInstance for in(::Any, ::Base.KeySet{Any, IdDict{Any, Any}}) (10 children)
              7: superseding in(x, itr) in Base at operators.jl:1280 with MethodInstance for in(::Any, ::Vector{String}) (11 children)
              8: superseding in(x, itr) in Base at operators.jl:1280 with MethodInstance for in(::Any, ::NTuple{8, String}) (67 children)
   6 mt_cache
```
Credit to `SnoopCompile.`@snoopr`` for helping catch this!

Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
601f888

Git stats

Files

Permalink
Failed to load latest commit information.

CI Testing Dev Docs

] add Zygote

Zygote provides source-to-source automatic differentiation (AD) in Julia, and is the next-gen AD system for the Flux differentiable programming framework. For more details and benchmarks of Zygote's technique, see our paper. You may want to check out Flux for more interesting examples of Zygote usage; the documentation here focuses on internals and advanced AD usage.

Zygote supports Julia 1.0 onwards, but we highly recommend using Julia 1.3 or later.

julia> using Zygote

julia> f(x) = 5x + 3

julia> f(10), f'(10)
(53, 5.0)

julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
top:
  ret i64 5
}

"Source-to-source" means that Zygote hooks into Julia's compiler, and generates the backwards pass for you – as if you had written it by hand.

Zygote supports the full flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more.

julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);

julia> gradient(x -> fs[readline()](x), 1)
sin
0.5403023058681398

Defining custom gradients is a cinch, and errors have good stacktraces.

julia> using Zygote: @adjoint

julia> add(a, b) = a + b

julia> @adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)

To support large machine learning models with many parameters, Zygote can differentiate implicitly-used parameters, as opposed to just function arguments.

julia> W, b = rand(2, 3), rand(2);

julia> predict(x) = W*x .+ b;

julia> g = gradient(Params([W, b])) do
         sum(predict([1,2,3]))
       end
Grads(...)

julia> g[W], g[b]
([1.0 2.0 3.0; 1.0 2.0 3.0], [1.0, 1.0])