diff --git a/examples/zygote.jl b/examples/zygote.jl new file mode 100644 index 00000000..bfb49efd --- /dev/null +++ b/examples/zygote.jl @@ -0,0 +1,21 @@ +using BangBang +using LinearAlgebra + +function rnn(n, J, x) + y = similar(x) + local result + for _ in 1:n + @! y = mul!(y, J, x) + @! y .= tanh.(y) + result = y + x, y = y, x + end + return result === x ? (@! y .= result) : result +end + +using Zygote +d = 10 +J = randn(d, d) +x0 = randn(d) +y_target = randn(d) +g, = Zygote.gradient(J -> sum((rnn(20, J, x0) .- y_target) .^ 2), J) diff --git a/src/BangBang.jl b/src/BangBang.jl index 7c58fe6a..1d5920a6 100644 --- a/src/BangBang.jl +++ b/src/BangBang.jl @@ -38,6 +38,9 @@ function __init__() @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" begin include("staticarrays.jl") end + @require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("zygote.jl") + end end end # module diff --git a/src/zygote.jl b/src/zygote.jl new file mode 100644 index 00000000..843a9748 --- /dev/null +++ b/src/zygote.jl @@ -0,0 +1,4 @@ +# https://fluxml.ai/Zygote.jl/dev/adjoints/#Gradient-Reflection-1 + +# Treat everything immutable during differentiation: +Zygote.@adjoint possible(_args...) = false, _ -> nothing