In [None]:
using Revise

In [None]:
using RegNeuralODE, Plots, OrdinaryDiffEq, Flux
using Plots.PlotMeasures

In [None]:
train_dataloader, test_dataloader = load_spiral2d(256; nspiral = 256, stop = Float32(6π),
                                                  start = 0.0f0, b = 0.3f0, noise_std = 0.3f0);
sts = train_dataloader.data[2][:, 1];

In [None]:
vanilla_node = ExtrapolationLatentODE(4, 2, 64, 64,
                                      nn -> NFECounterNeuralODE(nn, (0.0f0, Float32(6π)),
                                                                Tsit5(), saveat = sts,
                                                                reltol = 1f-3, abstol = 1f-3),
                                      cpu)
loss_unreg = RegNeuralODE.get_loss_function(vanilla_node)
opt_vanilla_node = ADAM(0.01)

reg_node = ExtrapolationLatentODE(4, 2, 64, 64,
                                  nn -> NFECounterCallbackNeuralODE(nn, (0.0f0, Float32(6π)),
                                                                    Tsit5(), saveat = sts,
                                                                    reltol = 1f-3, abstol = 1f-3),
                                  cpu)
# Start with the same initialization
reg_node.p1 .= vanilla_node.p1
reg_node.p2 .= vanilla_node.p2
reg_node.p3 .= vanilla_node.p3
loss_reg = RegNeuralODE.get_loss_function(reg_node; λ = 1.0f2)
opt_reg_node = ADAM(0.01)

In [None]:
reg_node, nfe_count_reg_node, train_losses_reg_node, test_losses_reg_node =
    RegNeuralODE.train!(reg_node, opt_reg_node, 1000, train_dataloader,
                        test_dataloader, loss_reg, loss_unreg);

In [None]:
vanilla_node, nfe_count_vanilla_node, train_losses_vanilla_node, test_losses_vanilla_node =
    RegNeuralODE.train!(vanilla_node, opt_vanilla_node, 1000, train_dataloader,
                        test_dataloader, loss_unreg);

In [None]:
plot(nfe_count_vanilla_node, lw = 2, color = :blue, linestyle = :dash,
     right_margin = 20mm, legend = false)
plot!(nfe_count_reg_node, lw = 2, color = :blue,
      right_margin = 20mm, legend = false)
ylims!(75.0, 200.0)
ylabel!("NFE Count")
plt = twinx()
plot!(plt, train_losses_vanilla_node[2:end], color=:red, legend = :topleft,
      label = "NeuralODE", lw = 2, linestyle = :dash, right_margin = 20mm)
plot!(plt, train_losses_reg_node[2:end], color=:red, legend = :topleft,
      label = "RegNODE", lw = 2, right_margin = 20mm)
ylabel!(plt, "Running (Unregularized) Training Loss")
xlabel!("Training Epoch")
savefig("latent_ode_extrapolation.png")