In [None]:
using Random, Statistics, StatsBase
using Lux, Optimisers, Zygote
using MLDatasets, MLUtils, OneHotArrays
using Plots

In [None]:
rng = Random.default_rng()
Random.seed!(rng, 0)

In [None]:
batchsize = 48

x_train, y_train = MNIST(split=:train)[:]
x_test, y_test = MNIST(split=:test)[:]

x_train = reshape(x_train, size(x_train, 1), size(x_train, 2), 1, size(x_train, 3))
y_train = onehotbatch(y_train, 0:9)
(x_train, y_train), (x_val, y_val) = splitobs((x_train, y_train), at=0.9, shuffle=true)
x_test = reshape(x_test, size(x_test, 1), size(x_test, 2), 1, size(x_test, 3))
y_test = onehotbatch(y_test, 0:9)

load_train = DataLoader((x_train, y_train); batchsize=batchsize, shuffle=true, partial=false)
load_val = DataLoader((x_val, y_val); batchsize=batchsize, shuffle=false, partial=false)
load_test = DataLoader((x_test, y_test); batchsize=batchsize, shuffle=false, partial=false)
load_martingale = DataLoader((x_train, y_train); batchsize=1, shuffle=true, partial=false)
;

In [None]:
model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)),
)

In [None]:
params, states = Lux.setup(rng, model)
;

In [None]:
const lossfn = CrossEntropyLoss(; logits=Val(true))

function loss(model, params, states, x, y)
    yhat, new_states = model(x, params, states)
    ls = mean(lossfn(yhat, y))
    return ls, new_states
end

In [None]:
function validation(model, params, states, loader)
    accs = Float64[]
    for (x, y) in loader
        yhat, _ = model(x, params, states)
        push!(accs, mean(lossfn(yhat, y)))
    end
    
    return mean(accs)
end

In [None]:
function accuracy(model, params, states, loader)
    accs = Float64[]
    for (x, y) in loader
        yhat, _ = model(x, params, states)
        preds = onecold(yhat, 0:9)
        labels = onecold(y, 0:9)
        push!(accs, mean(preds .== labels))
    end
    
    return mean(accs)
end

In [None]:
opt = Optimisers.Adam(1e-4 * batchsize)
opt_state = Optimisers.setup(opt, params)
;

In [None]:
function train!(model, params, states, opt_state, load_train, epochs)
    time_start = time()
    loss_val_0 = Inf
    patience = 3
    patience_ctr = 0
    epoch_0 = 1
    params_0, states_0, opt_state_0 = deepcopy((params, states, opt_state))
    losses_train = Float64[]
    losses_val = Float64[]
    accs_test = Float64[]
    
    for epoch in 1:epochs
        time_0 = time()
        loss_train = 0.0
        for (xb, yb) in load_train
            # DataLoader implements its own iterator, thus this is reshuffled every epoch
            (ls, new_states), back = Zygote.pullback(params -> loss(model, params, states, xb, yb), params)
            states = new_states

            grads = back((1f0, nothing))[1]
            opt_state, params = Optimisers.update(opt_state, params, grads)
            loss_train += ls
        end
        time_train = time() - time_0
        loss_train = loss_train / length(load_train)
        loss_val = validation(model, params, states, load_val)
        time_val = time() - time_0 - time_train
        acc_test = accuracy(model, params, states, load_test)
        time_test = time() - time_0 - time_train - time_val
        
        println("Epoch $(epoch)" * 
            "\n\tloss_train = $(loss_train)" * "\ttime_train = $(round(time_train, digits=4)) s" * 
            "\n\tloss_val   = $(loss_val)  " * "\ttime_val   = $(round(time_val, digits=4)) s" * 
            "\n\tacc_test   = $(acc_test)  " * "\ttime_test  = $(round(time_test, digits=4)) s"
        )
        push!(losses_train, loss_train)
        push!(losses_val, loss_val)
        push!(accs_test, acc_test)

        if loss_val < loss_val_0
            loss_val_0 = loss_val
            params_0, states_0, opt_state_0 = deepcopy((params, states, opt_state))
            patience_ctr = 0
            epoch_0 = epoch
        else
            patience_ctr += 1
        end

        if patience_ctr >= patience
            println("Early stopping!")
            params, states, opt_state = params_0, states_0, opt_state_0
            break
        end
    end

    time_total = round(time() - time_start, digits=4)
    println("Best epoch: $(epoch_0)")
    println("Total time: $(time_total) s")
    return params, states, opt_state, epoch_0, (losses_train, losses_val, accs_test)
