Solve the [Lotka Volterra](https://en.wikipedia.org/wiki/Lotka–Volterra_equations) problem that the original [NeuralNetDiffEq.jl](https://julialang.org/blog/2017/10/gsoc-NeuralNetDiffEq) **failed completely**.

In [1]:
using Plots
using DifferentialEquations
include("NN_solver.jl")

predict (generic function with 2 methods)

In [2]:
default(size = (300, 200)) # plot size

# ODE definition

In [3]:
function lotka_volterra!(t, u, du)
    du[1] = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
    du[2] = -3 .* u[2] + u[1].*u[2]
end

y0_list = [1.0, 1.0]
tspan = (0.0,5.0)

(0.0, 5.0)

# Solve by DifferentialEquations.jl as reference

In [4]:
prob = ODEProblem(lotka_volterra!, y0_list, tspan)
sol = solve(prob, saveat=0.05, reltol=1e-6, abstol=1e-6);

In [5]:
plot(sol.t, sol[1,:], label="y1", legend = :topleft)
plot!(sol.t, sol[2,:], label="y2")

In [6]:
savefig("reference.svg")

# My ANN solver

## Initialize ANN

In [7]:
t = collect(linspace(tspan[1], tspan[2], 51))
t = reshape(t, 1, :) # training points

1×51 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  …  4.4  4.5  4.6  4.7  4.8  4.9  5.0

In [8]:
nn = init_nn(lotka_volterra!, t, y0_list, n_hidden = 20);
show(nn) # print basic info

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [1.0, 1.0]
Numnber of hidden units:   20
Number of training points: 51


**Note: The subsequent training can be sensitive to initialization. If it cannot converge, may need to re-initialize the weights here and train again.**

In [18]:
reset_weights!(nn)

In [19]:
function quickplot()
    y_pred_list, _ = predict(nn)
    plot(sol.t, sol[1,:], label="y1 true", legend = :topleft)
    plot!(sol.t, sol[2,:], label="y2 true")

    plot!(nn.t[:], y_pred_list[1][:], label="y1 NN", 
         lw=0, marker=:circle, markerstrokewidth = 0, markersize=3)
    plot!(nn.t[:], y_pred_list[2][:], label="y2 NN", 
         lw=0, marker=:circle, markerstrokewidth = 0, markersize=3)
end

quickplot()

In [20]:
savefig("initial_guess.svg")

## First try a first-order optimizer

### Training

The basic Gradient Descent method converges very slowly:

In [21]:
@time opt = train!(nn, iterations=1000, show_every=100, method=GradientDescent())

Iter     Function value   Gradient norm 
     0     3.761501e+02     7.887165e+03




   100     6.440676e-01     4.157988e-01
   200     6.290368e-01     3.670549e-01
   300     6.153545e-01     3.465893e-01
   400     5.985151e-01     3.872297e-01
   500     5.854940e-01     4.000883e-01
   600     5.726636e-01     4.224358e-01
   700     5.458444e-01     4.071303e-01
   800     5.098684e-01     6.417561e-01
   900     4.933963e-01     2.009057e-01
  1000     4.807139e-01     3.641788e-01
 82.773930 seconds (72.87 M allocations: 58.137 GiB, 13.13% gc time)


Results of Optimization Algorithm
 * Algorithm: Gradient Descent
 * Starting Point: [-0.3940013248433644,-0.6813627922751576, ...]
 * Minimizer: [-0.3834351815214026,-0.6673994915460282, ...]
 * Minimum: 4.807139e-01
 * Iterations: 1000
 * Convergence: false
   * |x - x'| < 1.0e-32: false 
     |x - x'| = 7.20e-05 
   * |f(x) - f(x')| / |f(x)| < 1.0e-32: false
     |f(x) - f(x')| / |f(x)| = 2.36e-04 
   * |g(x)| < 1.0e-08: false 
     |g(x)| = 3.64e-01 
   * stopped by an increasing objective: false
   * Reached Maximum Number of Iterations: true
 * Objective Calls: 2545
 * Gradient Calls: 2545

### Compare with reference solution

In [22]:
quickplot()

In [23]:
savefig("train0.svg")

## Then use a second-order optimizer

### Training

In [24]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     4.807139e-01     3.641788e-01
    20     2.062586e-01     2.103476e-01
    40     1.855814e-01     9.277237e-02
    60     1.593247e-01     2.649147e+00
    80     1.302820e-01     4.355558e-01
   100     1.045770e-01     1.204467e+00
   120     7.474441e-02     7.298782e-01
   140     5.227283e-02     8.689202e-01
   160     2.627641e-02     3.703598e-01
   180     1.876283e-02     1.666257e-01
   200     1.125537e-02     3.881110e-01
 15.237179 seconds (12.65 M allocations: 11.124 GiB, 14.03% gc time)


### Compare with reference solution

In [25]:
quickplot()

In [26]:
savefig("train1.svg")

## Continue training to further bring down the error

We do this progressive so we can know how many iterations are exactly needed. The code below is largely repeative can be simplified to a single call.

### Round 2

In [27]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     1.125537e-02     3.881110e-01
    20     7.827297e-03     2.807081e-01
    40     4.528628e-03     1.001401e-01
    60     2.934222e-03     3.621814e-02
    80     2.511388e-03     1.128397e-01
   100     2.142640e-03     4.228054e-02
   120     1.951372e-03     6.955948e-02
   140     1.861941e-03     1.281674e-01
   160     1.639855e-03     3.151713e-02
   180     1.568052e-03     7.095340e-02
   200     1.540943e-03     2.577444e-02
 12.756554 seconds (11.80 M allocations: 10.759 GiB, 15.78% gc time)


In [28]:
quickplot()

In [29]:
savefig("train2.svg")

### Round 3

In [30]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     1.540943e-03     2.577444e-02
    20     1.533370e-03     1.849400e-02
    40     1.513330e-03     8.037753e-03
    60     1.481464e-03     6.169330e-02
    80     1.425685e-03     4.736749e-02
   100     1.378786e-03     5.218985e-02
   120     1.301499e-03     1.052588e-01
   140     1.125536e-03     1.888068e-01
   160     9.706648e-04     1.203090e-01
   180     9.456348e-04     1.264325e-01
   200     8.939693e-04     1.026971e-01
 12.429721 seconds (11.78 M allocations: 10.736 GiB, 15.58% gc time)


In [31]:
quickplot()

In [32]:
savefig("train3.svg")

### Round 4

In [33]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     8.939693e-04     1.026971e-01
    20     8.668762e-04     3.213256e-02
    40     8.101685e-04     8.214193e-02
    60     7.593789e-04     6.507863e-02
    80     7.116888e-04     8.299522e-02
   100     6.524115e-04     2.657131e-02
   120     5.736572e-04     7.254289e-02
   140     5.022448e-04     7.290899e-02
   160     3.889444e-04     3.709599e-02
   180     3.114184e-04     8.121390e-02
   200     2.704294e-04     9.643871e-02
 12.284763 seconds (11.38 M allocations: 10.373 GiB, 15.71% gc time)


In [34]:
quickplot()

In [35]:
savefig("train4.svg")

### Round 5

In [36]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     2.704294e-04     9.643871e-02
    20     2.542466e-04     3.681995e-02
    40     2.152552e-04     4.185301e-02
    60     1.817210e-04     1.078525e-01
    80     1.730984e-04     2.242688e-02
   100     1.403765e-04     1.005905e-01
   120     1.145174e-04     6.649120e-02
   140     9.491102e-05     5.550138e-02
   160     8.076651e-05     1.005759e-01
   180     5.526149e-05     3.633556e-02
   200     5.037231e-05     2.151210e-03
 12.553040 seconds (11.85 M allocations: 10.804 GiB, 15.51% gc time)


In [37]:
quickplot()

In [38]:
savefig("train5.svg")

### Round 6

In [39]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     5.037231e-05     2.151210e-03
    20     4.980894e-05     2.730332e-02
    40     4.616251e-05     2.304167e-02
    60     4.204669e-05     8.376277e-02
    80     3.813186e-05     7.406999e-03
   100     3.565229e-05     6.914135e-03
   120     3.319774e-05     7.353812e-03
   140     3.130245e-05     9.317539e-03
   160     2.808777e-05     4.104920e-02
   180     2.548812e-05     2.279862e-02
   200     2.135330e-05     1.448674e-02
 13.430463 seconds (12.22 M allocations: 11.145 GiB, 14.83% gc time)


In [40]:
quickplot()

In [41]:
savefig("train6.svg")