-
-
Notifications
You must be signed in to change notification settings - Fork 39
Closed
Description
I am not sure if this is a ComponentArrays issue.
I tried extending the DiffEqFlux tutorial to an SDE, and the performance took a huge knock, each iteration of training takes around 2 minutes on my machine.
using ComponentArrays
using OrdinaryDiffEq
using Plots
using UnPack
using DiffEqFlux: sciml_train
using Flux: glorot_uniform, ADAM
using Optim: LBFGS
using DifferentialEquations
using DiffEqSensitivity
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)
dense_layer(in, out) = ComponentArray(W=glorot_uniform(out, in), b=zeros(out))
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
trueprob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(trueprob, Tsit5(), saveat = t)) .+ (randn(2, datasize) .* 0.3)
function dudt(u, p, t)
# display(typeof(p))
@unpack L1, L2 = p.p
return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end
function noisy(u,p,t)
@unpack noise_params = p
return noise_params .* u .* 0.1
end
prob = SDEProblem(dudt, noisy, u0, tspan)
layers = (L1=dense_layer(2, 50), L2=dense_layer(50, 2))
θ = ComponentArray(u=u0, p=( p = layers, noise_params = rand(2)))
predict_n_ode(θ) = Array(solve(prob, Tsit5(), u0=θ.u, p=θ.p, saveat=t, sensealg=BacksolveAdjoint()))
function loss_n_ode(θ)
pred = predict_n_ode(θ)
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_n_ode(θ)
cb = function (θ, loss, pred; doplot=false)
display(loss)
# plot current prediction against data
pl = scatter(t, ode_data[1,:], label = "data")
scatter!(pl, t, pred[1,:], label = "prediction")
display(plot(pl))
return false
end
cb(θ, loss_n_ode(θ)...)
data = Iterators.repeated((), 1000)
res1 = sciml_train(loss_n_ode, θ, ADAM(0.05); cb=cb, maxiters=100)
cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)
res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)Metadata
Metadata
Assignees
Labels
No labels