In [77]:
using Pkg
Pkg.activate("./project")

[32m[1m  Activating[22m[39m project at `e:\Documents\1. Studia\MAGISTERSKIE\Algorytmy w inż danych\Custom-Julia-NN\project`


### Structures
Definition of basic structures for computational graph

In [78]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name :: String
    Variable(output; name="?") = new(output, nothing, name)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct VectorOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    VectorOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

### Pretty-printing


In [79]:
import Base: show, summary
show(io::IO, x::ScalarOperator{F}) where {F} = print(io, "op ", x.name, "(", F, ")");
show(io::IO, x::VectorOperator{F}) where {F} = print(io, "op.", x.name, "(", F, ")");
show(io::IO, x::Constant) = print(io, "const ", x.output)
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end


show (generic function with 277 methods)

### Graph building

In [80]:
function visit(node::GraphNode, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::Operator, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::GraphNode)
    visited = Set()
    order = Vector()
    visit(head, visited, order)
    return order
end

topological_sort (generic function with 1 method)

### Forward pass

In [81]:
reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) =
    node.output = forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end

forward! (generic function with 1 method)

### Backward pass

In [82]:
update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = if isnothing(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    @assert length(result.output) == 1 "Gradient is defined only for scalar functions"
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end

backward! (generic function with 4 methods)

### Implemented operations

In [83]:
import Base: ^
^(x::GraphNode, n::GraphNode) = ScalarOperator(^, x, n)
Base.Broadcast.broadcasted(^, x::GraphNode, n::GraphNode) = VectorOperator(^, x , n)
forward(::ScalarOperator{typeof(^)}, x, n) = return x^n
backward(::ScalarOperator{typeof(^)}, x, n, g) = tuple(g * n * x ^ (n-1), g * log(abs(x)) * x ^ n)

forward(::VectorOperator{typeof(^)}, x, n) = return x .^n
backward(::VectorOperator{typeof(^)}, x, n, g) =  let
    𝟏 = ones(length(node.output))
    Jx = diagm(x.^(n-1) .* 𝟏 *n)
    Jy = diagm(x .* 𝟏)
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 11 methods)

In [84]:
import Base: sin
sin(x::GraphNode) = ScalarOperator(sin, x)
forward(::ScalarOperator{typeof(sin)}, x) = return sin(x)
backward(::ScalarOperator{typeof(sin)}, x, g) = tuple(g * cos(x))

backward (generic function with 11 methods)

In [85]:
import Base: *
import LinearAlgebra: mul!
# x * y (aka matrix multiplication)
*(A::GraphNode, x::GraphNode) = VectorOperator(mul!, A, x)
forward(::VectorOperator{typeof(mul!)}, A, x) = return A * x
backward(::VectorOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

# x .* y (element-wise multiplication)
Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = VectorOperator(*, x, y)
forward(::VectorOperator{typeof(*)}, x, y) = return x .* y
backward(node::VectorOperator{typeof(*)}, x, y, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(y .* 𝟏)
    Jy = diagm(x .* 𝟏)
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 11 methods)

In [86]:
Base.Broadcast.broadcasted(-, x::GraphNode, y::GraphNode) = VectorOperator(-, x, y)
forward(::VectorOperator{typeof(-)}, x, y) = return x .- y
backward(::VectorOperator{typeof(-)}, x, y, g) = tuple(g,-g)

backward (generic function with 11 methods)

In [87]:
Base.Broadcast.broadcasted(+, x::GraphNode, y::GraphNode) = VectorOperator(+, x, y)
forward(::VectorOperator{typeof(+)}, x, y) = return x .+ y
backward(::VectorOperator{typeof(+)}, x, y, g) = tuple(g, g)

backward (generic function with 11 methods)

In [88]:
import Base: sum
sum(x::GraphNode) = VectorOperator(sum, x)
forward(::VectorOperator{typeof(sum)}, x) = return sum(x)
backward(::VectorOperator{typeof(sum)}, x, g) = let
    𝟏 = ones(length(x))
    J = 𝟏'
    tuple(J' * g)
end

backward (generic function with 11 methods)

In [89]:
softmax(x::GraphNode) = VectorOperator(softmax, x)
forward(::VectorOperator{typeof(softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::VectorOperator{typeof(softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 11 methods)

In [90]:
Base.Broadcast.broadcasted(/, x::GraphNode, y::GraphNode) = VectorOperator(/, x, y)
forward(::VectorOperator{typeof(/)}, x, y) = return x ./ y
backward(node::VectorOperator{typeof(/)}, x, y::Real, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(𝟏 ./ y)
    Jy = (-x ./ y .^2)
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 11 methods)

In [91]:
import Base: max
Base.Broadcast.broadcasted(max, x::GraphNode, y::GraphNode) = VectorOperator(max, x, y)
forward(::VectorOperator{typeof(max)}, x, y) = return max.(x, y)
backward(::VectorOperator{typeof(max)}, x, y, g) = let
    Jx = diagm(isless.(y, x))
    Jy = diagm(isless.(x, y))
    tuple(Jx' * g, Jy' * g)
end

backward (generic function with 11 methods)

In [92]:
x = Variable(5.0, name="x")
two = Constant(2.0)
squared = x^two
sine = sin(squared)

order = topological_sort(sine)

4-element Vector{Any}:
 var x
 ┣━ ^ Float64
 ┗━ ∇ Nothing
 const 2.0
 op ?(typeof(^))
 op ?(typeof(sin))

In [106]:
using LinearAlgebra
Wh  = Variable(randn(10,2), name="wh")
Wo  = Variable(randn(1,10), name="wo")
x = Variable([1.98, 4.434], name="x")
y = Variable([0.064], name="y")
losses = Float64[]

function dense(w, b, x, activation) return activation(w * x .+ b) end
function dense(w, x, activation) return activation(w * x) end
function dense(w, x) return w * x end

function mean_squared_loss(y, ŷ)
    return Constant(0.5) .* (y .- ŷ) .^ Constant(2)
end

function net(x, wh, wo, y)
    σ(x) = (x) / ((x) + one(x.output.length)
    x̂ = dense(wh, x, σ)
    x̂.name = "x̂"
    ŷ = dense(wo, x̂)
    ŷ.name = "ŷ"
    E = mean_squared_loss(y, ŷ)
    E.name = "loss"

    return topological_sort(E)
end
graph = net(x, Wh, Wo, y)
forward!(graph)
backward!(graph)

for (i,n) in enumerate(graph)
    print(i, ". "); println(n)
end

ErrorException: syntax: missing comma or ) in argument list

In [100]:
x = Variable([1.98, 4.434], name="x")
y = Variable([0.064], name="y")

var y
 ┣━ ^ 1-element Vector{Float64}
 ┗━ ∇ Nothing

In [105]:
x.

2-element Vector{Float64}:
 1.98
 4.434