In [74]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name :: String
    Variable(output; name="?") = new(output, nothing, name)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct BroadcastedOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    BroadcastedOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

import Base: show, summary
show(io::IO, x::ScalarOperator{F}) where {F} = print(io, "op ", x.name, "(", F, ")");
show(io::IO, x::BroadcastedOperator{F}) where {F} = print(io, "op.", x.name, "(", F, ")");
show(io::IO, x::Constant) = print(io, "const ", x.output)
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end

function visit(node::GraphNode, visited, order) 
    if node ∉ visited
        push!(visited, node)
        push!(order, node)
    end 
end

function visit(node::Operator, visited, order) 
    if node ∉ visited
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end 
end

function topological_sort(head::GraphNode) 
    visited = Set()
    order = Vector() 
    visit(head, visited, order) 
    return order
end

reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) = node.output = forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end

update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = if isnothing(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    @assert length(result.output) == 1
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end

backward! (generic function with 4 methods)

In [75]:
import LinearAlgebra: diagm
import LinearAlgebra: mul!

import Base: *
*(A::GraphNode, x::GraphNode) = BroadcastedOperator(mul!, A, x)
forward(::BroadcastedOperator{typeof(mul!)}, A, x) = return A * x
backward(::BroadcastedOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = BroadcastedOperator(*, x, y)
forward(::BroadcastedOperator{typeof(*)}, x, y) = return x .* y
backward(node::BroadcastedOperator{typeof(*)}, x, y, g) = let
    o = ones(length(node.output))
    Jx = diagm(vec(y .* o))
    Jy = diagm(vec(x .* o))
    tuple(Jx' * g, Jy' * g)
end

import Base: exp
Base.Broadcast.broadcasted(exp, x::GraphNode) = BroadcastedOperator(exp, x)
forward(::BroadcastedOperator{typeof(exp)}, x) = return exp.(x)
backward(node::BroadcastedOperator{typeof(exp)}, x, grad) = let
    o = ones(length(node.output))
    J = diagm(vec(node.output .* o))
    tuple(J' * grad)
end

import Base: sum
sum(x::GraphNode) = BroadcastedOperator(sum, x)
forward(::BroadcastedOperator{typeof(sum)}, x) = return sum(x)
backward(::BroadcastedOperator{typeof(sum)}, x, g) = let
    o = ones(length(x))
    J = o'
    tuple(J' * g)
end

Base.Broadcast.broadcasted(/, x::GraphNode, y::GraphNode) = BroadcastedOperator(/, x, y)
forward(::BroadcastedOperator{typeof(/)}, x, y) = return x ./ y
backward(node::BroadcastedOperator{typeof(/)}, x, y::Real, g) = let
    o = ones(length(node.output))
    Jx = diagm(vec(o ./ y))
    Jy = (-x ./ y .^2)
    tuple(Jx' * g, Jy' * g)
end

import Base: log
Base.Broadcast.broadcasted(log, x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    o = ones(length(x))
    J = o' ./ x
    tuple(J' * g)
end

backward (generic function with 11 methods)

In [76]:
function convOperation(I, K)
    n, m = size(I) .- size(K) .+ 1
    J = zeros(n, m)
    for i=1:n, j=1:m
        J[i, j] = sum(I[i:i+1, j:j+1] .* K)
    end
    return J
end
convLayer(x::GraphNode, k::GraphNode) = BroadcastedOperator(convLayer, x, k)
forward(::BroadcastedOperator{typeof(convLayer)}, x, k) = let
    return convOperation(x, k)
end
backward(node::BroadcastedOperator{typeof(convLayer)}, I, K, g) = let
    kh, kw = size(K)
    fgrad = zeros(Float32, size(K))
    outh, outw = size(node.output)

    for i in 1:outh
        for j in 1:outw
            fgrad += g[i, j] * I[i:i+kh-1, j:j+kw-1]
        end    
    end
    return fgrad
end
function convlayerInit(x, k , activation) return activation(convLayer(x, k)) end
function convlayerInit(x, k) return convLayer(x, k) end

function flatten(input)       
    return reshape(input, (:, 1))
end
flatten(x::GraphNode) = BroadcastedOperator(flatten, x)
forward(::BroadcastedOperator{typeof(flatten)}, x) = return flatten(x)
backward(::BroadcastedOperator{typeof(flatten)}, x, grad) = let
    result = reshape(grad, size(x))
    tuple(result)
end

function dense(w, x, activation) return activation(w * x) end
function dense(w, x) return w * x end

dense (generic function with 2 methods)

In [77]:
using MLDatasets
using Flux: onehotbatch
function load_fashion_mnist()

    train_data, train_labels = FashionMNIST(split=:train)[1:10]
    test_data, test_labels = FashionMNIST(split=:test)[1:10]

    train_data = train_data ./ 255.0
    test_data = test_data ./ 255.0

    train_data = reshape(train_data, (:, 1, 28, 28))
    test_data = reshape(test_data, (:, 1, 28, 28))

    train_labels = onehotbatch(train_labels, 0:9)
    test_labels = onehotbatch(test_labels, 0:9)

    X = Variable(train_data, name="images")
    y = Variable(train_labels, name = "labels")

    return X, y, test_data, test_labels
end

load_fashion_mnist (generic function with 1 method)

In [78]:
relu(x::GraphNode) = BroadcastedOperator(relu, x)
forward(::BroadcastedOperator{typeof(relu)}, x) = return max.(0, x)
backward(::BroadcastedOperator{typeof(relu)}, x, g) = let 
        result = g .* (x .>= 0)
        return tuple(result)
 end

elu(x::GraphNode) = BroadcastedOperator(elu, x)
forward(::BroadcastedOperator{typeof(elu)}, x) = return (x .>= 0) .* x + 1 .* (exp.(x) .- 1) .* (x .< 0)
backward(::BroadcastedOperator{typeof(elu)}, x, g) = let
        grad = g .* ((x .>= 0) .+ 1 .* exp.(x) .* (x .< 0))
        return tuple(grad)
end

softmax(x::GraphNode) = BroadcastedOperator(softmax, x)
forward(::BroadcastedOperator{typeof(softmax)}, x) = exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(softmax)}, x, g) = let
        y = ones(length(node.output))
        J = diagm(vec(node.output .* y)) - node.output * node.output'
        tuple(J' * g)
end

function cross_entropy_loss(x::GraphNode, y::GraphNode)
        return sum(y .* log.(exp.(x))) * Constant(-1.0)
end

function net(x, wc, wd, wo, y)
        c = convlayerInit(x, wc, relu)
        c.name = "conv layer"
        f = flatten(c)
        d1 = dense(wd, f, elu)
        d1.name = "dense leyer"
        d2 = dense(wo, d1, relu)
        d2.name = "output"
        E = cross_entropy_loss(y, d2)
        E.name = "loss"
    
        return topological_sort(E)
end

function inicializeTestData()
        Wc  = Variable(rand(2,2), name="Wagi conv")
        Wd  = Variable(rand(10,729), name="Wagi dense")
        Wo  = Variable(rand(10,10), name="Wagi out")
        X, Y, test_data, test_labels = load_fashion_mnist()  
        return Wc,Wd,Wo,X, Y
end

function trainSGD(epochs::Int, learning_rate::Real, expectedValueOfLoss::Real)

        Wc, Wd, Wo, X, Y = inicializeTestData()
        forwardVal = 0.0
        for epoch in 1:epochs
                for j in 1:size(X.output, 1)
                        x = Variable(X.output[j, 1, :, :])
                        y = Variable(Y.output[:, j])
                        graph = net(x, Wc, Wd, Wo, y)
                        forwardVal = forward!(graph)
                        if abs(forwardVal) < expectedValueOfLoss
                                print("Epoch nr. ", epoch, " Loss: ", round(forwardVal, digits=6), "\n")
                                return
                        end
                        backward!(graph)
                        if (forwardVal>0)
                        Wc.output .-= learning_rate .* Wc.gradient
                        Wd.output .-= learning_rate .* Wd.gradient
                        Wo.output .-= learning_rate .* Wo.gradient
                        else
                                Wc.output .+= learning_rate .* Wc.gradient
                                Wd.output .+= learning_rate .* Wd.gradient
                                Wo.output .+= learning_rate .* Wo.gradient
                        end
                end
                print("Epoch nr. ", epoch, " Loss: ", round(forwardVal, digits=6), "\n")
        end
end

trainSGD (generic function with 1 method)

In [79]:
trainSGD(100, 0.001, 0.001)

Epoch nr. 1 Loss: -4.234983
Epoch nr. 2 Loss: -4.061186
Epoch nr. 3 Loss: -3.889983
Epoch nr. 4 Loss: -3.721267
Epoch nr. 5 Loss: -3.55493
Epoch nr. 6 Loss: -3.390867
Epoch nr. 7 Loss: -3.228973
Epoch nr. 

8 Loss: -3.069147
Epoch nr. 9 Loss: -2.911285
Epoch nr. 10 Loss: -2.755289
Epoch nr. 11 Loss: -2.601059
Epoch nr. 12 Loss: -2.448497


Epoch nr. 13 Loss: -2.297507
Epoch nr. 14 Loss: -2.147993
Epoch nr. 15 Loss: -1.999859
Epoch nr. 

16 Loss: -1.856836
Epoch nr. 17 Loss: -1.716372
Epoch nr. 18 Loss: -1.577984
Epoch nr. 19 Loss: -1.444295
Epoch nr. 

20 Loss: -1.31673
Epoch nr. 21 Loss: -1.191693
Epoch nr. 22 Loss: -1.069817


Epoch nr. 23 Loss: -0.954566
Epoch nr. 24 Loss: -0.845227
Epoch nr. 25 Loss: -0.739913
Epoch nr. 26 Loss: -0.637605
Epoch nr. 27 Loss: -0.541069
Epoch nr. 28 Loss: -0.458316
Epoch nr. 29 Loss: -0.386488
Epoch nr. 30 Loss: -0.328714
Epoch nr. 31 Loss: -0.286841
Epoch nr. 32 Loss: -0.2509
Epoch nr. 33 Loss: -0.221252
Epoch nr. 34 Loss: -0.199701
Epoch nr. 35 Loss: -0.184162
Epoch nr. 

36 Loss: -0.170782
Epoch nr. 37 Loss: -0.158488
Epoch nr. 38 Loss: -0.147134
Epoch nr. 39 Loss: -0.137323
Epoch nr. 40 Loss: -0.128374
Epoch nr. 41 Loss: -0.119932
Epoch nr. 

42 Loss: -0.112068
Epoch nr. 43 Loss: -0.104635
Epoch nr. 44 Loss: -0.097573
Epoch nr. 45 Loss: -0.090743
Epoch nr. 46 Loss: -0.084129
Epoch nr. 47 Loss: -0.077728
Epoch nr. 48 Loss: -0.071809
Epoch nr. 49 Loss: -0.066119
Epoch nr. 50 Loss: -0.060642
Epoch nr. 51 Loss: -0.055476
Epoch nr. 52 Loss: -0.050769
Epoch nr. 53 Loss: -0.046445
Epoch nr. 

54 Loss: -0.042537
Epoch nr. 55 Loss: -0.039074
Epoch nr. 56 Loss: -0.036174


Epoch nr. 57 Loss: -0.033844
Epoch nr. 58 Loss: -0.032283
Epoch nr. 59 Loss: -0.031155
Epoch nr. 60 Loss: -0.03026


Epoch nr. 61 Loss: -0.029447
Epoch nr. 62 Loss: -0.028672
Epoch nr. 63 Loss: -0.027917
Epoch nr. 64 Loss: -0.027205
Epoch nr. 65 Loss: -0.026506
Epoch nr. 66 Loss: -0.025869
Epoch nr. 67 Loss: -0.025241
Epoch nr. 

68 Loss: -0.02462
Epoch nr. 69 Loss: -0.024029
Epoch nr. 70 Loss: -0.023469
Epoch nr. 71 Loss: -0.022916
Epoch nr. 

72 Loss: -0.022367
Epoch nr. 73 Loss: -0.021823
Epoch nr. 74 Loss: -0.021284
Epoch nr. 75 Loss: -0.020745
Epoch nr. 76 Loss: -0.020213
Epoch nr. 77 Loss: -0.019683
Epoch nr. 

78 Loss: -0.019155
Epoch nr. 79 Loss: -0.018631
Epoch nr. 80 Loss: -0.018107


Epoch nr. 81 Loss: -0.017587
Epoch nr. 82 Loss: -0.017069
Epoch nr. 83 Loss: -0.016558
Epoch nr. 84 Loss: -0.016054
Epoch nr. 85 Loss: -0.015552
Epoch nr. 86 Loss: -0.01505
Epoch nr. 87 Loss: -0.014549
Epoch nr. 88 Loss: -0.014049


Epoch nr. 89 Loss: -0.013549
Epoch nr. 90 Loss: -0.013051
Epoch nr. 91 Loss: -0.012569
Epoch nr. 92 Loss: -0.012094
Epoch nr. 93 Loss: -0.011623
Epoch nr. 94 Loss: -0.011154
Epoch nr. 95 Loss: -0.010685
Epoch nr. 96 Loss: -0.010219


Epoch nr. 97 Loss: -0.00977
Epoch nr. 98 Loss: -0.009336
Epoch nr. 99 Loss: -0.008903
Epoch nr. 100 Loss: -0.008476
