# Automatic differentiation of RNN network

##### Author: Michał Tomczyk 311524

In [1]:
using Random, Distributions

function xavier_init(input_dim::Int, output_dim::Union{Int,Nothing}=nothing)
    if output_dim === nothing
    scale = sqrt(2.0 / (input_dim + 1))
        return rand(Normal(0, scale), input_dim)
    end
    scale = sqrt(2.0 / (input_dim + output_dim))
    return rand(Normal(0, scale), input_dim, output_dim)
end

abstract type Node end

mutable struct InputNode <: Node
    output::AbstractVecOrMat
    name::String

    InputNode(output::AbstractVecOrMat; name="?"::String) = new(output, name)
    InputNode(output_size; name="?"::String) = new(zeros(output_size...), name)
end

struct ConstantNode <: Node
    output::AbstractVecOrMat
    name::String
    ConstantNode(output; name="?"::String) = new([output], name)
end

mutable struct VariableNode <: Node
    output::AbstractVecOrMat
    gradient::AbstractVecOrMat
    name::String

    VariableNode(output_size::Tuple{Int,Int}; name="?"::String) = new(xavier_init(output_size...), zeros(output_size...), name)
    VariableNode(output::AbstractVecOrMat; name="?"::String) = new(output, zeros(size(output)), name)
end

mutable struct OperationNode{F} <: Node
    inputs::Vector{Node}
    output::Union{AbstractVecOrMat,Nothing}
    gradient::Union{AbstractVecOrMat,Nothing}
    name::String
    OperationNode(fun::F, inputs::Vector{Node}; name="?"::String) where {F} =
        new{F}(inputs, nothing, nothing, name)

end

# Traversing

In [2]:


function visit(node::Node, visited::Set{Node}, order::Vector{Node})
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::OperationNode, visited::Set{Node}, order::Vector{Node})
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::Node)
    visited = Set{Node}()
    order = Vector{Node}()
    visit(head, visited, order)
    return order
end


function init_nodes!(order::Vector{Node})
    for node in order
        init_node!(node)
    end
end


init_node!(node::ConstantNode) = nothing
init_node!(node::InputNode) = nothing
init_node!(node::VariableNode) = nothing
function init_node!(node::OperationNode)
    # println("init_node! ", typeof(node))
    # println("input sizes: ", [(input.name,size(input.output)) for input in node.inputs])
    # println("number of inputs: ", length(node.inputs))
    # println("inputs: ", [input.output for input in node.inputs])
    output_size = size(forward(node, [input.output for input in node.inputs]...))
    node.output = zeros(output_size)
    node.gradient = zeros(output_size)
    # println("init_node-successful! ", typeof(node))
end

init_node! (generic function with 4 methods)

# Forward / Backward

In [3]:


reset!(node::ConstantNode) = nothing
reset!(node::InputNode) = nothing
reset!(node::VariableNode) = fill!(node.gradient, zero(eltype(node.gradient)))
reset!(node::OperationNode) = fill!(node.gradient, zero(eltype(node.gradient)))
function reset!(order::Vector{Node})
    for node in order
        reset!(node)
    end
    return nothing
end


reset_operations!(node::ConstantNode) = nothing
reset_operations!(node::InputNode) = nothing
reset_operations!(node::VariableNode) = nothing
reset_operations!(node::OperationNode) = fill!(node.gradient, zero(eltype(node.gradient)))


compute!(node::ConstantNode) = nothing
compute!(node::InputNode) = nothing
compute!(node::VariableNode) = nothing
function compute!(node::OperationNode)
    node.output = forward(node, [input.output for input in node.inputs]...)
end

function forward!(order::Vector{Node})
    for node in order
        compute!(node)
        reset_operations!(node)
    end
    return last(order).output
end


update!(node::ConstantNode, gradient) = nothing
update!(node::VariableNode, gradient) = let 
    # println("node: ", size(node.gradient))
    # println("gradient: ", size(gradient))
