In [None]:
using DiffEqFlux, OrdinaryDiffEq, Flux, NNlib, MLDataUtils, Printf
using Flux: logitcrossentropy
using Flux.Data: DataLoader
using MLDatasets
using CUDA
CUDA.allowscalar(false)

┌ Info: Precompiling DiffEqFlux [aae7a2af-3d4f-5e19-a356-7da93b79d9d0]
└ @ Base loading.jl:1278
┌ Info: Precompiling MLDataUtils [cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d]
└ @ Base loading.jl:1278
┌ Info: Precompiling MLDatasets [eb30cadb-4394-5ae3-aed4-317e484a6458]
└ @ Base loading.jl:1278


In [None]:
function loadmnist(batchsize = bs, train_val_split = 0.8333)
    # Use MLDataUtils LabelEnc for natural onehot conversion
    onehot(labels_raw) = convertlabel(LabelEnc.OneOfK, labels_raw,
                                      LabelEnc.NativeLabels(collect(0:9)))
    # Load MNIST
    imgs, labels_raw = MNIST.testdata();
    # Process images into (H,W,C,BS) batches
    x_data = Float32.(reshape(imgs, size(imgs,1), size(imgs,2), 1, size(imgs,3)))
    y_data = onehot(labels_raw)
    (x_train, y_train), (x_test, y_test) = stratifiedobs((x_data, y_data),
                                                         p = train_split)
    return (
        # Use Flux's DataLoader to automatically minibatch and shuffle the data
        DataLoader(gpu.(collect.((x_train, y_train))); batchsize = batchsize,
                   shuffle = true),
        # Don't shuffle the test data
        DataLoader(gpu.(collect.((x_test, y_test))); batchsize = batchsize,
                   shuffle = false)
    )
end

In [None]:
# Main
const bs = 128
const train_split = 0.9
train_dataloader, test_dataloader = loadmnist(bs, train_split)

down = Chain(flatten, Dense(784, 20, tanh)) |> gpu

In [None]:

nn = Chain(Dense(20, 10, tanh),
           Dense(10, 10, tanh),
           Dense(10, 20, tanh)) |> gpu


nn_ode = NeuralODE(nn, (0.f0, 100.f0), Tsit5(),
                   save_everystep = false,
                   reltol = 1e-3, abstol = 1e-3,
                   save_start = false) |> gpu;

In [None]:
fc  = Chain(Dense(20, 10)) |> gpu;

In [None]:

function DiffEqArray_to_Array(x)
    xarr = gpu(x)
    return reshape(xarr, size(xarr)[1:2])
end


In [None]:

# Build our over-all model topology
model = Chain(down,
              nn_ode,
              DiffEqArray_to_Array,
              fc) |> gpu;

In [None]:
img, lab = train_dataloader.data[1][:, :, :, 1:1], train_dataloader.data[2][:, 1:1];

In [None]:
x_d = down(img);

In [None]:
nn(x_d);

In [79]:

# We can see that we can compute the forward pass through the NN topology
# featuring an NNODE layer.
x_m = model(img);

In [80]:
classify(x) = argmax.(eachcol(x))

function accuracy(model, data; n_batches = 100)
    total_correct = 0
    total = 0
    for (i, (x, y)) in enumerate(collect(data))
        # Only evaluate accuracy for n_batches
        i > n_batches && break
        target_class = classify(cpu(y))
        predicted_class = classify(cpu(model(x)))
        total_correct += sum(target_class .== predicted_class)
        total += length(target_class)
    end
    return total_correct / total
end

# burn in accuracy
accuracy(model, train_dataloader)

loss(x, y) = logitcrossentropy(model(x), y)

# burn in loss
loss(img, lab)

opt = ADAM(0.05)
iter = 0

cb() = begin
    global iter += 1
    # Monitor that the weights do infact update
    # Every 10 training iterations show accuracy
    if iter % 10 == 1
        train_accuracy = accuracy(model, train_dataloader) * 100
        test_accuracy = accuracy(model, test_dataloader;
                                 n_batches = length(test_dataloader)) * 100
        @printf("Iter: %3d || Train Accuracy: %2.3f || Test Accuracy: %2.3f\n",
                iter, train_accuracy, test_accuracy)
    end
end

# Train the NN-ODE and monitor the loss and weights.
Flux.train!(loss, params(down, nn_ode.p, fc), train_dataloader, opt, cb = cb)

Iter:   1 || Train Accuracy: 9.377 || Test Accuracy: 9.710
Iter:  11 || Train Accuracy: 9.743 || Test Accuracy: 9.710
Iter:  21 || Train Accuracy: 9.799 || Test Accuracy: 9.810
Iter:  31 || Train Accuracy: 9.821 || Test Accuracy: 9.810
Iter:  41 || Train Accuracy: 9.743 || Test Accuracy: 9.710
Iter:  51 || Train Accuracy: 9.799 || Test Accuracy: 9.810
Iter:  61 || Train Accuracy: 9.821 || Test Accuracy: 9.810
Iter:  71 || Train Accuracy: 10.277 || Test Accuracy: 10.310
