-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Closed
Copy link
Labels
Description
Reverse mode
julia> f(x,y) = sum(x .* y)
f (generic function with 1 method)
julia> a = rand(2)
2-element Vector{Float64}:
0.41323466569127665
0.4525169547907606
julia> b = rand(2)
2-element Vector{Float64}:
0.37009493514138525
0.818796709700481
julia> a_re = adapt(ConcreteRArray, a)
2-element ConcreteRArray{Float64, 1}:
0.41323466569127665
0.4525169547907606
julia> b_re = adapt(ConcreteRArray, b)
2-element ConcreteRArray{Float64, 1}:
0.37009493514138525
0.818796709700481
julia> f_grad(x,y) = Enzyme.gradient(Reverse, f, x, y)
f_grad (generic function with 1 method)
julia> @jit f_grad(a_re,b_re)
(ConcreteRArray{Float64, 1}([0.37009493514138525, 0.818796709700481]), ConcreteRArray{Float64, 1}([0.41323466569127665, 0.4525169547907606]))
julia> function f_grad(x,y)
dx = Enzyme.make_zero(x)
dy = Enzyme.make_zero(y)
return Enzyme.autodiff(ReverseWithPrimal, f, Active, Duplicated(x,dx), Duplicated(y,dy))
end
f_grad (generic function with 1 method)
julia> @jit f_grad(a_re,b_re)
((nothing, nothing), ConcreteRNumber{Float64}(0.5234554504635411))