if length(size(node.gradient)) == 1 || size(node.gradient)[2]==1
    # println("bias!")
    # println("summed: ",size(sum(gradient,dims=2)))
    if  isnothing(node.gradient)
        # println("nothing!")
        node.gradient = sum(gradient,dims=2)
    else
        # println("appending!")
        node.gradient .+= sum(gradient,dims=2)
    end
else
    if  isnothing(node.gradient)
        node.gradient = gradient
    else
        node.gradient .+= gradient
    end
end
# println("done!")
end
update!(node::InputNode, gradient) = nothing
update!(node::OperationNode, gradient) =
    let
        # println("node: ", size(node.gradient))
        # println("gradient: ", size(gradient))
        if length(size(node.gradient)) == 1 || size(node.gradient)[2]==1
            # println("bias!")
            # println("summed: ",size(sum(gradient,dims=2)))
            if  isnothing(node.gradient)
                # println("nothing!")
                node.gradient = sum(gradient,dims=2)
            else
                # println("appending!")
                node.gradient .+= sum(gradient,dims=2)
            end
        else
            if  isnothing(node.gradient)
                node.gradient = gradient
            else
                node.gradient .+= gradient
            end
        end
# println("done!")

    end

function backward!(order::Vector{Node}; seed=1.0)
    result = last(order)
    result.gradient = [seed]
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::ConstantNode) end
function backward!(node::VariableNode) end
function backward!(node::InputNode) end
function backward!(node::OperationNode)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end

backward! (generic function with 5 methods)

# Operators

In [4]:
import Base: +
+(x::Node, y::Node) = OperationNode(+, Node[x, y])
forward(::OperationNode{typeof(+)}, x, y) = return x .+ y
backward(::OperationNode{typeof(+)}, x, y, g) = tuple(g, g)


import Base: -
Base.Broadcast.broadcasted(-, x::Node, y::Node) = OperationNode(-, Node[x, y])
forward(::OperationNode{typeof(-)}, x, y) = return x .- y
backward(::OperationNode{typeof(-)}, x, y, g) = tuple(g, -g)