end

In [None]:
nsteps = 50_000
epochs = nsteps ÷ batchsize
println("nsteps = $(nsteps)," * "\tbatchsize = $(batchsize)," * "\tepochs = $(epochs)")
params, states, opt_state, epoch, losses = train!(model, params, states, opt_state, load_train, epochs)
;

In [None]:
acc_test = accuracy(model, params, states, load_test)
println("Final test accuracy = $(acc_test)")

In [None]:
colors = palette(:default)[1:10]
plt = plot(size=(1000,500), layout=(1,2), margin=5Plots.mm)
plot!(losses[1], subplot=1, label="training loss", color=colors[1])
plot!(losses[2], subplot=1, label="validation loss", color=colors[2])
vline!([epoch], subplot=1, label="early stopping", color=colors[2], linestyle=:dash)
plot!(losses[3], subplot=2, label="test accuracy", color=colors[3])
hline!([acc_test], subplot=2, label="early stopping", color=colors[3], linestyle=:dash)
title!("Loss during training", subplot=1)
title!("Final accuracy", subplot=2)
xlabel!("epochs", subplot=1)
xlabel!("epochs", subplot=2)
ylabel!("loss", subplot=1)
ylabel!("accuracy", subplot=2)
display(plt)

In [None]:
function martingale_bootstrap!(model, params, states, load_martingale, niter)
    iter = 1
    for (xb, yb) in load_martingale
        # DataLoader implements its own iterator, thus this is reshuffled every epoch
        (ls, new_states), back = Zygote.pullback(params -> loss(model, params, states, xb, yb), params)
        states = new_states

        grads = back((1f0, nothing))[1]
        ϵ = 1.0 / (100.0+iter)
        params = Lux.fmap(params, grads) do ps, gs
            ps .+ ϵ .* gs
        end
        # params = Lux.fmap((ps, gs) -> ps .+ ϵ .* gs, (params, grads))
            # randn(rng, eltype(ps), size(ps))
        
        if iter < niter
            iter += 1
        else
            break
        end
    end

    return params
end

In [None]:
function martingale_posterior!(model, params, states, load_martingale, niter)
    iter = 1
    for (xb, yb) in load_martingale
        # DataLoader implements its own iterator, thus this is reshuffled every epoch
        yhat, _ = model(xb, params, states)
        probs = softmax(yhat; dims=1)
        preds = sample(0:9, Weights(probs[:,1]))
        yb = onehotbatch(preds, 0:9)
        
        (ls, new_states), back = Zygote.pullback(params -> loss(model, params, states, xb, yb), params)
        states = new_states

        grads = back((1f0, nothing))[1]
        ϵ = 1.0 / (100.0+iter)
        params = Lux.fmap(params, grads) do ps, gs
            ps .+ ϵ .* gs
        end
        # params = Lux.fmap((ps, gs) -> ps .+ ϵ .* gs, (params, grads))
            # randn(rng, eltype(ps), size(ps))
        
        if iter < niter
            iter += 1
        else
            break
        end
    end

    return params
end

In [None]:
params_1 = martingale_bootstrap!(model, params, states, load_martingale, 200)
params_2 = martingale_posterior!(model, params, states, load_martingale, 200)
;

In [None]:
acc_test = accuracy(model, params, states, load_test)
acc_test_1 = accuracy(model, params_1, states, load_test)
acc_test_2 = accuracy(model, params_2, states, load_test)
println("Accuracy:" * 
    "\n\tinit = $(acc_test)" * 
    "\n\tboot = $(acc_test_1)" * 
    "\n\tpost = $(acc_test_2)"
)

In [None]:
using LinearAlgebra

In [None]:
BLAS.get_num_threads()