In [1]:
using Flux, Statistics
using Parameters: @with_kw

using Printf
using MLDatasets

@with_kw mutable struct Args
    learning_rate::Float64 = 3e-4       
    batch_size::Int = 1024   
    epochs::Int = 10
    verbose::Bool = true
end

Args

In [4]:
function load_data(args)
    # Loading Dataset
    X_train, y_train = MLDatasets.MNIST.traindata(Float32)
    X_test,  y_test  = MLDatasets.MNIST.testdata(Float32)

    # Preprocessing steps
    X_train, X_test = Flux.flatten(X_train), Flux.flatten(X_test)
    y_train, y_test = Flux.onehotbatch(y_train, 0:9), Flux.onehotbatch(y_test, 0:9)
    
    if args.verbose
        @printf "--Loaded Dataset statistics--\n"
        @printf "Training data amount :   %d\n" size(X_train, 2)
        @printf "Training data size   :   %d\n" size(X_train, 1)
        @printf "-----------------------------\n"
        @printf "Testing data amount  :   %d\n" size(X_test, 2)
        @printf "Testing data size    :   %d\n" size(X_test, 1)
    end

    # Batching
    train_data = Flux.Data.DataLoader(X_train, y_train, batchsize=args.batch_size, shuffle=true)
    test_data = Flux.Data.DataLoader(X_test, y_test, batchsize=args.batch_size)
    
    return train_data, test_data
end

function loss(model)
    function (x, y)
        Flux.logitcrossentropy(model(x), y)
    end
end

function loss_all(data, loss_func)
    loss = 0f0
    for (x, y) in data
        loss += loss_func(x, y)
    end
    loss / length(data)
end

function accuracy(data, model)
    accuracy = 0
    for (x, y) in data
        predicted  = Flux.onecold(cpu(model(x)))
        true_label = Flux.onecold(cpu(y))
        accuracy += sum(predicted .== true_label) * 1 / size(x, 2)
    end
    accuracy / length(data)
end

accuracy (generic function with 1 method)

In [5]:
args = Args()

# Loading Data
train_data, test_data = load_data(args)

# Constructing the model
model = Chain(
    Dense(28 * 28, 32, relu),
    Dense(32, 10)
)

# Training
evalcb = () -> @show(loss_all(train_data, loss(model)))
optimizer = ADAM(args.learning_rate)

Flux.@epochs args.epochs Flux.train!(
    loss(model), 
    params(model), 
    train_data, 
    optimizer, 
    cb = Flux.throttle(evalcb, 5))

--Loaded Dataset statistics--
Training data amount :   60000
Training data size   :   784
-----------------------------
Testing data amount  :   10000
Testing data size    :   784


┌ Info: Epoch 1
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 2.3231714f0


┌ Info: Epoch 2
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 1.337967f0


┌ Info: Epoch 3
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.810371f0


┌ Info: Epoch 4
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.6032391f0


┌ Info: Epoch 5
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.49972752f0


┌ Info: Epoch 6
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.43774837f0
loss_all(train_data, loss(model)) = 0.39761862f0


┌ Info: Epoch 7
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121
┌ Info: Epoch 8
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.36809814f0


┌ Info: Epoch 9
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.34596434f0


┌ Info: Epoch 10
└ @ Main C:\Users\alona\.julia\packages\Flux\Fj3bt\src\optimise\train.jl:121


loss_all(train_data, loss(model)) = 0.32793108f0


In [11]:
@printf "Model accuracy on training data:   %0.3f \n" accuracy(train_data, model) * 100
@printf "Model accuracy on testing data :   %0.3f \n" accuracy(test_data, model) * 100

Model accuracy on training data:   91.517 
Model accuracy on testing data :   91.504 
