# Imports

In [1]:
using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using Flux.Data;
using Flux: onehotbatch, @epochs;

# Primitives

## Linear 

In [2]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear($i=>$o)")
end

# Forward
(l::Linear)(x::Union{AbstractVector{T}, AbstractMatrix{T}}) where T = l.W * x .+ l.b

## Logit Cross Entropy

In [3]:
function logitcrossentropy(ŷ, y; dims = 1, agg = mean)
    agg(.-sum(y .* logsoftmax(ŷ; dims = dims); dims = dims));
  end

logitcrossentropy (generic function with 1 method)

# Define the model

In [4]:
mutable struct Net
    fc1::Linear
    fc2::Linear
end

@functor Net

# Init
Net() = Net(
    Linear(28*28, 100),
    Linear(100, 10)
)

# Forward
function (model::Net)(x::AbstractArray)
    x = reshape(x, 28*28, :)
    x = model.fc1(x)
    x = relu(x)
    x = model.fc2(x)
    return x
end

# Data

In [5]:
train_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:train);
test_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:test);

X_train = train_dataset.features;
Y_train = train_dataset.targets;

X_test = test_dataset.features;
Y_test = test_dataset.targets;

In [6]:
size(X_test)

(28, 28, 10000)

In [7]:
# Flatten features to be 784 dim
X_train = reshape(X_train, 784, :);  # (dim x batch)
X_test = reshape(X_test, 784, :);

In [8]:
# Convert targets to one-hot vectors
Y_train = onehotbatch(Y_train, 0:9);
Y_test = onehotbatch(Y_test, 0:9);  # (dim x batch)

In [9]:
loader = DataLoader((X_train, Y_train), shuffle=true, batchsize=128);

# Training setup

In [14]:
# Make model
mlp = Net()

Net(Linear(784=>100), Linear(100=>10))

In [15]:
# Setup Adam optimizer
# Default Β is (0.9, 0.999)
state = Optimisers.setup(Optimisers.Adam(1e-3), mlp);

In [16]:
# Create objective function
function loss_function(model::Net, x::AbstractArray, y::AbstractArray)
    ŷ = model(x)
    loss = logitcrossentropy(ŷ, y)
    return loss
end

loss_function (generic function with 1 method)

# Training loop

In [18]:
for epoch in 1:10
    losses = []
    @time begin
        for (x, y) in loader
            # loss_function does forward pass
            # Yota.jl grad function computes parameter gradients
            loss, g = grad(loss_function, mlp, x, y)
            # Optimiser updates parameters
            Optimisers.update!(state, mlp, g[2])
            push!(losses, loss)
            # TODO: Add accuracy computation
        end
        @info("epoch $epoch loss = $(mean(losses))")
    end
end

  0.774987 seconds (162.82 k allocations: 1.951 GiB, 12.71% gc time)


┌ Info: epoch 1 loss = 0.04094025307690705
└ @ Main In[18]:12


  0.646125 seconds (163.01 k allocations: 1.951 GiB, 10.26% gc time)


┌ Info: epoch 2 loss = 0.03572923218462636
└ @ Main In[18]:12


  0.652080 seconds (163.01 k allocations: 1.951 GiB, 9.54% gc time)


┌ Info: epoch 3 loss = 0.03166147383022924
└ @ Main In[18]:12


  0.682782 seconds (163.01 k allocations: 1.951 GiB, 11.74% gc time)


┌ Info: epoch 4 loss = 0.027879119237525167
└ @ Main In[18]:12


  0.666789 seconds (163.01 k allocations: 1.951 GiB, 8.82% gc time)


┌ Info: epoch 5 loss = 0.02481490266472439
└ @ Main In[18]:12


  0.624532 seconds (163.01 k allocations: 1.951 GiB, 8.85% gc time)


┌ Info: epoch 6 loss = 0.021184432907923646
└ @ Main In[18]:12


  0.604429 seconds (163.01 k allocations: 1.951 GiB, 8.57% gc time)


┌ Info: epoch 7 loss = 0.019366974012950174
└ @ Main In[18]:12


  0.622498 seconds (163.01 k allocations: 1.951 GiB, 8.82% gc time)


┌ Info: epoch 8 loss = 0.016626981172646552
└ @ Main In[18]:12


  0.610364 seconds (163.01 k allocations: 1.951 GiB, 8.79% gc time)


┌ Info: epoch 9 loss = 0.014004979298361142
└ @ Main In[18]:12


  0.661410 seconds (163.01 k allocations: 1.951 GiB, 10.90% gc time)


┌ Info: epoch 10 loss = 0.012511318528209223
└ @ Main In[18]:12
