Skip to content

Commit

Permalink
Zygote interop: treat everything immutable during differentiation
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Aug 29, 2019
1 parent da5a9aa commit 8d457a5
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 0 deletions.
21 changes: 21 additions & 0 deletions examples/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using BangBang
using LinearAlgebra

function rnn(n, J, x)
y = similar(x)
local result
for _ in 1:n
@! y = mul!(y, J, x)
@! y .= tanh.(y)
result = y
x, y = y, x
end
return result === x ? (@! y .= result) : result
end

using Zygote
d = 10
J = randn(d, d)
x0 = randn(d)
y_target = randn(d)
g, = Zygote.gradient(J -> sum((rnn(20, J, x0) .- y_target) .^ 2), J)
3 changes: 3 additions & 0 deletions src/BangBang.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ function __init__()
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" begin
include("staticarrays.jl")
end
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("zygote.jl")
end
end

end # module
4 changes: 4 additions & 0 deletions src/zygote.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# https://fluxml.ai/Zygote.jl/dev/adjoints/#Gradient-Reflection-1

# Treat everything immutable during differentiation:
Zygote.@adjoint possible(_args...) = false, _ -> nothing

0 comments on commit 8d457a5

Please sign in to comment.