Solve the [Lotka Volterra](https://en.wikipedia.org/wiki/Lotka–Volterra_equations) problem that the original [NeuralNetDiffEq.jl](https://julialang.org/blog/2017/10/gsoc-NeuralNetDiffEq) **failed completely**.

In [1]:
using Plots
using DifferentialEquations
include("NN_solver.jl")

predict (generic function with 2 methods)

In [2]:
default(size = (300, 200)) # plot size

# ODE definition

In [3]:
function lotka_volterra!(t, u, du)
    du[1] = 1.5 .* u[1] - 1.0 .* u[1].*u[2]
    du[2] = -3 .* u[2] + u[1].*u[2]
end

y0_list = [1.0, 1.0]
tspan = (0.0,5.0)

(0.0, 5.0)

# Solve by DifferentialEquations.jl as reference

In [4]:
prob = ODEProblem(lotka_volterra!, y0_list, tspan)
sol = solve(prob, saveat=0.05, reltol=1e-6, abstol=1e-6);

In [5]:
plot(sol.t, sol[1,:], label="y1", legend = :topleft)
plot!(sol.t, sol[2,:], label="y2")

In [6]:
savefig("reference.svg")

# My ANN solver

## Initialize ANN

In [7]:
t = collect(linspace(tspan[1], tspan[2], 51))
t = reshape(t, 1, :) # training points

1×51 Array{Float64,2}:
 0.0  0.1  0.2  0.3  0.4  0.5  0.6  0.7  …  4.4  4.5  4.6  4.7  4.8  4.9  5.0

In [8]:
nn = init_nn(lotka_volterra!, t, y0_list, n_hidden = 20);
show(nn) # print basic info

Neural ODE Solver
Number of equations:       2
Initial condition y0:      [1.0, 1.0]
Numnber of hidden units:   20
Number of training points: 51


**Note: The subsequent training can be sensitive to initialization. If it cannot converge, may need to re-initialize the weights here and train again.**

In [9]:
reset_weights!(nn)

In [10]:
function quickplot()
    y_pred_list, _ = predict(nn)
    plot(sol.t, sol[1,:], label="y1 true", legend = :topleft)
    plot!(sol.t, sol[2,:], label="y2 true")

    plot!(nn.t[:], y_pred_list[1][:], label="y1 NN", 
         lw=0, marker=:circle, markerstrokewidth = 0, markersize=3)
    plot!(nn.t[:], y_pred_list[2][:], label="y2 NN", 
         lw=0, marker=:circle, markerstrokewidth = 0, markersize=3)
    ylims!(0, 7)
end

quickplot()

In [11]:
savefig("initial_guess.svg")

## First try a first-order optimizer

### Training

The basic Gradient Descent method converges very slowly:

In [14]:
@time train!(nn, iterations=1000, show_every=100, method=GradientDescent())

Iter     Function value   Gradient norm 
     0     3.169716e-01     2.202805e-01
   100     3.105593e-01     2.786916e-01
   200     3.053551e-01     2.375907e-01
   300     3.007554e-01     2.112435e-01
   400     2.965636e-01     1.920253e-01
   500     2.926856e-01     1.772866e-01
   600     2.890674e-01     1.657392e-01
   700     2.856757e-01     1.565731e-01
   800     2.824881e-01     1.491864e-01
   900     2.794890e-01     1.431070e-01
  1000     2.766658e-01     1.379711e-01
 69.752413 seconds (62.26 M allocations: 56.790 GiB, 14.89% gc time)


2-element Array{Any,1}:
 Any[[0.851823; 0.258042; … ; -1.5305; 1.1144], [0.0641279, -0.0141204, 0.00121622, -0.0233322, 0.0703514, 0.0173223, 0.0849347, -0.00537344, 0.0126308, -0.0121093, -0.0181212, -0.0743879, 0.00198639, 0.0439036, 0.00255647, -0.035803, -0.0226368, -0.0665881, 0.0140235, -0.000867725], [1.01918 -0.596006 … 0.840085 -0.00930486], [-0.0169718]]
 Any[[-0.253699; -0.297072; … ; -2.09996; 2.48815], [0.0690464, -0.0432935, -0.00743919, -0.0511839, -0.0934274, 0.0375544, -0.065465, -0.0507902, -0.00383175, 0.0616521, 0.0225693, 0.0365638, 0.0410208, -0.0245232, -0.0246678, 0.298073, -0.0103305, 0.0336446, -0.0156536, 0.0159725], [-0.630989 0.316227 … 0.321638 -0.355143], [-0.027613]]  

### Compare with reference solution

In [15]:
quickplot()

In [16]:
savefig("train0.svg")

## Then use a second-order optimizer

### Training

In [17]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     2.766658e-01     1.379711e-01
    20     1.998111e-01     2.346913e-01
    40     1.814000e-01     2.198355e-01
    60     1.560469e-01     9.629551e-01
    80     1.301012e-01     2.039628e-01
   100     1.172671e-01     7.605691e-01
   120     9.514630e-02     4.782085e-01
   140     7.820542e-02     3.996489e-01
   160     5.846823e-02     5.057250e-01
   180     3.368965e-02     2.030275e-01
   200     1.988986e-02     2.508233e-01
 15.295869 seconds (12.00 M allocations: 10.534 GiB, 13.40% gc time)


### Compare with reference solution

In [18]:
quickplot()

In [19]:
savefig("train1.svg")

## Continue training to further bring down the error

We do this progressive so we can know how many iterations are exactly needed. The code below is largely repeative can be simplified to a single call.

### Round 2

In [20]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     1.988986e-02     2.508233e-01
    20     1.586872e-02     4.402852e-01
    40     1.084081e-02     2.619599e-01
    60     5.145943e-03     1.182974e-01
    80     3.306572e-03     1.471672e-01
   100     2.409457e-03     1.001235e-01
   120     2.115766e-03     8.810897e-02
   140     1.981713e-03     6.165784e-02
   160     1.919678e-03     3.256366e-02
   180     1.834639e-03     9.909097e-02
   200     1.628578e-03     5.603911e-02
 14.751598 seconds (11.78 M allocations: 10.736 GiB, 13.95% gc time)


In [21]:
quickplot()

In [22]:
savefig("train2.svg")

### Round 3

In [23]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     1.628578e-03     5.603911e-02
    20     1.582517e-03     5.185822e-02
    40     1.527072e-03     2.382414e-02
    60     1.438738e-03     2.010915e-01
    80     1.324651e-03     4.502148e-02
   100     1.225022e-03     8.681411e-03
   120     1.053733e-03     9.422423e-02
   140     9.162333e-04     5.680250e-02
   160     6.982738e-04     4.916793e-02
   180     5.489311e-04     1.349164e-01
   200     4.397707e-04     5.070524e-02
 12.442035 seconds (11.40 M allocations: 10.396 GiB, 15.73% gc time)


In [24]:
quickplot()

In [25]:
savefig("train3.svg")

### Round 4

In [26]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     4.397707e-04     5.070524e-02
    20     4.026615e-04     4.663021e-02
    40     3.060492e-04     2.179518e-02
    60     2.206999e-04     9.591778e-02
    80     1.743112e-04     3.150744e-02
   100     1.276654e-04     7.907902e-02
   120     1.175746e-04     1.037274e-02
   140     1.079285e-04     2.921081e-02
   160     7.818330e-05     5.205425e-03
   180     7.129838e-05     4.450609e-03
   200     6.435730e-05     3.736333e-02
 12.930993 seconds (12.12 M allocations: 11.054 GiB, 15.87% gc time)


In [27]:
quickplot()

In [28]:
savefig("train4.svg")

### Round 5

In [29]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     6.435730e-05     3.736333e-02
    20     6.220481e-05     5.483996e-03
    40     5.596889e-05     2.166454e-02
    60     5.074154e-05     1.809271e-02
    80     4.490502e-05     1.871842e-02
   100     4.326679e-05     6.486733e-03
   120     3.990787e-05     1.263422e-02
   140     3.727240e-05     2.373896e-02
   160     3.626178e-05     1.250467e-02
   180     2.956916e-05     1.350270e-02
   200     2.744053e-05     7.270890e-03
 13.744363 seconds (12.77 M allocations: 11.644 GiB, 16.19% gc time)


In [30]:
quickplot()

In [31]:
savefig("train5.svg")

### Round 6

In [32]:
@time train!(nn, iterations=200, show_every=20, method=BFGS());

Iter     Function value   Gradient norm 
     0     2.744053e-05     7.270890e-03
    20     2.722552e-05     2.236682e-02
    40     2.563928e-05     1.291942e-03
    60     2.500930e-05     5.504895e-03
    80     2.427631e-05     4.498904e-03
   100     2.371244e-05     1.946768e-02
   120     2.295648e-05     1.466324e-02
   140     2.140554e-05     3.237892e-02
   160     2.044879e-05     3.159542e-02
   180     1.950768e-05     2.924390e-03
   200     1.843729e-05     3.775219e-03
 13.685619 seconds (12.77 M allocations: 11.644 GiB, 16.29% gc time)


In [33]:
quickplot()

In [34]:
savefig("train6.svg")