Skip to content

Extension of Neural ODE tutorial to SDE case is very slow #26

@metanoid

Description

@metanoid

I am not sure if this is a ComponentArrays issue.
I tried extending the DiffEqFlux tutorial to an SDE, and the performance took a huge knock, each iteration of training takes around 2 minutes on my machine.

using ComponentArrays
using OrdinaryDiffEq
using Plots
using UnPack

using DiffEqFlux: sciml_train
using Flux: glorot_uniform, ADAM
using Optim: LBFGS
using DifferentialEquations
using DiffEqSensitivity

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)

dense_layer(in, out) = ComponentArray(W=glorot_uniform(out, in), b=zeros(out))

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
trueprob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(trueprob, Tsit5(), saveat = t)) .+ (randn(2, datasize) .* 0.3)


function dudt(u, p, t)
    # display(typeof(p))
    @unpack L1, L2 = p.p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end
function noisy(u,p,t)
    @unpack noise_params = p
    return noise_params .* u .* 0.1
end

prob = SDEProblem(dudt, noisy, u0, tspan)

layers = (L1=dense_layer(2, 50), L2=dense_layer(50, 2))
θ = ComponentArray(u=u0, p=( p = layers, noise_params = rand(2)))
predict_n_ode(θ) = Array(solve(prob, Tsit5(), u0=θ.u, p=θ.p, saveat=t, sensealg=BacksolveAdjoint()))

function loss_n_ode(θ)
    pred = predict_n_ode(θ)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_n_ode(θ)

cb = function (θ, loss, pred; doplot=false)
    display(loss)
    # plot current prediction against data
    pl = scatter(t, ode_data[1,:], label = "data")
    scatter!(pl, t, pred[1,:], label = "prediction")
    display(plot(pl))
    return false
end

cb(θ, loss_n_ode(θ)...)

data = Iterators.repeated((), 1000)

res1 = sciml_train(loss_n_ode, θ, ADAM(0.05); cb=cb, maxiters=100)
cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)

res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions