First load the packages `GeometricMachineLearning`, `LinearAlgebra`, `ProgressMeter`, `Zygote`, `Random`, `Lux` and `MLDatatasets` or just functions within them. A random number generator (rng) is needed because Lux requires one to initialize the network parameters. 

In [21]:
using GeometricMachineLearning
using GeometricMachineLearning: ResNet
using LinearAlgebra: norm
using ProgressMeter: @showprogress
using Zygote: gradient 
import MLDatasets
import Lux
import Random

We train the transformer on the MNIST data set. The images within that data set are 28 $\times$ 28. For training we reshape this matrix to a 49 $\times$ 16 matrix. The following offers a visualization:

Here we fix the constants relating to the data set. `patch_length` describes the size of the image patches. `n_heads` is the number of heads in the multihead attention layers. The number of patches is just the original dimension of the image divided by the patch length and then squared. 

In [4]:
image_dim = 28
patch_length = 7
n_heads = 7
patch_number = (image_dim÷patch_length)^2

16

We now load the data set and perform some preprocessing. The function `split_and_flatten` is part of `GeometricMachineLearning`. It does what was described before. The images are also divided by a factor of 255, leaving this out shouldn't change much, but you can experiment with this. 

In [34]:
train_x, train_y = MLDatasets.MNIST(split=:train)[:]
test_x, test_y = MLDatasets.MNIST(split=:test)[:]

# preprocessing steps (also perform rescaling so that the images have values between 0 and 1)
function preprocess_x(x)
    x_reshaped = zeros(Float32, patch_length^2, patch_number, size(x, 3))
    for i in axes(x, 3)
        x_reshaped[:, :, i] = split_and_flatten(x[:, :, i], patch_length)/255
    end
    x_reshaped
end

train_x_reshaped = preprocess_x(train_x);
test_x_reshaped = preprocess_x(test_x);

49×16×10000 Array{Float32, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0         …  0.0  0.0          0.00390619  0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.00118416  0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.0         0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.0         0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.0         0.0
 0.0  0.0  0.0  0.0  0.0         …  0.0  0.000584391  0.0         0.0
 0.0  0.0  0.0  0.0  0.00129181     0.0  0.00390619   0.0         0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.00176855  0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          1.53787f-5  0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0          0.0         0.0
 ⋮                               ⋱                                ⋮
 0.0  0.0  0.0  0.0  0.0         …  0.0  0.00318339   0.0         0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.000276817  0.0         0.0
 0.0  0.0  0.0  0.0  0.0            0.0  0.0     

We also need to perform preprocessing on the target data (`train_y` and `test_y`). This is referred to as **one-hot encoding**. 

In [6]:
function encode_y(y)
    y_encoded = zeros(Bool, 10, length(y))
    for i in axes(y,1)
        y_encoded[y[i]+1,i] = 1
    end
    y_encoded
end

train_y_encoded = encode_y(train_y);
test_y_encoded = encode_y(test_y);

10×10000 Matrix{Bool}:
 0  0  0  1  0  0  0  0  0  0  1  0  0  …  0  0  0  0  0  1  0  0  0  0  0  0
 0  0  1  0  0  1  0  0  0  0  0  0  0     0  0  0  0  0  0  1  0  0  0  0  0
 0  1  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  1  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  1  0  0  0  0  0  0     0  0  0  0  0  0  0  0  0  1  0  0
 0  0  0  0  0  0  0  0  1  0  0  0  0  …  1  0  0  0  0  0  0  0  0  0  1  0
 0  0  0  0  0  0  0  0  0  0  0  1  0     0  1  0  0  0  0  0  0  0  0  0  1
 1  0  0  0  0  0  0  0  0  0  0  0  0     0  0  1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  0  0  0  0  0  0     0  0  0  1  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  0  1  0  1  0  0  1     0  0  0  0  1  0  0  0  0  0  0  0

Now we define a bunch of different models for comparison. `Classification` is a neural network layer that takes matrices. The first entry is the number of rows and the second entry is the number of labels (output). The number of layers in the transformer (i.e. multihead attention and resnet) are specified through $L$. The `Classification` layer has the following options: 
- `use_bias`: you can use a bias in the classification layer
- `add_connection`: you can use the residual connection for the multihead attention or not. 
- `use_average`: if the input to your layer is a matrix (or a tensor) and this is set to `true` it computes the average of the columns $\frac{1}{\mathtt{n\_col}}\sum_{j = 1\ldots\mathtt{n\_col}}x_{i,j}$ after the linear transformation has been applied, if `false` then it takes the first column.
- `use_softmax`: If set to `true` then this uses `softmax` as the nonlinearity, if `false` it uses elementwise sigmoid. 

In [11]:
L = 2

models = (

    model₀ = Lux.Chain(Tuple(map(_ -> ResNet(49, tanh), 1:L))..., Classification(patch_length^2, 10, use_bias=false, use_average=false, use_softmax=false)),

    model₁ = Lux.Chain( Transformer(patch_length^2, n_heads, L, add_connection=false, Stiefel=false),
            Classification(patch_length^2, 10, use_bias=false, use_average=false, use_softmax=false)),

    model₂ = Lux.Chain(Transformer(patch_length^2, n_heads, L, add_connection=true, Stiefel=false),
                        Classification(patch_length^2, 10, use_bias=false, use_average=false, use_softmax=false)),

    model₃ = Lux.Chain(Transformer(patch_length^2, n_heads, L, add_connection=false, Stiefel=true),
                        Classification(patch_length^2, 10, use_bias=false, use_average=false, use_softmax=false)),
                        
    model₄ = Lux.Chain(Transformer(patch_length^2, n_heads, L, add_connection=true, Stiefel=true),
                        Classification(patch_length^2, 10, use_bias=false, use_average=false, use_softmax=false))

)

Chain(
    layer_1 = MultiHeadAttention(),     [90m# 7_203 parameters[39m
    layer_2 = ResNet(49 => 49, tanh_fast),  [90m# 2_450 parameters[39m
    layer_3 = MultiHeadAttention(),     [90m# 7_203 parameters[39m
    layer_4 = ResNet(49 => 49, tanh_fast),  [90m# 2_450 parameters[39m
    layer_5 = Classification(),         [90m# 490 parameters[39m
) [90m        # Total: [39m19_796 parameters,
[90m          #        plus [39m0 states, [90msummarysize [39m200 bytes.

Now for the actual training. If you have an NVIDIA gpu (that is supported by `CUDA.jl`) you can use it here. All the functionality in `GeometricMachineLearning` has been adapted for this (using `KernelAbstractions.jl`). This probably also works with `AMDGPU.jl` and `Metal.jl`, but would have to be tested. NOTE: during actual training you may want to suppress evaluating the total loss at each iteration. 

In [30]:
const num = 60000
function training(model::Lux.Chain, batch_size=32, n_epochs=.01, o=AdamOptimizer(), enable_cuda=false)
    enable_cuda ? using CUDA : nothing
    ps, st = enable_cuda ? Lux.setup(CUDA.device(), Random.default_rng(), model) : Lux.setup(Random.default_rng(), model)

    function loss(ps, x, y)
        x_eval = enable_cuda ? Lux.apply(model, x |> cu, ps, st)[1] : x_eval = Lux.apply(model, x, ps, st)[1]
        enable_cuda ? norm(x_eval - (y |> cu))/sqrt(size(y, 2)) : norm(x_eval - (y))/sqrt(size(y, 2))
    end

    #the number of training steps is calculated based on the number of epochs and the batch size
    training_steps = Int(ceil(n_epochs*num/batch_size))
    #this records the training error
    loss_array = zeros(training_steps + 1)
    loss_array[1] = enable_cuda ? loss(ps, train_x_reshaped |> cu, train_y_encoded |> cu) : loss_array[1] = loss(ps, train_x_reshaped, train_y_encoded)

    println("initial loss: ", loss_array[1])

    #initialize the optimizer cache
    optimizer_instance = enable_cuda ? Optimizer(CUDA.device(), o, model) : Optimizer(o, model)

    @showprogress "Training network ..." for i in 1:training_steps
        #draw a mini batch 
        indices = Int.(ceil.(rand(batch_size)*num))
        x_batch = enable_cuda ? (train_x_reshaped[:, :, indices] |> cu) : train_x_reshaped[:, :, indices]
        y_batch = enable_cuda ? (train_y_encoded[:, indices] |> cu) : train_y_encoded[:, indices]

        #compute the gradient using Zygote
        dp = gradient(ps -> loss(ps, x_batch, y_batch), ps)[1]

        #update the cache of the optimizer and the parameter
        optimization_step!(optimizer_instance, model, ps, dp)    

        #compute the loss at the current step
        loss_array[1+i] = enable_cuda ? loss(ps, train_x_reshaped |> cu, train_y_encoded |> cu) : loss(ps, train_x_reshaped, train_y_encoded)

    end
    println("final loss: ", loss_array[end])
    enable_cuda ? println("final test loss: ", loss(ps, test_x_reshaped |> cu, test_y_encoded |> cu),"\n") : println("final test loss: ", loss(ps, test_x_reshaped, test_y_encoded),"\n")

    loss_array
end

training (generic function with 5 methods)

Now for the actual training: 

In [33]:
batch_size = 32
n_epochs = 0.01
o = AdamOptimizer(0.001f0, 0.9f0, 0.99f0, 1.0f-8)
enable_cuda = false

NamedTuple{keys(models)}(Tuple(training(model, batch_size, n_epochs, o, enable_cuda, give_training_error) for model in models))

function plot_stuff()
    p = plot(loss_array₀, label="0")
    plot!(p, loss_array₁, label="1")
    plot!(p, loss_array₂, label="2")
    plot!(p, loss_array₃, label="3")
    plot!(p, loss_array₄, label="4")
end

give_training_error ? plot_stuff : nothing 

initial loss: 0.9486927545760359


[32mTraining network ...  11%|███▎                           |  ETA: 0:00:48[39m[K

[32mTraining network ...  16%|████▉                          |  ETA: 0:00:45[39m[K

[32mTraining network ...  21%|██████▌                        |  ETA: 0:00:42[39m[K

[32mTraining network ...  26%|████████▏                      |  ETA: 0:00:39[39m[K

[32mTraining network ...  32%|█████████▊                     |  ETA: 0:00:37[39m[K

[32mTraining network ...  37%|███████████▍                   |  ETA: 0:00:34[39m[K

[32mTraining network ...  42%|█████████████                  |  ETA: 0:00:31[39m[K

[32mTraining network ...  47%|██████████████▋                |  ETA: 0:00:28[39m[K

[32mTraining network ...  53%|████████████████▍              |  ETA: 0:00:26[39m[K

[32mTraining network ...  58%|██████████████████             |  ETA: 0:00:23[39m[K

[32mTraining network ...  63%|███████████████████▋           |  ETA: 0:00:20[39m[K

[32mTraining network ...  68%|█████████████████████▎         |  ETA: 0:00:17[39m[K

[32mTraining network ...  74%|██████████████████████▉        |  ETA: 0:00:14[39m[K

[32mTraining network ...  79%|████████████████████████▌      |  ETA: 0:00:11[39m[K

[32mTraining network ...  84%|██████████████████████████▏    |  ETA: 0:00:09[39m[K

[32mTraining network ...  89%|███████████████████████████▊   |  ETA: 0:00:06[39m[K

[32mTraining network ...  95%|█████████████████████████████▍ |  ETA: 0:00:03[39m[K

[32mTraining network ... 100%|███████████████████████████████| Time: 0:00:53[39m[K


final loss: 0.9486604864158703
final test loss: 

0.9486714172363281



InterruptException: InterruptException: