### Structures
Definition of basic structures for computational graph

In [1]:
using MLDatasets, Flux, Statistics
train_data = MLDatasets.MNIST(split=:train)
test_data  = MLDatasets.MNIST(split=:test)


function loader(data)
    dim1, dim2, dim3 = size(data.features)
    x = reshape(data.features, dim1 * dim2, dim3)
    y = data.targets
    #x4dim = reshape(data.features, 28, 28, 1, :) # insert trivial channel dim
    yhot  = Flux.onehotbatch(data.targets, 0:9)  # make a 10×60000 OneHotMatrix
    return x, y, yhot
    #Flux.DataLoader((x4dim, yhot); batchsize, shuffle=true)
end

x1, y1, yhot = loader(train_data);

In [1]:
abstract type GraphNode end
abstract type Operator <: GraphNode end

struct Constant{T} <: GraphNode
    output :: T
end

mutable struct Variable <: GraphNode
    output :: Any
    gradient :: Any
    name :: String
    Variable(output; name="?") = new(output, nothing, name)
end

mutable struct ScalarOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    ScalarOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end

mutable struct BroadcastedOperator{F} <: Operator
    inputs :: Any
    output :: Any
    gradient :: Any
    name :: String
    BroadcastedOperator(fun, inputs...; name="?") = new{typeof(fun)}(inputs, nothing, nothing, name)
end
import Base: show, summary
show(io::IO, x::ScalarOperator{F}) where {F} = print(io, "op ", x.name, "(", F, ")");
show(io::IO, x::BroadcastedOperator{F}) where {F} = print(io, "op.", x.name, "(", F, ")", x.inputs, x.output, x.gradient);
show(io::IO, x::Constant) = print(io, "const ", x.output)
show(io::IO, x::Variable) = begin
    print(io, "var ", x.name);
    print(io, "\n ┣━ ^ "); summary(io, x.output)
    print(io, "\n ┗━ ∇ ");  summary(io, x.gradient)
