# Forward AD 

In [478]:
struct Dual{T <:Number} <:Number
    v::T
   dv::T
end

import Base: +, -, *, /
-(x::Dual)          = Dual(-x.v,       -x.dv)
+(x::Dual, y::Dual) = Dual( x.v + y.v,  x.dv + y.dv)
-(x::Dual, y::Dual) = Dual( x.v - y.v,  x.dv - y.dv)
*(x::Dual, y::Dual) = Dual( x.v * y.v,  x.dv * y.v + x.v * y.dv)
/(x::Dual, y::Dual) = Dual( x.v / y.v, (x.dv * y.v - x.v * y.dv)/y.v^2)

import Base: abs, sin, cos, tan, exp, sqrt, isless
abs(x::Dual)  = Dual(abs(x.v),sign(x.v)*x.dv)
sin(x::Dual)  = Dual(sin(x.v), cos(x.v)*x.dv)
cos(x::Dual)  = Dual(cos(x.v),-sin(x.v)*x.dv)
tan(x::Dual)  = Dual(tan(x.v), one(x.v)*x.dv + tan(x.v)^2*x.dv)
exp(x::Dual)  = Dual(exp(x.v), exp(x.v)*x.dv)
sqrt(x::Dual) = Dual(sqrt(x.v),.5/sqrt(x.v) * x.dv)
isless(x::Dual, y::Dual) = x.v < y.v;

import Base: convert, promote_rule

convert(::Type{Dual{T}}, x::Dual) where T = Dual(convert(T, x.v), convert(T, x.dv))
@show Dual{Float64}[Dual(1,2), Dual(3,0)];

convert(::Type{Dual{T}}, x::Number) where T = Dual(convert(T, x), zero(T))
@show Dual{Float64}[1, 2, 3];

promote_rule(::Type{Dual{T}}, ::Type{R}) where {T,R} = Dual{promote_type(T,R)}
@show Dual(1,2) * 3;

import Base: show
show(io::IO, x::Dual) = print(io, "(", x.v, ") + [", x.dv, "ϵ]");
value(x::Dual) = x.v;
partials(x::Dual) = x.dv;

J = function jacobian(f, args::Vector{T}) where {T <:Number}
    jacobian_columns = Matrix{T}[]
    
    for i=1:length(args)
        x = Dual{T}[]
        for j=1:length(args)
            seed = (i == j)
            push!(x, seed ?
                Dual(args[j], one(args[j])) :
                Dual(args[j],zero(args[j])) )
        end
        column = partials.([f(x)...])
        push!(jacobian_columns, column[:,:])
    end
    hcat(jacobian_columns...)
end

Dual{Float64}[Dual(1, 2), Dual(3, 0)] = Dual{Float64}[(1.0) + [2.0ϵ], (3.0) + [0.0ϵ]]
Dual{Float64}[1, 2, 3] = Dual{Float64}[(1.0) + [0.0ϵ], (2.0) + [0.0ϵ], (3.0) + [0.0ϵ]]
Dual(1, 2) * 3 = (3) + [6ϵ]


jacobian (generic function with 1 method)

# Backward AD

In [479]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name :: String
    Variable(output; name="?") = new(output, nothing, name)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct BroadcastedOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    BroadcastedOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

import Base: show, summary
show(io::IO, x::ScalarOperator{F}) where {F} = print(io, "op ", x.name, "(", F, ")");
show(io::IO, x::BroadcastedOperator{F}) where {F} = print(io, "op.", x.name, "(", F, ")");
show(io::IO, x::Constant) = print(io, "const ", x.output)
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end

function visit(node::GraphNode, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::Operator, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::GraphNode)
    visited = Set()
    order = Vector()
    visit(head, visited, order)
    return order
end

reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) =
    node.output = forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end

