# Imports

In [1]:
using Yota;
using MLDatasets;
using NNlib;
using Statistics;
using Distributions;
using Functors;
using Optimisers;
using MLUtils: DataLoader;
using OneHotArrays: onehotbatch
using Metrics;

# Primitives

## Linear 

In [2]:
mutable struct Linear
    W::AbstractMatrix{T} where T
    b::AbstractVector{T} where T
end

@functor Linear

# Init
function Linear(in_features::Int, out_features::Int)
    k_sqrt = sqrt(1 / in_features)
    d = Uniform(-k_sqrt, k_sqrt)
    return Linear(rand(d, out_features, in_features), rand(d, out_features))
end
Linear(in_out::Pair{Int, Int}) = Linear(in_out[1], in_out[2])

function Base.show(io::IO, l::Linear)
    o, i = size(l.W)
    print(io, "Linear($i=>$o)")
end

# Forward
(l::Linear)(x::Union{AbstractVector{T}, AbstractMatrix{T}}) where T = l.W * x .+ l.b

## Logit Cross Entropy

In [3]:
function logitcrossentropy(ŷ, y; dims=1, agg=mean)
  # Compute cross entropy loss from logits
  # Cross entropy computed from NLL loss on logsoftmax of model outputs
    agg(.-sum(y .* logsoftmax(ŷ; dims=dims); dims=dims));
  end

logitcrossentropy (generic function with 1 method)

# Define the model

In [4]:
mutable struct Net
    fc1::Linear
    fc2::Linear
end

# Need to mark functor for Optimizer to work
@functor Net

# Init
Net() = Net(
    Linear(28*28, 100),
    Linear(100, 10)
)

# Forward
function (model::Net)(x::AbstractArray)
    x = reshape(x, 28*28, :)
    x = model.fc1(x)
    x = relu(x)
    x = model.fc2(x)
    return x
end

# Data

In [5]:
train_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:train);
test_dataset = MNIST(dir="/Users/trevoryu/Code/data/mnist", split=:test);

X_train = train_dataset.features;
Y_train = train_dataset.targets;

X_test = test_dataset.features;
Y_test = test_dataset.targets;

In [6]:
size(X_test)

(28, 28, 10000)

In [7]:
# Flatten features to be 784 dim
X_train = reshape(X_train, 784, :);  # (dim x batch)
X_test = reshape(X_test, 784, :);

In [8]:
# Convert targets to one-hot vectors
Y_train = onehotbatch(Y_train, 0:9);
Y_test = onehotbatch(Y_test, 0:9);  # (dim x batch)

In [9]:
train_loader = DataLoader((X_train, Y_train), shuffle=true, batchsize=128);
test_loader = DataLoader((X_test, Y_test), shuffle=false, batchsize=128);

# Training setup

In [10]:
# Make model
mlp = Net()

Net(Linear(784=>100), Linear(100=>10))

In [11]:
# Setup Adam optimizer
# Default Β is (0.9, 0.999)
state = Optimisers.setup(Optimisers.Adam(1e-3), mlp);

In [12]:
# Create objective function to optimize
function loss_function(model::Net, x::AbstractArray, y::AbstractArray)
    ŷ = model(x)
    loss = logitcrossentropy(ŷ, y)
    return loss
end

loss_function (generic function with 1 method)

# Training loop

In [13]:
for epoch in 1:10
    losses = []
    @time begin
        for (x, y) in train_loader
            # loss_function does forward pass
            # Yota.jl grad function computes model parameter gradients in g[2]
            loss, g = grad(loss_function, mlp, x, y)
            
            # Optimiser updates parameters
            Optimisers.update!(state, mlp, g[2])
            push!(losses, loss)
        end
        @info("epoch $epoch loss = $(mean(losses))")
    end
end

 25.837786 seconds (157.79 M allocations: 10.050 GiB, 3.78% gc time, 91.53% compilation time: 0% of which was recompilation)


┌ Info: epoch 1 loss = 0.4370590642396691
└ @ Main In[13]:13


  0.757983 seconds (163.01 k allocations: 1.951 GiB, 11.52% gc time)


┌ Info: epoch 2 loss = 0.20768273653621058
└ @ Main In[13]:13


  0.755957 seconds (163.01 k allocations: 1.951 GiB, 9.69% gc time)


┌ Info: epoch 3 loss = 0.15264513650689454
└ @ Main In[13]:13


  0.766106 seconds (163.01 k allocations: 1.951 GiB, 11.41% gc time)


┌ Info: epoch 4 loss = 0.12068533624691201
└ @ Main In[13]:13


  0.790535 seconds (163.01 k allocations: 1.951 GiB, 11.73% gc time)


┌ Info: epoch 5 loss = 0.09948631426212774
└ @ Main In[13]:13


  0.749042 seconds (163.01 k allocations: 1.951 GiB, 10.17% gc time)


┌ Info: epoch 6 loss = 0.08462848175017493
└ @ Main In[13]:13


  0.763076 seconds (163.01 k allocations: 1.951 GiB, 11.58% gc time)


┌ Info: epoch 7 loss = 0.07293157247018839
└ @ Main In[13]:13


  0.737009 seconds (163.01 k allocations: 1.951 GiB, 8.39% gc time)


┌ Info: epoch 8 loss = 0.06216332802989073
└ @ Main In[13]:13


  0.724461 seconds (163.02 k allocations: 1.951 GiB, 8.17% gc time)


┌ Info: epoch 9 loss = 0.05457949791791655
└ @ Main In[13]:13


  0.742743 seconds (163.01 k allocations: 1.951 GiB, 9.23% gc time)


┌ Info: epoch 10 loss = 0.047562408875107556
└ @ Main In[13]:13


# Evaluation

In [14]:
function evaluate(mlp, test_loader)
    preds = []
    targets = []
    @time begin
        for (x, y) in test_loader
            # Get model predictions
            # Note argmax of nd-array gives CartesianIndex
            # Need to grab the first element of each CartesianIndex to get the true index
            logits = mlp(x)
            ŷ = map(i -> i[1], argmax(logits, dims=1))
            append!(preds, ŷ)

            # Get true labels
            true_label = map(i -> i[1], argmax(y, dims=1))
            append!(targets, true_label)
        end
    end
    accuracy = sum(preds .== targets) / length(targets)
    return accuracy
end

evaluate (generic function with 1 method)

In [17]:
evaluate(mlp, test_loader)

  0.068577 seconds (3.42 k allocations: 115.477 MiB, 27.36% gc time)


0.9766

In [16]:
# Avg time per train epoch: 0.753 sec
# Total time for train (with compile): 32.6 sec
# Avg time per eval step: 0.77 ms