In [22]:
import Base: exp, log, sin, cos, tan, +, -, *, /, sqrt, convert, promote_rule, zero


In [23]:
mutable struct Dual{T <: Number} <: Number
    x::T
    dfdy::T
    parent::Union{Dual{T}, Array{Dual{T},1}, Int, Nothing}
    bp!
end

In [24]:
function convert(::Type{Dual{T}}, x::T) where T <: Number
    Dual(x, zero(T), nothing, (dfdy, parents) -> nothing)
end

function convert(::Type{Dual{T}}, x::S) where {T, S <: Number}
    Dual(T(x), zero(T), nothing, (dfdy, parents) -> nothing)
end

function convert(::Type{Dual{T}}, x::Dual{T}) where T
    x
end


convert (generic function with 189 methods)

In [25]:
function zero(x::Dual{T}) where T
    Dual(zero(T), zero(T), nothing, (dfdy, parents) -> nothing)
end

zero (generic function with 24 methods)

In [26]:
function promote_rule(::Type{Dual{T}}, ::Type{S}) where {T, S <: Number}
    Dual{promote_type(T,S)}
end

function promote_rule(::Type{T}, ::Type{Dual{S}}) where {T <: Number, S}
    Dual{promote_type(T,S)}
end

function promote_rule(::Type{S}, ::Type{Dual{T}}) where {S <: AbstractIrrational, T}
    Dual{promote_type(S, T)}
end


promote_rule (generic function with 128 methods)

In [27]:
function push_parents!(queue::Array{Dual{T}, 1}, ::Nothing) where T
    # Do nothing
end
function push_parents!(queue::Array{Dual{T}, 1}, i::Int) where T
    # Do nothing
end
function push_parents!(queue::Array{Dual{T}, 1}, ls::Array{Dual{T}, 1}) where T
    append!(queue, ls)
end
function push_parents!(queue::Array{Dual{T}, 1}, l::Dual{T}) where T
    push!(queue, l)
end

push_parents! (generic function with 8 methods)

In [28]:
function backprop!(l::Dual{T}) where T
    # Apparently we need this construction because otherwise l gets copied when
    # put into the array.
    backprop!([l])
end
function backprop!(queue::Array{Dual{T},1}) where T
    while length(queue) > 0
        l = popfirst!(queue)
        l.bp!(l.dfdy, l.parent)
        push_parents!(queue, l.parent)
    end
end

backprop! (generic function with 4 methods)

In [29]:
function collect_outputs(l::Dual{T}) where T
    queue = Dual{T}[l]

    outputs = Dual{T}[]

    while length(queue) > 0
        l = popfirst!(queue)
        if typeof(l.parent) <: Int
            push!(outputs, l)
        elseif typeof(l.parent) == Dual{T}
            push!(queue, l.parent)
        elseif typeof(l.parent) == Array{Dual{T}, 1}
            append!(queue, l.parent)
        else # Nothing
            # Do nothing
        end
    end

    outputs
end


collect_outputs (generic function with 2 methods)

In [30]:
function D(f)
    function dfdx(x::T) where T <: Number
        # Pass the function a backward infinitesimal whose backprop function
        # stores the backprop derivative in dfdx_store

        x = Dual(x, zero(x), 1, (dfdy, parents) -> nothing)

        result = f(x)

        result.dfdy = one(result.x)
        backprop!(result)

        y = collect_outputs(result)[1]

        return y.dfdy
    end

    function dfdx(x::Array{T}) where T <: Number
        fargs = [Dual(xelt, zero(xelt), i, (dfdy, parents) -> nothing) for (i, xelt) in enumerate(x)]
        result = f(fargs)
        result.dfdy = one(result.x)
        backprop!(result)
        y = collect_outputs(result)

        grad = zeros(typeof(result.x), length(x))
        for yelt in y
            grad[yelt.parent] = yelt.dfdy
        end

        return grad
    end

    function dfdx(x...)
        fargs = [Dual(xelt, zero(xelt), i, (dfdy, parents) -> nothing) for (i, xelt) in enumerate(x)]
        result = f(fargs...)
        result.dfdy = one(result.x)
        backprop!(result)
        y = collect_outputs(result)

        grad = zeros(typeof(result.x), length(x))
        for yelt in y
            grad[yelt.parent] = yelt.dfdy
        end

        return grad
    end

    return dfdx
end

D (generic function with 2 methods)

In [31]:
function D(i::Integer, f)
    df = D(f)
    function df_wrapper(x...)
        g = df(x...)
        return g[i]
    end
    return df_wrapper
end

D (generic function with 2 methods)

In [32]:
function bpp!(dfdy, xy)
    x, y = xy
    x.dfdy += dfdy
    y.dfdy += dfdy
end

function +(x::Dual{T}, y::Dual{T}) where T
    Dual(x.x + y.x, zero(T), [x, y], bpp!)
end

function bpm!(dfdy, xy)
    x, y = xy
    x.dfdy += dfdy
    y.dfdy -= dfdy
end
function -(x::Dual{T}, y::Dual{T}) where T
    Dual(x.x - y.x, zero(T), [x, y], bpm!)
end

function bpum!(dfdy, x)
    x.dfdy -= dfdy
end
function -(x::Dual{T}) where T
    Dual(-x.x, zero(T), x, bpum!)
end

function bpt!(dfdy, xy)
    x,y = xy
    x.dfdy += dfdy*y.x
    y.dfdy += x.x*dfdy
end
function *(x::Dual{T}, y::Dual{T}) where T
    Dual(x.x*y.x, zero(T), [x,y], bpt!)
end

function /(x::Dual{T}, y::Dual{T}) where T
    yinv = one(T)/y.x

    function bp!(dfdy, xy)
        a,b = xy
        a.dfdy += dfdy*yinv
        b.dfdy -= a.x*dfdy*(yinv*yinv)
    end

    Dual(x.x*yinv, zero(T), [x,y], bp!)
end

function exp(x::Dual{T}) where T
    expx = exp(x.x)

    function bp!(dfdy, p)
        p.dfdy += dfdy*expx
    end

    Dual(expx, zero(expx), x, bp!)
end

function sin(x::Dual{T}) where T
    function bp!(dfdy, p)
        p.dfdy += cos(x.x)*dfdy
    end

    sx = sin(x.x)
    Dual(sx, zero(sx), x, bp!)
end

function cos(x::Dual{T}) where T
    function bp!(dfdy, p)
        p.dfdy -= sin(x.x)*dfdy
    end

    cx = cos(x.x)
    Dual(cx, zero(cx), x, bp!)
end

function tan(x::Dual{T}) where T
    c = cos(x.x)
    function bp!(dfdy, p)
        p.dfdy += dfdy/(c*c)
    end

    tx = tan(x.x)
    Dual(tx, zero(tx), x, bp!)
end

function sqrt(x::Dual{T}) where T
    sqrtx = sqrt(x.x)

    function bp!(dfdy, p)
        p.dfdy += dfdy/(2*sqrtx)
    end

    Dual(sqrtx, zero(sqrtx), x, bp!)
end

function log(x::Dual{T}) where T
    function bp!(dfdy, p)
        p.dfdy += dfdy/x.x
    end

    Dual(log(x.x), zero(T), x, bp!)
end

log (generic function with 22 methods)

In [33]:
function test(x, y)
    return x^2 + y^3
end

test (generic function with 2 methods)

In [34]:
d = D(test)

(::var"#dfdx#37"{typeof(test)}) (generic function with 3 methods)

In [37]:
d(20,3)

2-element Array{Int64,1}:
 40
 27