# Imports

In [1]:
using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using MLUtils: DataLoader;
using OneHotArrays: onehotbatch
using Metrics;
using TimerOutputs;

# Primitives

## Linear 

In [2]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear($i=>$o)")
end

# Forward
(l::Linear)(x::Union{AbstractVector{T}, AbstractMatrix{T}}) where T = l.W * x .+ l.b

## Logit Cross Entropy

In [3]:
function logitcrossentropy(ŷ, y; dims=1, agg=mean)
  # Compute cross entropy loss from logits
  # Cross entropy computed from NLL loss on logsoftmax of model outputs
    agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims));
  end

logitcrossentropy (generic function with 1 method)

# Define the model

In [4]:
mutable struct Net
    fc1::Linear
    fc2::Linear
end

# Need to mark functor for Optimizer to work
@functor Net

# Init
Net() = Net(
    Linear(28*28, 100),
    Linear(100, 10)
)

# Forward
function (model::Net)(x::AbstractArray)
    x = reshape(x, 28*28, :)
    x = model.fc1(x)
    x = relu(x)
    x = model.fc2(x)
    return x
end

# Data

In [5]:
train_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:train);
test_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:test);

X_train = train_dataset.features;
Y_train = train_dataset.targets;

X_test = test_dataset.features;
Y_test = test_dataset.targets;

In [6]:
size(X_test)

(28, 28, 10000)

In [7]:
# Flatten features to be 784 dim
X_train = reshape(X_train, 784, :);  # (dim x batch)
X_test = reshape(X_test, 784, :);

In [8]:
# Convert targets to one-hot vectors
Y_train = onehotbatch(Y_train, 0:9);
Y_test = onehotbatch(Y_test, 0:9);  # (dim x batch)

In [9]:
batch_size = 128;
train_loader = DataLoader((X_train, Y_train), shuffle=true, batchsize=batch_size);
test_loader = DataLoader((X_test, Y_test), shuffle=false, batchsize=batch_size);

# Training setup

In [10]:
# Make model
mlp = Net()

Net(Linear(784=>100), Linear(100=>10))

In [11]:
# Setup Adam optimizer
# Default Β is (0.9, 0.999)
state = Optimisers.setup(Optimisers.Adam(1e-3), mlp);

In [12]:
# Create objective function to optimize
function loss_function(model::Net, x::AbstractArray, y::AbstractArray)
    ŷ = model(x)
    loss = logitcrossentropy(ŷ, y)
    return loss
end

loss_function (generic function with 1 method)

# Evaluation function

In [13]:
function evaluate(mlp, test_loader)
    preds = []
    targets = []
    for (x, y) in test_loader
        # Get model predictions
        # Note argmax of nd-array gives CartesianIndex
        # Need to grab the first element of each CartesianIndex to get the true index
        logits = mlp(x)
        ŷ = map(i -> i[1], argmax(logits, dims=1))
        append!(preds, ŷ)

        # Get true labels
        true_label = map(i -> i[1], argmax(y, dims=1))
        append!(targets, true_label)
    end
    accuracy = sum(preds .== targets) / length(targets)
    return accuracy
end

evaluate (generic function with 1 method)

# Training loop

In [14]:
# Setup timing output
const to = TimerOutput()