end
function visit(node::GraphNode, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        push!(order, node)
    end
    return nothing
end
    
function visit(node::Operator, visited, order)
    if node ∈ visited
    else
        push!(visited, node)
        for input in node.inputs
            visit(input, visited, order)
        end
        push!(order, node)
    end
    return nothing
end

function topological_sort(head::GraphNode)
    visited = Set()
    order = Vector()
    visit(head, visited, order)
    return order
end
reset!(node::Constant) = nothing
reset!(node::Variable) = node.gradient = nothing
reset!(node::Operator) = node.gradient = nothing

compute!(node::Constant) = nothing
compute!(node::Variable) = nothing
compute!(node::Operator) =
    node.output = forward(node, [input.output for input in node.inputs]...)

function forward!(order::Vector)
    for node in order
        compute!(node)
        reset!(node)
    end
    return last(order).output
end
update!(node::Constant, gradient) = nothing
update!(node::GraphNode, gradient) = if isnothing(node.gradient)
    node.gradient = gradient else node.gradient .+= gradient
end

function backward!(order::Vector; seed=1.0)
    result = last(order)
    result.gradient = seed
    @assert length(result.output) == 1 "Gradient is defined only for scalar functions"
    for node in reverse(order)
        backward!(node)
    end
    return nothing
end

function backward!(node::Constant) end
function backward!(node::Variable) end
function backward!(node::Operator)
    inputs = node.inputs
    gradients = backward(node, [input.output for input in inputs]..., node.gradient)
    for (input, gradient) in zip(inputs, gradients)
        update!(input, gradient)
    end
    return nothing
end
import Base: ^
^(x::GraphNode, n::GraphNode) = ScalarOperator(^, x, n)
forward(::ScalarOperator{typeof(^)}, x, n) = return x^n
backward(::ScalarOperator{typeof(^)}, x, n, g) = tuple(g * n * x ^ (n-1), g * log(abs(x)) * x ^ n)
import Base: sin
sin(x::GraphNode) = ScalarOperator(sin, x)
forward(::ScalarOperator{typeof(sin)}, x) = return sin(x)
backward(::ScalarOperator{typeof(sin)}, x, g) = tuple(g * cos(x))
import Base: exp
exp(x::GraphNode) = ScalarOperator(exp, x)
forward(::ScalarOperator{typeof(exp)}, x) = return exp(x)
backward(::ScalarOperator{typeof(exp)}, x, g) = tuple(g * exp(x))


import Base: *
import LinearAlgebra: mul!
# x * y (aka matrix multiplication)
*(A::GraphNode, x::GraphNode) = BroadcastedOperator(mul!, A, x)
forward(::BroadcastedOperator{typeof(mul!)}, A, x) = return A * x
backward(::BroadcastedOperator{typeof(mul!)}, A, x, g) = tuple(g * x', A' * g)

# x .* y (element-wise multiplication)pu
Base.Broadcast.broadcasted(*, x::GraphNode, y::GraphNode) = BroadcastedOperator(*, x, y)
forward(::BroadcastedOperator{typeof(*)}, x, y) = return x .* y
backward(node::BroadcastedOperator{typeof(*)}, x, y, g) = let
    ones_vec = ones(length(node.output)) # I wektor jednostkowy
    Jx = diagm(y .* ones_vec) # I(length(node.output)) * yI
    Jy = diagm(x .* ones_vec)
    tuple(Jx' * g, Jy' * g)
end
Base.Broadcast.broadcasted(-, x::GraphNode, y::GraphNode) = BroadcastedOperator(-, x, y)
forward(::BroadcastedOperator{typeof(-)}, x, y) = return x .- y
backward(::BroadcastedOperator{typeof(-)}, x, y, g) = tuple(g,-g)
Base.Broadcast.broadcasted(+, x::GraphNode, y::GraphNode) = BroadcastedOperator(+, x, y)
forward(::BroadcastedOperator{typeof(+)}, x, y) = return x .+ y
backward(::BroadcastedOperator{typeof(+)}, x, y, g) = tuple(g, g)
import Base: sum
sum(x::GraphNode) = BroadcastedOperator(sum, x)
forward(::BroadcastedOperator{typeof(sum)}, x) = return sum(x)
backward(::BroadcastedOperator{typeof(sum)}, x, g) = let
    𝟏 = ones(length(x))
    J = 𝟏'
    tuple(J' * g)
end


Base.Broadcast.broadcasted(/, x::GraphNode, y::GraphNode) = BroadcastedOperator(/, x, y)
forward(::BroadcastedOperator{typeof(/)}, x, y) = return x ./ y
backward(node::BroadcastedOperator{typeof(/)}, x, y::Real, g) = let
    𝟏 = ones(length(node.output))
    Jx = diagm(𝟏 ./ y)
    Jy = (-x ./ y .^2)
    tuple(Jx' * g, Jy' * g)
end

import Base: max
Base.Broadcast.broadcasted(max, x::GraphNode, y::GraphNode) = BroadcastedOperator(max, x, y)
forward(::BroadcastedOperator{typeof(max)}, x, y) = return max.(x, y)
backward(::BroadcastedOperator{typeof(max)}, x, y, g) = let
    Jx = diagm(isless.(y, x))
    Jy = diagm(isless.(x, y))
    tuple(Jx' * g, Jy' * g)
end

import Base: log
log(x::GraphNode) = BroadcastedOperator(log, x)
forward(::BroadcastedOperator{typeof(log)}, x) = return log.(x)
backward(::BroadcastedOperator{typeof(log)}, x, g) = let
    J = diagm(1 ./ x)
    tuple(J' * g)
end


σ(x::GraphNode) = BroadcastedOperator(σ, x)
forward(::BroadcastedOperator{typeof(σ)}, x) = return 1.0 ./ (1.0 .+ exp.(-x))
backward(::BroadcastedOperator{typeof(σ)}, x, g) = let
    J = diagm(1.0 ./ (1.0 .+ exp.(-x))).*(1.0 .- (1.0 ./ (1.0 .+ exp.(-x))))
    tuple(J' * g)
end

Base.Broadcast.broadcasted(^, x::GraphNode, y::GraphNode) = BroadcastedOperator(^, x, y)
forward(::BroadcastedOperator{typeof(^)}, x, y) = return x .^ y
backward(node::BroadcastedOperator{typeof(^)}, x, y, g) = let
    Jx = diagm(y .* x .^ (y .- 1.0))
    Jy = diagm(log.(abs.(x)) .* x .^ y)
    tuple(Jx' * g, Jy' * g)
end


backward (generic function with 13 methods)

In [2]:
logit_cross_entropy(y_predicted::GraphNode, y::GraphNode) = BroadcastedOperator(logit_cross_entropy, y_predicted, y)
forward(::BroadcastedOperator{typeof(logit_cross_entropy)}, y_predicted, y) =
    let
        y_predicted = y_predicted .- maximum(y_predicted)
        y_predicted = exp.(y_predicted) ./ sum(exp.(y_predicted))
        loss = sum(log.(y_predicted) .* y) * -1.0
        return loss
    end
backward(::BroadcastedOperator{typeof(logit_cross_entropy)}, y_predicted, y, g) =
    let
        y_predicted = y_predicted .- maximum(y_predicted)
        y_predicted = exp.(y_predicted) ./ sum(exp.(y_predicted))
        return tuple(g .* (y_predicted - y))
    end

backward (generic function with 14 methods)

In [11]:
x1 = Variable([1.9514772580227577, 1.848623792879922, -6.456514365666464, -5.10149420660439, 3.5482988599977925, -2.210665982825723, 4.338133361390976, -8.256227686203022, -9.262014936475973, -1.3432975629321589], name="x1")
x2 = Variable([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], name="x2")

b = logit_cross_entropy(x1, x2)

order = topological_sort(b)
currentloss = forward!(order)
backward!(order)


In [19]:
order[3].gradient

1.0

In [23]:
sin(25)

-0.13235175009777303

In [27]:
function log_softmax(x)
    # Subtract the max value for numerical stability
    shift_x = x .- maximum(x)
    log_exp_sum = log.(sum(exp.(shift_x)))
    return shift_x .- log_exp_sum
end


function logit_cross_entropy(ŷ::GraphNode, y::GraphNode)
    return .-sum(y .* log(ŷ) .+ (1 .- y) .* log(1 .- ŷ))
end

logit_cross_entropy (generic function with 1 method)

### The simplest multilayer-perceptron

In [113]:
using LinearAlgebra
Wh  = Variable(randn(10,784), name="wh")
Wo  = Variable(randn(2,10), name="wo")
x = Variable(x1[:, 1], name="x")
#y = Variable(yhot[:,1], name="y")
#x = Variable([1.98, 4.434], name="x")
y = Variable([0.064, 0.1234], name="y")

losses = Float64[]

function dense(w, b, x, activation) return activation(w * x .+ b) end
function dense(w, x, activation) return activation(w * x) end
function dense(w, x) return w * x end

function mean_squared_loss(y, ŷ)
    return sum(Constant(0.5) .* (y .- ŷ) .^ Constant(2))
end

function net(x, wh, wo, y)
    x̂ = dense(wh, x, σ)
    x̂.name = "x̂"
    ŷ = dense(wo, x̂)
    ŷ.name = "ŷ"
    E = logit_cross_entropy(ŷ, y)
    E.name = "loss"
    return topological_sort(E)
end
graph = net(x, Wh, Wo, y)
forward!(graph)
backward!(graph)

for (i,n) in enumerate(graph)
    print(i, ". "); println(n)
end

MethodError: MethodError: no method matching log(::BroadcastedOperator{typeof(mul!)})

Closest candidates are:
  log(!Matched::Float64)
   @ Base special\log.jl:267
  log(!Matched::BigFloat)
   @ Base mpfr.jl:727
  log(!Matched::Missing, !Matched::Missing)
   @ Base math.jl:1584
  ...


### Manual derivatives for comparison

In [19]:

for i=1:10
    currentloss = forward!(graph)
    backward!(graph)
    Wh.output -= 0.01Wh.gradient
    Wo.output -= 0.01Wo.gradient
    println("Current loss: ", currentloss)
    push!(losses, first(currentloss))
end


Current loss: [2.6569395463993826]
Current loss: [2.2946906133559444]
Current loss: [1.9682890420032868]
Current loss: [1.6770251382036299]
Current loss: [1.4201807795922223]
Current loss: [1.1966209745623364]
Current loss: [1.00452876090821]
Current loss: [0.8413701620291453]
Current loss: [0.7040700678099885]
Current loss: [0.5892978041825421]


In [69]:
losses

2-element Vector{Float64}:
 8.862781216261547
 8.015221933977145

In [20]:
using PyPlot
semilogy(losses, ".")
xlabel("epoch")
ylabel("loss")
grid()

┌ Info: Installing matplotlib via the Conda matplotlib package...
└ @ PyCall C:\Users\mbili\.julia\packages\PyCall\1gn3u\src\PyCall.jl:719
┌ Info: Running `conda install -y matplotlib` in root environment
└ @ Conda C:\Users\mbili\.julia\packages\Conda\sDjAP\src\Conda.jl:181


Channels:
 - defaults
 - conda-forge
Platform: win-64
Collecting package metadata (repodata.json): ...working... done
Solving environment: ...working... done




    current version: 23.11.0
    latest version: 24.3.0

Please update conda by running

    $ conda update -n base -c conda-forge conda





## Package Plan ##

  environment location: C:\Users\mbili\.julia\conda\3\x86_64

  added / updated specs:
    - matplotlib


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    ca-certificates-2024.3.11  |       haa95532_0         128 KB
    ------------------------------------------------------------
                                           Total:         128 KB

The following packages will be UPDATED:

  ca-certificates                     2023.12.12-haa95532_0 --> 2024.3.11-haa95532_0 



Downloading and Extracting Packages: ...working... done
Preparing transaction: ...working... done
Verifying transaction: ...working... done
Executing transaction: ...working... done


InitError: InitError: PyError (PyImport_ImportModule

The Python package matplotlib could not be imported by pyimport. Usually this means
that you did not install matplotlib in the Python version being used by PyCall.

PyCall is currently configured to use the Julia-specific Python distribution
installed by the Conda.jl package.  To install the matplotlib module, you can
use `pyimport_conda("matplotlib", PKG)`, where PKG is the Anaconda
package that contains the module matplotlib, or alternatively you can use the
Conda package directly (via `using Conda` followed by `Conda.add` etcetera).

Alternatively, if you want to use a different Python distribution on your
system, such as a system-wide Python (as opposed to the Julia-specific Python),
you can re-configure PyCall with that Python.   As explained in the PyCall
documentation, set ENV["PYTHON"] to the path/name of the python executable
you want to use, run Pkg.build("PyCall"), and re-launch Julia.

) <class 'ImportError'>
ImportError('DLL load failed while importing _imaging: Nie można odnaleźć określonego modułu.')
  File "C:\Users\mbili\.julia\conda\3\x86_64\lib\site-packages\matplotlib\__init__.py", line 161, in <module>
    from . import _api, _version, cbook, _docstring, rcsetup
  File "C:\Users\mbili\.julia\conda\3\x86_64\lib\site-packages\matplotlib\rcsetup.py", line 27, in <module>
    from matplotlib.colors import Colormap, is_color_like
  File "C:\Users\mbili\.julia\conda\3\x86_64\lib\site-packages\matplotlib\colors.py", line 52, in <module>
    from PIL import Image
  File "C:\Users\mbili\.julia\conda\3\x86_64\lib\site-packages\PIL\Image.py", line 84, in <module>
    from . import _imaging as core

during initialization of module PyPlot

In [21]:
softmax(x::GraphNode) = BroadcastedOperator(softmax, x)
forward(::BroadcastedOperator{typeof(softmax)}, x) = return exp.(x) ./ sum(exp.(x))
backward(node::BroadcastedOperator{typeof(softmax)}, x, g) = let
    y = node.output
    J = diagm(y) .- y * y'
    tuple(J' * g)
end

backward (generic function with 12 methods)

In [22]:
rosenbrock(x, y) = (Constant(1.0) .- x .* x) .+ Constant(100.0) .* (y .- x .* x) .* (y .- x .* x)

rosenbrock (generic function with 1 method)

In [23]:
x = Variable([0.], name="x")
y = Variable([0.], name="y")
graph = topological_sort(rosenbrock(x, y))

13-element Vector{Any}:
 const 1.0
 var x
 ┣━ ^ 1-element Vector{Float64}
 ┗━ ∇ Nothing
 op.?(typeof(*))
 op.?(typeof(-))
 const 100.0
 var y
 ┣━ ^ 1-element Vector{Float64}
 ┗━ ∇ Nothing
 op.?(typeof(*))
 op.?(typeof(-))
 op.?(typeof(*))
 op.?(typeof(*))
 op.?(typeof(-))
 op.?(typeof(*))
 op.?(typeof(+))

In [24]:
v  = -1:.1:+1
n  = length(v)
z  = zeros(n, n)
dz = zeros(n, n, 2)
for i=1:n, j=1:n
    x.output .= v[i]
    y.output .= v[j]
    z[i,j] = first(forward!(graph)); backward!(graph)
    dz[i,j,1] = first(x.gradient)
    dz[i,j,2] = first(y.gradient)
end


In [None]:
using PyPlot
xv = repeat(v, 1, n)
yv = repeat(v',n, 1)
contourf(xv, yv, z)
quiver(xv, yv, dz[:,:,1], dz[:,:,2])