In [None]:
using Revise

In [None]:
using RegNeuralODE, Plots, OrdinaryDiffEq, Flux
using Plots.PlotMeasures

In [None]:
# TODO: This is a model from the DiffEqFlux Tutorial. We need to use GPU (TrackerAdjoint)
#       if we want to train the models from the easy-neural-ode paper
vanilla_node = ClassifierNODE(
    Chain(flatten, Dense(784, 20, tanh)),
    NFECounterNeuralODE(Chain(Dense(20, 10, tanh),
                              Dense(10, 10, tanh),
                              Dense(10, 20, tanh)),
                        [0.f0, 1.f0], Tsit5(),
                        save_everystep = false,
                        reltol = 6f-5, abstol = 6f-5,
                        save_start = false),
    Chain(RegNeuralODE.diffeqsol_to_array, Dense(20, 10))
)
opt_vanilla_node = ADAM(0.01)

reg_node = ClassifierNODE(
    Chain(flatten, Dense(784, 20, tanh)),
    NFECounterCallbackNeuralODE(Chain(Dense(20, 10, tanh),
                                      Dense(10, 10, tanh),
                                      Dense(10, 20, tanh)),
                                [0.f0, 1.f0], Tsit5(),
                                save_everystep = false,
                                reltol = 6f-5, abstol = 6f-5,
                                save_start = false),
    Chain(RegNeuralODE.diffeqsol_to_array, Dense(20, 10))
)
# Start with the same initialization
reg_node.p1 .= vanilla_node.p1
reg_node.p2 .= vanilla_node.p2
reg_node.p3 .= vanilla_node.p3
opt_reg_node = ADAM(0.01)

In [None]:
train_dataloader, test_dataloader = load_mnist(4096);

In [None]:
vanilla_node, vanilla_node_nfe_count, vanilla_node_train_accuracy, vanilla_node_test_accuracy =
    RegNeuralODE.train!(vanilla_node, opt_vanilla_node, 50,
                        train_dataloader, test_dataloader,
                        RegNeuralODE.get_loss_function(vanilla_node))

In [None]:
reg_node, reg_node_nfe_count, reg_node_train_accuracy, reg_node_test_accuracy =
    RegNeuralODE.train!(reg_node, opt_reg_node, 50,
                        train_dataloader, test_dataloader,
                        RegNeuralODE.get_loss_function(reg_node))

In [None]:
plot(vanilla_node_nfe_count, legend = :topleft, lw = 2,
     label = "NeuralODE", color = :blue, linestyle = :dash,
     right_margin = 10mm)
plot!(reg_node_nfe_count, legend = :topleft, lw = 2,
      label = "RegNODE", color = :blue, right_margin = 10mm)
ylims!(25.0, 50.0)
ylabel!("NFE Count")
plt = twinx()
plot!(plt, (1 .- vanilla_node_test_accuracy[2:end]) * 100, color=:red,
      legend=false, lw = 2, linestyle = :dash, right_margin = 10mm)
plot!(plt, (1 .- reg_node_test_accuracy[2:end]) * 100, color=:red,
      legend=false, lw = 2, right_margin = 10mm)
ylabel!(plt, "Test Error (%)")
xlabel!("Training Epoch")
savefig("mnist_classification.png")