In [None]:
#Import packages
using OrdinaryDiffEq, NODEData, Plots
using Flux, DiffEqSensitivity, Parameters
using Statistics


In [None]:
function lorenz63(x,p,t)
    σ, β, ρ = p 
    [σ*(x[2] - x[1]), x[1]*(ρ - x[3]) - x[2], x[1]*x[2] - β*x[3]]
end

σ = 10
β = 8f0/3
ρ = 28
p = [σ, β, ρ]

t_transient = 100
N_t_train = 500
N_t_valid = N_t_train*3
N_t = N_t_train + N_t_valid
dt = 0.1f0
tspan = (0f0, Float32(t_transient + N_t * dt))

x0 = [0.1f0, 0.1f0, 0.1f0] 

prob = ODEProblem(lorenz63, x0, tspan, p) 
sol = solve(prob, Tsit5(), saveat=saveat=t_transient:dt:t_transient + N_t * dt)

t_train = t_transient:dt:t_transient+N_t_train*dt
data_train = Array(sol(t_train))

t_valid = t_transient+N_t_train*dt:dt:t_transient+N_t_train*dt+N_t_valid*dt
data_valid = Array(sol(t_valid))

train = NODEDataloader(Float32.(data_train), t_train, 2)
valid = NODEDataloader(Float32.(data_valid), t_valid, 2)

#train, valid = NODEDataloader(sol, 10; dt=dt, valid_set=0.8)

In [None]:
plot3d(sol[1,:], sol[2,:], sol[3,:], camera = (45 , 40), xlabel="x", ylabel="y", zlabel="z")

In [None]:
#Define ANN
N_WEIGHTS = 16 
nn = Chain(Dense(3, N_WEIGHTS, relu), Dense(N_WEIGHTS, N_WEIGHTS, relu), Dense(N_WEIGHTS, N_WEIGHTS, relu), Dense(N_WEIGHTS, N_WEIGHTS, relu), Dense(N_WEIGHTS, 3))
p, re_nn = Flux.destructure(nn)


function neural_lorenz63(u, p, t)
    σ = 10
    β = 8f0/3
    [σ*(x[2] - x[1]), re_nn(p) - x[2], x[1]*x[2] - β*x[3]]
end

neural_ode(u, p, t) = re_nn(p)(u)
node_prob = ODEProblem(neural_ode, x0, (Float32(0.),Float32(dt)), p)

In [None]:
abstract type AbstractChaoticNDEModel end 

"""
    ChaoticNDE{P,R,A,K} <: AbstractChaoticNDEModel

Model for setting up and training Chaotic Neural Differential Equations.

# Fields:

* `p` parameter vector 
* `prob` DEProblem 
* `alg` Algorithm to use for the `solve` command 
* `kwargs` any additional keyword arguments that should be handed over (e.g. `sensealg`)

# Constructor 

`ChaoticNDE(prob; alg=Tsit5(), kwargs...)`
"""
struct ChaoticNDE{P,R,A,K} <: AbstractChaoticNDEModel
    p::P 
    prob::R 
    alg::A
    kwargs::K
end 

function ChaoticNDE(prob; alg=Tsit5(), kwargs...)
    p = prob.p 
    ChaoticNDE{typeof(p), typeof(prob), typeof(alg), typeof(kwargs)}(p, prob, alg, kwargs)
end 

Flux.@functor ChaoticNDE
Flux.trainable(m::ChaoticNDE) = (p=m.p,)

function (m::ChaoticNDE)(X,p=m.p)
    (t, x) = X 
    Array(solve(remake(m.prob; tspan=(t[1],t[end]),u0=x[:,1],p=p), m.alg; saveat=t, m.kwargs...))
end

model = ChaoticNDE(node_prob)
model(train[1])

loss(x, y) = sum(abs2, x - y)
loss(model(train[1]), train[1][2]) 

η = 1f-3
opt = Flux.AdamW(η)
opt_state = Flux.setup(opt, model)

In [None]:
#Train model
#η = 1f-3

loss_log = Float32[]
for epoch in 1:30
    if (epoch % 5) == 0 println("Epoch:", epoch) end
    losses = Float32[]
    for (i, data) in enumerate(train)
        t, x = data

        val, grads = Flux.withgradient(model) do m
            # Any code inside here is differentiated.
            # Evaluation of the model and loss must be inside!
            result = m((t,x))
            loss(result, x)
        end

        # Save the loss from the forward pass. (Done outside of gradient.)
        push!(losses, val)

        # Detect loss of Inf or NaN. Print a warning, and then skip update!
        if !isfinite(val)
            @warn "loss is $val on item $i" epoch
            continue
        end

    Flux.update!(opt_state, model, grads[1])
    end
    push!(loss_log, Statistics.mean(losses))

    if (epoch % 30) == 0  # reduce the learning rate every 30 epochs
        η /= 2
        Flux.adjust!(opt_state, η)
    end
end

In [None]:
plot(loss_log)

In [None]:
#Test model
t = convert(Array{Float32,1}, collect(0:0.1:2))
rec_sol = model((t,x0))
plot3d(rec_sol[1,:], rec_sol[2,:], rec_sol[3,:])