import Base: *
import LinearAlgebra: mul!
# x * y (aka matrix multiplication)
*(A::Node, x::Node) = OperationNode(mul!, Node[A, x])
forward(::OperationNode{typeof(mul!)}, A, x) = return A * x
backward(::OperationNode{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)


# x .* y (element-wise multiplication)
import Base: broadcast
broadcasted(*, x::Node, y::Node) = OperationNode(*, Node[x, y])
forward(::OperationNode{typeof(*)}, x, y) = return x .* y
# backward(::OperationNode{typeof(*)}, x, y, g) = tuple(g .* y, g .* x)
# backward(::OperationNode{typeof(*)}, x, y, g) = tuple(g .* y, x .* g)
backward(node::OperationNode{typeof(*)}, x, y, g) =
    let
return tuple(g .* y, g .* x)
    end

import Base: sum
sum(x::Node) = OperationNode(sum, Node[x])
forward(::OperationNode{typeof(sum)}, x) = return [sum(x)]
# backward(::OperationNode{typeof(sum)}, x, g) = tuple(g .* ones(size(x)))
# JEBIE SIE NA BACKWARDZIE
backward(::OperationNode{typeof(sum)}, x, g) =
    let
        𝟏 = ones(length(x))
        J = 𝟏'
        tuple(J' * g)
    end


import Base: ^
^(x::Node, n::Node) = OperationNode(^, Node[x, n])
forward(::OperationNode{typeof(^)}, x, n) = return x .^ n
backward(::OperationNode{typeof(^)}, x, n, g) =
    let
        return tuple(g .* n .* x .^ (n .- 1), g .* log.(abs.(x)) .* x .^ n)
    end


# tanh function overload with forward and backward methods
import Base: tanh
tanh(x::Node) = OperationNode(tanh, Node[x])
forward(::OperationNode{typeof(tanh)}, x) = return tanh.(x)
backward(::OperationNode{typeof(tanh)}, x, g) = tuple(g .* (1 .- tanh.(x) .^ 2))

# sigmoid function overload with forward and backward methods
import Base: broadcast
sigmoid(x::Node) = OperationNode(sigmoid, Node[x])
forward(::OperationNode{typeof(sigmoid)}, x) = return sigmoid.(x)
backward(::OperationNode{typeof(sigmoid)}, x, g) = tuple(g .* sigmoid.(x) .* (1 .- sigmoid.(x)))


cross_entropy_loss(y_hat::Node, y::Node) = OperationNode(cross_entropy_loss, Node[y_hat, y])
forward(::OperationNode{typeof(cross_entropy_loss)}, y_hat, y) =
    let
        y_hat = y_hat .- maximum(y_hat, dims=1)
        y_hat = exp.(y_hat) ./ sum(exp.(y_hat), dims=1)
        loss = sum(log.(y_hat) .* y, dims=1) * -1.0
        return loss
    end
backward(::OperationNode{typeof(cross_entropy_loss)}, y_hat, y, g) =
    let
        y_hat = y_hat .- maximum(y_hat, dims=1)
        y_hat = exp.(y_hat) ./ sum(exp.(y_hat), dims=1)
        return tuple(g .* (y_hat - y))
    end



backward (generic function with 9 methods)

# Network

In [5]:

mutable struct Network
   inputs::Vector{InputNode}
   
   Wx::VariableNode
   Wh::VariableNode
   b::VariableNode
   h::VariableNode
   
   Wy::VariableNode
   by::VariableNode

   desired_output::InputNode

   output_graph::Vector{Node}
   loss_graph::Vector{Node}
    
end



function declare_RNN(input_lenght::Int, output_length::Int, neurons::Int)
   inputs = Vector{InputNode}()

      # Wx = VariableNode((input_lenght,neurons))
      Wx = VariableNode((neurons, input_lenght),name="Wx")
      Wh = VariableNode((neurons, neurons),name="Wh")
      b = VariableNode((neurons,1),name="b")
      h = VariableNode(zeros(neurons,1), name="h")
      Wy = VariableNode((output_length, neurons),name="Wy")
      by = VariableNode((output_length,1),name="by")
      desired_output = InputNode(zeros(output_length), name="desired_output")
      output_graph = Vector{Node}()
      loss_graph = Vector{Node}()
      Network(inputs, Wx, Wh, b, h, Wy, by, desired_output, output_graph, loss_graph)
   end

function unfold!(network::Network, n_sequences::Int; batchsize=1)
      h = network.h
      Wh = network.Wh
      Wx = network.Wx
      b = network.b
      Wy = network.Wy
      by = network.by
      y = network.desired_output
      output_graph = network.output_graph
      loss_graph = network.loss_graph


      for i in 1:n_sequences
         x = InputNode(zeros((size(Wx.output)[2],batchsize)), name="x")
         push!(network.inputs, x)
         h = tanh((Wx * x) .+ (Wh * h) .+ b)
      end
      y_hat = (Wy * h) .+ by

      network.output_graph= topological_sort(y_hat)
      network.loss_graph = topological_sort(cross_entropy_loss(y_hat, y))
      init_nodes!(network.loss_graph)
end

function feed_with_sequence!(network::Network, sequences...)
      for (input, x_i) in zip(network.inputs, sequences)
         input.output = x_i
      end
      return nothing
end

function feed_desired_output!(network::Network, y::AbstractVecOrMat)
      network.desired_output.output = y
      return nothing
end

function adjust!(net::Network, lr, batchsize)
   net.Wx.output .-= lr .* (net.Wx.gradient ./ batchsize)
   net.Wh.output .-= lr .* (net.Wh.gradient ./ batchsize)
   net.b.output .-= lr .* (net.b.gradient ./ batchsize)
   net.Wy.output .-= lr .* (net.Wy.gradient ./ batchsize)
   net.by.output .-= lr .* (net.by.gradient ./ batchsize)
   
end

adjust! (generic function with 1 method)

In [6]:
function calculate_accuracy(net::Network, data)
    correct = 0
    graph = net.output_graph

    for (x, y) in loader(data, batchsize=settings.batchsize)
            feed_with_sequence!(net,
                view(x, 1:196, :),
                view(x, 197:392, :),
                view(x, 393:588, :),
                view(x, 589:784, :)
            )

            y = y[:, :]
            ŷ = forward!(graph)




            for i in 1:size(y, 2)
                if Flux.onecold(ŷ[:, i]) == Flux.onecold(y[:, i])
                    correct += 1
                end
            end
    end
    println("Correct: ", round(100 * correct / length(data); digits=2), "%")
end

calculate_accuracy (generic function with 1 method)

In [7]:
using MLDatasets, Flux
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)

function loader(data; batchsize::Int=1)
    x1dim = reshape(data.features, 28 * 28, :) # reshape 28×28 pixels into a vector of pixels
    yhot  = Flux.onehotbatch(data.targets, 0:9) # make a 10×60000 OneHotMatrix
    Flux.DataLoader((x1dim, yhot); batchsize, shuffle=true)
end

loader (generic function with 1 method)

# Testing

In [8]:
settings = (;
    epochs = 5,
    batchsize = 100,
    lr = 0.05
)


(epochs = 5, batchsize = 100, lr = 0.05)

In [9]:
n_sequences = 4

net = declare_RNN(14*14, 10, 64)
unfold!(net, n_sequences,batchsize=settings.batchsize)

In [10]:
calculate_accuracy(net, test_data)

Correct: 15.85%


In [11]:
loss_in_epoch = 0.0
losses = []

graph = net.loss_graph

@time for epoch in 1:settings.epochs
    println("Epoch: ", epoch)
    reset!(graph)
    loss_in_epoch = 0.0
    @time for (x, y) in loader(train_data, batchsize=settings.batchsize)
        reset!(graph)
            feed_with_sequence!(net,
                view(x, 1:196,:),
                view(x, 197:392,:),
                view(x, 393:588,:),
                view(x, 589:784,:))
            feed_desired_output!(net, y)
            forward!(graph)
            loss = forward!(graph)
            loss_in_epoch += sum(loss)
            backward!(graph)
        adjust!(net, settings.lr, settings.batchsize)
    end
    println("Current loss: ", loss_in_epoch)
    push!(losses, first(loss_in_epoch))
    print("train: ")
    calculate_accuracy(net, train_data)
    print("test: ")
    
    calculate_accuracy(net, test_data)

end

Epoch: 1
 11.927221 seconds (3.70 M allocations: 3.502 GiB, 2.44% gc time, 26.74% compilation time)
Current loss: 35727.77719722139
train: Correct: 90.6%
test: Correct: 90.97%
Epoch: 2
  8.480574 seconds (544.33 k allocations: 3.296 GiB, 2.11% gc time)
Current loss: 17615.490734986957
train: Correct: 92.79%
test: Correct: 92.99%
Epoch: 3
  8.380669 seconds (544.33 k allocations: 3.296 GiB, 2.11% gc time)
Current loss: 14019.592758277879
train: Correct: 94.09%
test: Correct: 94.3%
Epoch: 4
  8.442271 seconds (544.33 k allocations: 3.296 GiB, 2.15% gc time)
Current loss: 11930.279809664424
train: Correct: 94.88%
test: Correct: 94.83%
Epoch: 5
  8.202175 seconds (544.33 k allocations: 3.296 GiB, 2.17% gc time)
Current loss: 10506.721461909621
train: Correct: 95.35%
test: Correct: 95.22%
 59.416249 seconds (8.01 M allocations: 23.091 GiB, 2.34% gc time, 5.53% compilation time)


In [12]:
calculate_accuracy(net, test_data)

Correct: 95.22%
