Open
Description
Hello. I was training a simple UODE and am encountering this error when running sciml_train()
. The error seems to be somewhere in the interface between DiffEqFlux
and GalacticOptim
. I have the specific error message, and some code to replicate the problem.
Here is the error message text:
ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at ~/.julia/packages/Static/8hh0B/src/float.jl:26
convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
...
Stacktrace:
[1] fill!(dest::Vector{Float32}, x::Nothing)
@ Base ./array.jl:351
[2] copyto!
@ ./broadcast.jl:921 [inlined]
[3] materialize!
@ ./broadcast.jl:871 [inlined]
[4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broad)
@ Base.Broadcast ./broadcast.jl:868
[5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_uode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
@ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
[6] macro expansion
@ ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:27 [inlined]
[7] macro expansion
@ ~/.julia/packages/GalacticOptim/fow0r/src/utils.jl:35 [inlined]
[8] __solve(prob::OptimizationProblem{, opt::ADAM, data::Base.Iterators.Cycle; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, U)
@ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:25
[9] #solve#482
@ ~/.julia/packages/SciMLBase/OHiiA/src/solve.jl:3 [inlined]
[10] sciml_train(::typeof(loss_uode), ::Vector{Float32}, ::ADAM, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Pairs{Symbol, U)
@ DiffEqFlux ~/.julia/packages/DiffEqFlux/vJuRw/src/train.jl:91
[11] top-level scope
@ ~/Dropbox/sandbox/julia_gend_univ/experiments/uode_benchmarking/mwe.jl:71
Here is the code to replicate the problem. I hardcoded some data so that anyone can
precisely replicate the issue. The code itself is based on some of the examples in the
Universal ODE github repo.
using StatsBase
using Plots
using DifferentialEquations
using DiffEqFlux
using StatsPlots
using ComponentArrays
using Distributions
using DifferentialEquations.EnsembleAnalysis
using SciMLBase
using DiffEqCallbacks
using GalacticOptim
using NLopt
using NNlib
using Statistics
using Flux
gr();
rbf(x) = exp.(-(x.^2))
function get_nn(l::Integer, nodes::Integer)
FastChain(
FastDense(6, nodes, rbf),
(FastDense(nodes, nodes, rbf) for _ in 1:l)...,
FastDense(nodes, 6))
end;
U = get_nn(1, 10)
pinit = initial_params(U)
theta = ComponentVector{Float64}(rattr_f1 = 0.042687970552806695, rattr_f2 = -0.13366589419787755, rattr_f3 = -0.005666828159994795, rattr_m1 = -0.001969619987196797, rattr_m2 = -0.057685466652091864, rattr_m3 = 0.03484072549732818, rhire_f1 = -0.02268797055304856, rhire_f2 = 0.14466589419764714, rhire_f3 = 0.01666682815999726, rhire_m1 = 0.021969619987239963, rhire_m2 = 0.07768546665210119, rhire_m3 = -0.01484072549733023, rprom_f1 = -0.12398041712603916, rprom_f2 = 0.05063523685785915, rprom_f3 = 0.0, rprom_m1 = -0.023528128109958992, rprom_m2 = 0.11427635100343723, rprom_m3 = 0.0, growth_rate_linear = 0.01)
raw_data = [3.3064507529453775, 2.411147842897867, 3.8541953698304243, 8.515747060715876, 18.034486297095288, 32.886093862709245, 4.48549960824742, 2.5562275466943074, 4.058708308896289, 11.51123402926337, 18.08620682567805, 33.57792521585569, 6.0468671504292155, 2.6907051307235283, 4.397601585885496, 14.814845644929573, 18.08529694293005, 34.252912163683725, 7.192672397387545, 2.937028837016381, 4.709731483233916, 18.723508048394834, 17.436642346365506, 34.89649452732027, 7.352580291945779, 3.4231308506611935, 4.888737719134746, 23.479663335819453, 16.209855858422877, 35.33931656393878, 7.2331655891210325, 4.131368800956079, 4.925093311266256, 28.673073555019325, 15.298779271098546, 35.543004556821245, 7.369935260789615, 4.7141389130161775, 4.678267058817686, 32.549852535972846, 14.849690091055214, 35.522268257020734, 8.307311635173907, 4.68681368146203, 4.477904053217947, 34.35541746900183, 15.452784944041495, 35.78857036748951, 10.143679063434147, 4.503984620423812, 4.419740291236412, 35.61459125937783, 16.042619087690753, 36.75482632364377, 11.98711821793013, 4.408672346554513, 4.895176737702706, 37.02628984889759, 16.003122755644856, 37.83154651006416, 13.654191894783967, 4.321165531477556, 5.602231066759056, 39.42383172231546, 15.980116848398685, 38.360016616294025, 15.07127136879804, 4.495640140672605, 6.186352487270125, 41.47585581747662, 16.727453775574347, 37.662018081417905, 15.243668538849033, 5.042530500538143, 6.328609512332072, 41.12613313537764, 19.000477841200336, 36.2447375671493, 14.613889931882971, 5.935769188767932, 6.199844065932502, 37.692953653917044, 22.89764644624489, 35.626941568081406, 13.694132960603485, 7.265264704518828, 5.931120003995033, 32.681378551084386, 26.61944037304275, 36.04946436642747, 12.989840749786113, 8.41647817962317, 5.874204204500485, 28.09604869428785, 28.0865626319234, 36.84136730252351, 12.460187701403568, 8.963034340620466, 6.496033181057384, 24.678755668173096, 27.472004139049975, 37.611189827329774, 12.60904088368306, 9.267867186695572, 7.099502466318395, 23.266513884738565, 25.25166618362897, 38.75528565771812, 13.178193080833045, 9.741361564156664, 7.786274544862907, 23.17812389142553, 23.045696601549917, 39.809506965115844, 13.990735014147583, 10.013169718744082, 8.982072355245323, 23.985936218734057, 20.230054255412583, 41.73550295219455, 14.99439640550598, 9.347386530026874, 10.941111084027392, 24.595027358946172, 17.757242325900023, 43.47253111918933, 14.877425968003662, 8.016118599612227, 13.350653962924648, 24.511605087832592, 15.5650983784468, 43.9015285476356, 14.563830542982078, 6.542545706368523, 15.775065661175315, 24.378778838851183, 14.539469120814292, 42.98157944534187, 14.057476431389063, 5.696368873306736, 17.02074365038269, 23.889658063571883, 14.724942331998033, 41.03721512324303, 14.113967621903988, 5.306870615252368, 17.62434605139526, 23.76767610436216, 15.317268849390269, 39.450457715937425, 14.229422124813736, 5.118940992390327, 17.828337152051862, 24.02635906876211, 16.164701004347, 38.93877340118323, 13.655222792709683, 5.07773075279201, 17.35735286039698, 23.928076206731614, 16.565210177074103, 39.28698134784897, 13.319193222138987, 4.665122538459729, 16.671204260330004, 23.8006495075749, 17.059360191354823, 39.78697306031208, 12.456792318264581, 4.308599089715807, 15.665121353876547, 24.123939210471914, 17.9263437802465, 40.263216511865735, 11.711382659973731, 3.6816219282402844, 15.099750324843555, 24.956045460013552, 19.03089773993017, 41.29472719013936, 11.215236446812579, 2.904896534655703, 14.69177798658438, 25.553490886574302, 20.365736659078202, 42.549245707105925]
full_data = reshape(raw_data, (6, 31))
function genduniv_ode_ude!(du, u, p, t, q)
û = U(u, p) # network prediction
du[1] = q.rhire_f1*u[1] - q.rattr_f1*u[1] - q.rprom_f1*u[1] - û[1]
du[2] = q.rhire_f2*u[2] + q.rprom_f1*u[1] - q.rattr_f2*u[2] - q.rprom_f2*u[2] + û[2]
du[3] = q.rhire_f3*u[3] + q.rprom_f2*u[2] - q.rattr_f3*u[3] - û[3]
du[4] = q.rhire_m1*u[4] - q.rattr_m1*u[4] - q.rprom_m1*u[4] - û[4]
du[5] = q.rhire_m2*u[5] + q.rprom_m1*u[4] - q.rattr_m2*u[5] - q.rprom_m2*u[5] + û[5]
du[6] = q.rhire_m3*u[6] + q.rprom_m2*u[5] - q.rattr_m3*u[6] + û[6]
du
end;
nn_dynamics!(du, u, p, t) = genduniv_ode_ude!(du, u, p, t, theta)
u0 = full_data[:, 1]
tspan = (1.0, 31.0)
prob_nn = ODEProblem(nn_dynamics!, u0, tspan, pinit)
function predict_uode(params)
remake(prob_nn, p=params)
Array(solve(prob_nn, Vern7(), saveat = 1.0))
end;
function loss_uode(p)
pred = predict_uode(p)
loss = sum(abs2, full_data .- pred) #+ Float32(1e-4)*sum(abs2, p)/length(p) # Just sum of squared error
return loss, pred
end;
callback_display = function (p, l, pred)
display(l)
return false
end
result_ode_uode = DiffEqFlux.sciml_train(loss_uode, pinit, ADAM(0.1), cb = callback_display, maxiters = 100)
Metadata
Metadata
Assignees
Labels
No labels