[0m[1m ────────────────────────────────────────────────────────────────────[22m
[0m[1m                   [22m         Time                    Allocations      
                   ───────────────────────   ────────────────────────
 Tot / % measured:      345ms /   0.0%           45.3MiB /   0.0%    

 Section   ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────
[0m[1m ────────────────────────────────────────────────────────────────────[22m

In [15]:
last_loss = 0;
@timeit to "total_training_time" begin
    for epoch in 1:10
        timing_name = epoch > 1 ? "average_epoch_training_time" : "train_jit"
        @timeit to timing_name begin
            losses = []
            for (x, y) in train_loader
                # loss_function does forward pass
                # Yota.jl grad function computes model parameter gradients in g[2]
                loss, g = grad(loss_function, mlp, x, y)
                
                # Optimiser updates parameters
                Optimisers.update!(state, mlp, g[2])
                push!(losses, loss)
            end
            last_loss = mean(losses)
            @info("epoch $epoch loss = $(mean(losses))")
        end
        timing_name = epoch > 1 ? "average_inference_time" : "eval_jit"
        @timeit to timing_name begin
            acc = evaluate(mlp, test_loader)
            @info("epoch $epoch eval accuracy = $(acc)")
        end
    end
end

┌ Info: epoch 1 loss = 0.4375657584243529
└ @ Main In[15]:17
┌ Info: epoch 1 eval accuracy = 0.9304
└ @ Main In[15]:22
┌ Info: epoch 2 loss = 0.21220267083301655
└ @ Main In[15]:17
┌ Info: epoch 2 eval accuracy = 0.9471
└ @ Main In[15]:22
┌ Info: epoch 3 loss = 0.1592959047496635
└ @ Main In[15]:17
┌ Info: epoch 3 eval accuracy = 0.9596
└ @ Main In[15]:22
┌ Info: epoch 4 loss = 0.12536456292602383
└ @ Main In[15]:17
┌ Info: epoch 4 eval accuracy = 0.9649
└ @ Main In[15]:22
┌ Info: epoch 5 loss = 0.10332849547590327
└ @ Main In[15]:17
┌ Info: epoch 5 eval accuracy = 0.9678
└ @ Main In[15]:22
┌ Info: epoch 6 loss = 0.08698907676456932
└ @ Main In[15]:17
┌ Info: epoch 6 eval accuracy = 0.9693
└ @ Main In[15]:22
┌ Info: epoch 7 loss = 0.0743854635873991
└ @ Main In[15]:17
┌ Info: epoch 7 eval accuracy = 0.9723
└ @ Main In[15]:22
┌ Info: epoch 8 loss = 0.06411696637116769
└ @ Main In[15]:17
┌ Info: epoch 8 eval accuracy = 0.9747
└ @ Main In[15]:22
┌ Info: epoch 9 loss = 0.055365824426301447

In [16]:
to

[0m[1m ────────────────────────────────────────────────────────────────────────────────[22m
[0m[1m                               [22m         Time                    Allocations      
                               ───────────────────────   ────────────────────────
       Tot / % measured:            33.3s /  98.1%           29.0GiB /  99.7%    

 Section               ncalls     time    %tot     avg     alloc    %tot      avg
 ────────────────────────────────────────────────────────────────────────────────
 total_training_time        1    32.6s  100.0%   32.6s   28.9GiB  100.0%  28.9GiB
   train_jit                1    26.0s   79.6%   26.0s   10.0GiB   34.8%  10.0GiB
   average_epoch_tr...      9    5.80s   17.8%   645ms   17.6GiB   60.8%  1.95GiB
   eval_jit                 1    559ms    1.7%   559ms    283MiB    1.0%   283MiB
   average_inferenc...      9    296ms    0.9%  32.9ms   1.02GiB    3.5%   116MiB
[0m[1m ──────────────────────────────────────────────────────────────

In [17]:
# Train time
# Exclude jit time
average_epoch_train_time = TimerOutputs.time(to["total_training_time"]["average_epoch_training_time"]) / (9 * 1e9)  # Outputs in nanoseconds, conver to seconds

0.644896347

In [18]:
# Eval batch time
# Exclude jit time
num_batches = length(test_loader)
average_eval_batch_time = TimerOutputs.time(to["total_training_time"]["average_inference_time"]) / (9 * 1e6 * num_batches)  # Outputs in nanoseconds, conver to milliseconds

0.41594755133614625

In [19]:
total_train_time = TimerOutputs.time(to["total_training_time"])
final_eval_accuracy = evaluate(mlp, test_loader)

0.9744

In [20]:
metrics = Dict(
    "model_name" => "MLP",
    "dataset" => "MNIST Digits",
    "framework_name" => "Avalon.jl",
    "task" => "classification",
    "total_training_time" => total_train_time,
    "average_epoch_training_time" => average_epoch_train_time,
    "average_batch_inference_time" => average_eval_batch_time,
    "final_trianing_loss" => last_loss,
    "final_evaluation_accuracy" => final_eval_accuracy
)

Dict{String, Any} with 9 entries:
  "task"                         => "classification"
  "framework_name"               => "Avalon.jl"
  "final_trianing_loss"          => 0.0483891
  "total_training_time"          => 32647958292
  "average_epoch_training_time"  => 0.644896
  "final_evaluation_accuracy"    => 0.9744
  "model_name"                   => "MLP"
  "dataset"                      => "MNIST Digits"
  "average_batch_inference_time" => 0.415948

In [22]:
using JSON;

open("m1-avalon-mlp.json","w") do f
    JSON.print(f, metrics)
end