In [80]:
using Pkg; Pkg.activate(".")
using ReverseDiff, ForwardDiff, Optimisers, Zygote, 
      LinearAlgebra, PrettyPrinting, NamedTupleTools

[32m[1m  Activating[22m[39m project at `~/gits/PACMAN_ADworkshop/Day1`


In [81]:
# define the sigmoid and its derivative. We could of course 
# evaluate the derivative by hand, but we'll use ForwardDiff.
# It is convenient and there is no performance penalty.
σ(x) = 1 / (1 + exp(x))
dσ(x) = ForwardDiff.derivative(σ, x)

"""
Define a simple MLP with 2 hidden layers:

   mlp(x) =  W₃ ⋅ σ(W₂ ⋅ σ(W₁ ⋅ x + b₁) + b₂) + b₁
"""
function mlp(x, p) 
   x1 = σ.(p.W1 * x + p.b1)
   x2 = σ.(p.W2 * x1 + p.b2)
   return dot(p.W3, x2) + p.b3
end


"""
A simple utility function to generate inputs for `mlp`.
"""
function mlp_init(Nin, N1, N2) 
   x = randn(Nin)
   p = (W1 = randn(N1, Nin), b1 = randn(N1), 
        W2 = randn(N2, N1),  b2 = randn(N2), 
        W3 = randn(N2),      b3 = randn() )
   return x, p        
end


mlp_init

In [82]:
# evaluting the mlp gives a scalar output. 
x, p = mlp_init(3, 4, 2)
mlp(x, p)

-0.4519502898186545

In [83]:
# AD will "automatically" give us the gradient 
# w.r.t. the input and/or w.r.t. the parameters.
∇x_ad = Zygote.gradient(_x -> mlp(_x, p), x)[1]
∇p_ad = Zygote.gradient(_p -> mlp(x, _p), p)[1]

print("∇x = "); pprintln(∇x_ad)
println("-"^80)
print("∇p = "); pprint(∇p_ad)

∇x = [0.19589648744066998, 0.13484948834357607, -0.1663725479925786]
--------------------------------------------------------------------------------
∇p = (W1 =
     [0.00890500867210859 -0.00552460882521072 0.005680826470775023; 0.07731720000921377 -0.04796708248576493 0.04932343275916715; -0.002426969567074131 0.0015056759608523473 -0.0015482514917231959; 0.24745625391346748 -0.15352023277706883 0.15786127665360114],
 b1 = [-0.005576385657806202,
       -0.048416631707897305,
       0.0015197872101071974,
       -0.15495902991981758],
 W2 =
     [-0.17448288553681046 -0.13282264611097902 -0.002058355355426552 -0.16621033460007367; 0.06773809898640472 0.05156467651380961 0.0007990988823122418 0.06452651251763021],
 b2 = [-0.2597753415777973, 0.1008505089074285],
 W3 = [0.6913163670183934, 0.06838282040409016],
 b3 = 1.0)

In [84]:
# What AD (Zygote in this case) does is it generate a new function that 
# first executes the original function (forward pass) but then adds additional 
# steps - the backward pass - that accumulate the gradient information. 
# The following function is indicative of how this generated code might look 
# like if we were to inspect it. 

function mlp_withgrad(x, p)
   # unpack the parameters 
   W1 = p.W1; b1 = p.b1 
   W2 = p.W2; b2 = p.b2 
   W3 = p.W3; b3 = p.b3
   
   # Forward Pass
   # fwd stage 1
   x1 = σ.(W1 * x + b1)   
   # fwd stage 2 
   x2 = σ.(W2 * x1 + b2)  
   # fwd stage 3 
   y = dot(W3, x2) + b3    

   # Backward Pass
   # bwd stage 3
   ∂y_∂x2 = W3
   ∂y_∂W3 = x2
   ∂y_∂b3 = 1.0
   # bwd stage 2
   t2 = ∂y_∂x2 .* dσ.(W2 * x1 + b2)  # N₂ - vector
   ∂y_∂x1 = W2' * t2
   ∂y_∂W2 = t2 * x1'
   ∂y_∂b2 = t2 
   # bwd stage 1
   t1 = ∂y_∂x1 .* dσ.(W1 * x + b1)   # N₁ - vector
   ∂y_∂x = W1' * t1
   ∂y_∂W1 = t1 * x'
   ∂y_∂b1 = t1 

   # pack the gradients into a named tuple (~ static Dict)
   ∂y_∂p = ( W1 = ∂y_∂W1, b1 = ∂y_∂b1, 
             W2 = ∂y_∂W2, b2 = ∂y_∂b2, 
             W3 = ∂y_∂W3, b3 = ∂y_∂b3 )
            
   return y, ∂y_∂x, ∂y_∂p
end;

In [87]:
# We can evaluate the gradients with our hand-written implementation 
# and confirm that they are comparable. 
y, ∇x, ∇p = mlp_withgrad(x, p)
@show y ≈ mlp(x, p)
@show ∇x ≈ ∇x_ad 
@show all(∇p[s] ≈ ∇p_ad[s] for s ∈ fieldnames(∇p));

y ≈ mlp(x, p) = true
∇x ≈ ∇x_ad = true
all((∇p[s] ≈ ∇p_ad[s] for s = fieldnames(∇p))) = true


In [90]:
# To wrap this up, let's just compare the performance of these 
# two implementations.
using BenchmarkTools

println("Performance Zygote: ")
@btime Zygote.gradient(_p -> mlp($x, _p), $p);

println("Performance Hand-coded: ")
@btime mlp_withgrad($x, $p);

Performance Zygote: 


  2.199 μs (53 allocations: 7.16 KiB)
Performance Hand-coded: 


  635.355 ns (16 allocations: 1.48 KiB)


We shouldn't read too much into this performance comparison. There are some ways
to make Zygote a bit faster and there are MANY ways to make the hand-coded 
version MUCH faster. 

But still there is an important message. Backward differentiation 
is more than just a useful tool for AD. It is also a means for us to organize 
our algorithms in a systematic way to get the best possible performance. All 
AD tools have overheads, sometimes they are negligible, sometimes they dominate. 
Knowing how to hand-code adjoints and chains/networks of function evaluations 
in a systematic way can sometimes lead to significant performance gains. 