update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = if isnothing(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    @assert length(result.output) == 1 "Gradient is defined only for scalar functions"
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end


backward! (generic function with 4 methods)

In [480]:
import Base: ^
^(x::GraphNode, n::GraphNode) = ScalarOperator(^, x, n)
forward(::ScalarOperator{typeof(^)}, x, n) = return x^n
backward(::ScalarOperator{typeof(^)}, x, n, g) = tuple(g * n * x ^ (n-1), g * log(abs(x)) * x ^ n)

import Base: sin
sin(x::GraphNode) = ScalarOperator(sin, x)
forward(::ScalarOperator{typeof(sin)}, x) = return sin(x)
backward(::ScalarOperator{typeof(sin)}, x, g) = tuple(g * cos(x))

import Base: *
import LinearAlgebra: mul!
*(A::GraphNode, x::GraphNode) = BroadcastedOperator(mul!, A, x)
forward(::BroadcastedOperator{typeof(mul!)}, A, x) = return A * x
backward(::BroadcastedOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = BroadcastedOperator(*, x, y)
forward(::BroadcastedOperator{typeof(*)}, x, y) = return x .* y
backward(node::BroadcastedOperator{typeof(*)}, x, y, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(y .* 𝟏)
    Jy = diagm(x .* 𝟏)
    tuple(Jx' * g, Jy' * g)
end

Base.Broadcast.broadcasted(-, x::GraphNode, y::GraphNode) = BroadcastedOperator(-, x, y)
forward(::BroadcastedOperator{typeof(-)}, x, y) = return x .- y
backward(::BroadcastedOperator{typeof(-)}, x, y, g) = tuple(g,-g)

Base.Broadcast.broadcasted(-, x::GraphNode) = BroadcastedOperator(-, x)
forward(::BroadcastedOperator{typeof(-)}, x) = return -x
backward(::BroadcastedOperator{typeof(-)}, x, g) = tuple(-g)

Base.Broadcast.broadcasted(+, x::GraphNode, y::GraphNode) = BroadcastedOperator(+, x, y)
forward(::BroadcastedOperator{typeof(+)}, x, y) = return x .+ y
backward(::BroadcastedOperator{typeof(+)}, x, y, g) = tuple(g, g)

import Base: sum
sum(x::GraphNode) = BroadcastedOperator(sum, x)
forward(::BroadcastedOperator{typeof(sum)}, x) = return sum(x)
backward(::BroadcastedOperator{typeof(sum)}, x, g) = let
    𝟏 = ones(length(x))
    J = 𝟏'
    tuple(J' * g)
end

Base.Broadcast.broadcasted(/, x::GraphNode, y::GraphNode) = BroadcastedOperator(/, x, y)
forward(::BroadcastedOperator{typeof(/)}, x, y) = return x ./ y
backward(node::BroadcastedOperator{typeof(/)}, x, y, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(𝟏 ./ y)
    Jy = (-x ./ y .^2)
    tuple(Jx' * g, Jy' * g)
end

import Base: max
Base.Broadcast.broadcasted(max, x::GraphNode, y::GraphNode) = BroadcastedOperator(max, x, y)
forward(::BroadcastedOperator{typeof(max)}, x, y) = return max.(x, y)
backward(::BroadcastedOperator{typeof(max)}, x, y, g) = let
    Jx = diagm(isless.(y, x))
    Jy = diagm(isless.(x, y))
    tuple(Jx' * g, Jy' * g)
end

import Base: min
Base.Broadcast.broadcasted(min, x::GraphNode, y::GraphNode) = BroadcastedOperator(min, x, y)
forward(::BroadcastedOperator{typeof(min)}, x, y) = return min.(x, y)
backward(::BroadcastedOperator{typeof(min)}, x, y, g) = let
    Jx = diagm(isless.(x, y))
    Jy = diagm(isless.(y, x))
    tuple(Jx' * g, Jy' * g)
end

import Base: exp 
Base.Broadcast.broadcasted(exp, x::GraphNode) = BroadcastedOperator(exp, x)
forward(::BroadcastedOperator{typeof(exp)}, x) = return exp.(x)
backward(::BroadcastedOperator{typeof(exp)}, x, g) = let
    tuple(exp.(x) .* g)
end

backward (generic function with 13 methods)

# MNIST

In [481]:
using ImageCore
using MLDatasets
MNIST.convert2image(MNIST.traintensor(1));

# NET

In [509]:
using LinearAlgebra
function vdigit(y::Int)
    yv = zeros(10)
    yv[y+1] = 1
    yv
end

# Forward mode AD
ReLU_f(x) = max(zero(x), x)
ReLU2_f(x) = max(zero(x), x) + min(zero(x),x)*0.01
σ_f(x) = one(x) / (one(x) + exp(-x))

mse_f(y::Vector, ŷ::Vector) = sum(0.5(y - ŷ).^2)

fullyconnected(w::Vector, n::Number, m::Number, v::Vector, activation::Function) = activation.(reshape(w, n, m) * v)
Wh_f  = randn(15,784)
Wo_f  = randn(10,15)
dWh_f = similar(Wh_f)
dWo_f = similar(Wo_f)
x_f = reshape(MNIST.traintensor(Float64,1),784)
y_f = vdigit(MNIST.trainlabels(1))
e_f = Float64[]

function net_f(x, wh, wo, y)
    a = fullyconnected(wh, 15, 784, x, ReLU2_f)
    ŷ = fullyconnected(wo, 10, 15, a, u->u)
    E = mse_f(y, ŷ)
end

dnet_Wh(x, wh, wo, y) = J(w -> net_f(x, w, wo, y), wh);

dnet_Wo(x, wh, wo, y) = J(w -> net_f(x, wh, w, y), wo);

# Reverse mode AD
function ReLU_r(x) 
    return max.(Constant(zeros(15)), x) .+ min.(Constant(zeros(15)), x).*Constant(0.01)
end

function σ_r(x)
    return Constant(ones(15)) ./ (Constant(ones(15)) .+ exp.(.-x))
end

Wh_r  = Variable(Wh_f, name="wh")
Wo_r  = Variable(Wo_f, name="wo")
x_r = Variable(x_f, name="x")
y_r = Variable(y_f, name="y")
e_r = Float64[]

function dense(w, b, x, activation) return activation(w * x .+ b) end
function dense(w, x, activation) return activation(w * x) end
function dense(w, x) return w * x end

function mse_r(y, ŷ)
    return sum(Constant(0.5) .* (y .- ŷ) .* (y .- ŷ))
end

function net_r(x, wh, wo, y)
    x̂ = dense(wh, x, ReLU_r)
    x̂.name = "x̂"
    ŷ = dense(wo, x̂)
    ŷ.name = "ŷ"
    E = mse_r(y, ŷ)
    E.name = "loss"
    return topological_sort(E)
end

graph = net_r(x_r, Wh_r, Wo_r, y_r);

In [507]:
#Gradient + momentum step
α = 0.1
β = 0.5
vh_r = zero(Wh_r.output)
vo_r = zero(Wo_r.output)
vh_f = zero(Wh_f)
vo_f = zero(Wo_f)

function step_reverse()
    e = forward!(graph)
    backward!(graph)
    vh_r[:] = β*vh_r .- α*Wh_r.gradient;
    Wh_r.output += vh_r;
    vo_r[:] = β*vo_r .- α*Wo_r.gradient;
    Wo_r.output += vo_r;
    return e
end

function step_forward()
    e = net_f(x_f, Wh_f[:], Wo_f[:], y_f);
    # dnet_Wh(x, wh, wo, y) = J(w -> net_f(x, w, wo, y), wh);
    # dnet_Wo(x, wh, wo, y) = J(w -> net_f(x, wh, w, y), wo);
    dWh_f[:] = dnet_Wh(x_f, Wh_f[:], Wo_f[:], y_f);
    dWo_f[:] = dnet_Wo(x_f, Wh_f[:], Wo_f[:], y_f);
    vh_f[:] = β*vh_f .- α*dWh_f;
    Wh_f[:] += vh_f[:];
    vo_f[:] = β*vo_r .- α*dWo_f;
    Wo_f[:] += vo_r[:];
    return e
end

step_forward (generic function with 1 method)

In [512]:
#Forward AD gradient step 
e = net_f(x_f, Wh_f[:], Wo_f[:], y_f);
dnet_Wh(x, wh, wo, y) = J(w -> net_f(x, w, wo, y), wh);
dWh_f[:] = dnet_Wh(x_f, Wh_f[:], Wo_f[:], y_f);
dnet_Wo(x, wh, wo, y) = J(w -> net_f(x, wh, w, y), wo);
dWo_f[:] = dnet_Wo(x_f, Wh_f[:], Wo_f[:], y_f);
Wh_f -= 0.001dWh_f;
Wo_f -= 0.001dWo_f;
push!(e_f, e)

3-element Vector{Float64}:
 676.3729930074519
   8.283530501772796
   2.8080078728759323

In [515]:
#Reverse AD gradient step
e = forward!(graph);
backward!(graph);
Wh_r.output -= 0.001Wh_r.gradient;
Wo_r.output -= 0.001Wo_r.gradient;
push!(e_r, e)

3-element Vector{Float64}:
 676.3729930074519
   8.283530501772738
   2.8080078728759292

In [516]:
loss_f = step_forward();
loss_r = step_reverse();
loss_f, loss_r

(1.6206475198683705, 1.6206475198683645)

In [44]:
using BenchmarkTools
function autodiff_forward()
    dWh_f[:] = dnet_Wh(x_f, Wh_f[:], Wo_f[:], y_f);
    dWo_f[:] = dnet_Wo(x_f, Wh_f[:], Wo_f[:], y_f);
end

function autodiff_reverse()
    backward!(graph)
end

@btime autodiff_forward()
@btime autodiff_reverse()

  1.332 s (320850 allocations: 7.34 GiB)
  32.250 μs (249 allocations: 113.48 KiB)
