Skip to content

Training a UODE and get error Cannot convert an object of type Nothing to an object of type Float32 #703

Open
@00krishna

Description

@00krishna

Hello. I was training a simple UODE and am encountering this error when running sciml_train(). The error seems to be somewhere in the interface between DiffEqFlux and GalacticOptim. I have the specific error message, and some code to replicate the problem.

Here is the error message text:

ERROR: MethodError: Cannot `convert` an object of type Nothing to an object of type Float32
Closest candidates are:
  convert(::Type{T}, ::Static.StaticFloat64{N}) where {N, T<:AbstractFloat} at ~/.julia/packages/Static/8hh0B/src/float.jl:26
  convert(::Type{T}, ::LLVM.GenericValue, ::LLVM.LLVMType) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/execution.jl:39
  convert(::Type{T}, ::LLVM.ConstantFP) where T<:AbstractFloat at ~/.julia/packages/LLVM/tVv0H/src/core/value/constant.jl:103
  ...
Stacktrace:
  [1] fill!(dest::Vector{Float32}, x::Nothing)
    @ Base ./array.jl:351
  [2] copyto!
    @ ./broadcast.jl:921 [inlined]
  [3] materialize!
    @ ./broadcast.jl:871 [inlined]
  [4] materialize!(dest::Vector{Float32}, bc::Base.Broadcast.Broad)
    @ Base.Broadcast ./broadcast.jl:868
  [5] (::GalacticOptim.var"#268#278"{GalacticOptim.var"#267#277"{OptimizationFunction{true, GalacticOptim.AutoZygote, DiffEqFlux.var"#84#89"{typeof(loss_uode)}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing}, Nothing}})(::Vector{Float32}, ::Vector{Float32})
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/function/zygote.jl:8
  [6] macro expansion
    @ ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:27 [inlined]
  [7] macro expansion
    @ ~/.julia/packages/GalacticOptim/fow0r/src/utils.jl:35 [inlined]
  [8] __solve(prob::OptimizationProblem{, opt::ADAM, data::Base.Iterators.Cycle; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Pairs{Symbol, U)
    @ GalacticOptim ~/.julia/packages/GalacticOptim/fow0r/src/solve/flux.jl:25
  [9] #solve#482
    @ ~/.julia/packages/SciMLBase/OHiiA/src/solve.jl:3 [inlined]
 [10] sciml_train(::typeof(loss_uode), ::Vector{Float32}, ::ADAM, ::Nothing; lower_bounds::Nothing, upper_bounds::Nothing, maxiters::Int64, kwargs::Base.Pairs{Symbol, U)
    @ DiffEqFlux ~/.julia/packages/DiffEqFlux/vJuRw/src/train.jl:91
 [11] top-level scope
    @ ~/Dropbox/sandbox/julia_gend_univ/experiments/uode_benchmarking/mwe.jl:71

Here is the code to replicate the problem. I hardcoded some data so that anyone can
precisely replicate the issue. The code itself is based on some of the examples in the
Universal ODE github repo.

using StatsBase
using Plots 
using DifferentialEquations
using DiffEqFlux
using StatsPlots
using ComponentArrays
using Distributions
using DifferentialEquations.EnsembleAnalysis
using SciMLBase
using DiffEqCallbacks
using GalacticOptim
using NLopt
using NNlib
using Statistics
using Flux
gr();

rbf(x) = exp.(-(x.^2))

function get_nn(l::Integer, nodes::Integer)
        FastChain(
        FastDense(6, nodes, rbf),
        (FastDense(nodes, nodes, rbf) for _ in 1:l)...,
        FastDense(nodes, 6))
end;

U = get_nn(1, 10)

pinit = initial_params(U)

theta = ComponentVector{Float64}(rattr_f1 = 0.042687970552806695, rattr_f2 = -0.13366589419787755, rattr_f3 = -0.005666828159994795, rattr_m1 = -0.001969619987196797, rattr_m2 = -0.057685466652091864, rattr_m3 = 0.03484072549732818, rhire_f1 = -0.02268797055304856, rhire_f2 = 0.14466589419764714, rhire_f3 = 0.01666682815999726, rhire_m1 = 0.021969619987239963, rhire_m2 = 0.07768546665210119, rhire_m3 = -0.01484072549733023, rprom_f1 = -0.12398041712603916, rprom_f2 = 0.05063523685785915, rprom_f3 = 0.0, rprom_m1 = -0.023528128109958992, rprom_m2 = 0.11427635100343723, rprom_m3 = 0.0, growth_rate_linear = 0.01)

raw_data = [3.3064507529453775, 2.411147842897867, 3.8541953698304243, 8.515747060715876, 18.034486297095288, 32.886093862709245, 4.48549960824742, 2.5562275466943074, 4.058708308896289, 11.51123402926337, 18.08620682567805, 33.57792521585569, 6.0468671504292155, 2.6907051307235283, 4.397601585885496, 14.814845644929573, 18.08529694293005, 34.252912163683725, 7.192672397387545, 2.937028837016381, 4.709731483233916, 18.723508048394834, 17.436642346365506, 34.89649452732027, 7.352580291945779, 3.4231308506611935, 4.888737719134746, 23.479663335819453, 16.209855858422877, 35.33931656393878, 7.2331655891210325, 4.131368800956079, 4.925093311266256, 28.673073555019325, 15.298779271098546, 35.543004556821245, 7.369935260789615, 4.7141389130161775, 4.678267058817686, 32.549852535972846, 14.849690091055214, 35.522268257020734, 8.307311635173907, 4.68681368146203, 4.477904053217947, 34.35541746900183, 15.452784944041495, 35.78857036748951, 10.143679063434147, 4.503984620423812, 4.419740291236412, 35.61459125937783, 16.042619087690753, 36.75482632364377, 11.98711821793013, 4.408672346554513, 4.895176737702706, 37.02628984889759, 16.003122755644856, 37.83154651006416, 13.654191894783967, 4.321165531477556, 5.602231066759056, 39.42383172231546, 15.980116848398685, 38.360016616294025, 15.07127136879804, 4.495640140672605, 6.186352487270125, 41.47585581747662, 16.727453775574347, 37.662018081417905, 15.243668538849033, 5.042530500538143, 6.328609512332072, 41.12613313537764, 19.000477841200336, 36.2447375671493, 14.613889931882971, 5.935769188767932, 6.199844065932502, 37.692953653917044, 22.89764644624489, 35.626941568081406, 13.694132960603485, 7.265264704518828, 5.931120003995033, 32.681378551084386, 26.61944037304275, 36.04946436642747, 12.989840749786113, 8.41647817962317, 5.874204204500485, 28.09604869428785, 28.0865626319234, 36.84136730252351, 12.460187701403568, 8.963034340620466, 6.496033181057384, 24.678755668173096, 27.472004139049975, 37.611189827329774, 12.60904088368306, 9.267867186695572, 7.099502466318395, 23.266513884738565, 25.25166618362897, 38.75528565771812, 13.178193080833045, 9.741361564156664, 7.786274544862907, 23.17812389142553, 23.045696601549917, 39.809506965115844, 13.990735014147583, 10.013169718744082, 8.982072355245323, 23.985936218734057, 20.230054255412583, 41.73550295219455, 14.99439640550598, 9.347386530026874, 10.941111084027392, 24.595027358946172, 17.757242325900023, 43.47253111918933, 14.877425968003662, 8.016118599612227, 13.350653962924648, 24.511605087832592, 15.5650983784468, 43.9015285476356, 14.563830542982078, 6.542545706368523, 15.775065661175315, 24.378778838851183, 14.539469120814292, 42.98157944534187, 14.057476431389063, 5.696368873306736, 17.02074365038269, 23.889658063571883, 14.724942331998033, 41.03721512324303, 14.113967621903988, 5.306870615252368, 17.62434605139526, 23.76767610436216, 15.317268849390269, 39.450457715937425, 14.229422124813736, 5.118940992390327, 17.828337152051862, 24.02635906876211, 16.164701004347, 38.93877340118323, 13.655222792709683, 5.07773075279201, 17.35735286039698, 23.928076206731614, 16.565210177074103, 39.28698134784897, 13.319193222138987, 4.665122538459729, 16.671204260330004, 23.8006495075749, 17.059360191354823, 39.78697306031208, 12.456792318264581, 4.308599089715807, 15.665121353876547, 24.123939210471914, 17.9263437802465, 40.263216511865735, 11.711382659973731, 3.6816219282402844, 15.099750324843555, 24.956045460013552, 19.03089773993017, 41.29472719013936, 11.215236446812579, 2.904896534655703, 14.69177798658438, 25.553490886574302, 20.365736659078202, 42.549245707105925]

full_data = reshape(raw_data, (6, 31))


function genduniv_ode_ude!(du, u, p, t, q)

    û = U(u, p) # network prediction 

    du[1] = q.rhire_f1*u[1] - q.rattr_f1*u[1] - q.rprom_f1*u[1] - û[1]
    du[2] = q.rhire_f2*u[2] + q.rprom_f1*u[1] - q.rattr_f2*u[2] - q.rprom_f2*u[2] + û[2] 
    du[3] = q.rhire_f3*u[3] + q.rprom_f2*u[2] - q.rattr_f3*u[3] - û[3]
    du[4] = q.rhire_m1*u[4] - q.rattr_m1*u[4] - q.rprom_m1*u[4] - û[4]
    du[5] = q.rhire_m2*u[5] + q.rprom_m1*u[4] - q.rattr_m2*u[5] - q.rprom_m2*u[5] + û[5]
    du[6] = q.rhire_m3*u[6] + q.rprom_m2*u[5] - q.rattr_m3*u[6] + û[6]
    du
end;

nn_dynamics!(du, u, p, t) = genduniv_ode_ude!(du, u, p, t, theta)


u0 = full_data[:, 1]
tspan = (1.0, 31.0)

prob_nn = ODEProblem(nn_dynamics!, u0, tspan, pinit)

function predict_uode(params)
    remake(prob_nn, p=params)
    Array(solve(prob_nn, Vern7(), saveat = 1.0))
end;

function loss_uode(p)
    pred = predict_uode(p)
    loss = sum(abs2, full_data .- pred) #+ Float32(1e-4)*sum(abs2, p)/length(p) # Just sum of squared error
    return loss, pred
end;

callback_display = function (p, l, pred)
    display(l)
    return false  
end

result_ode_uode = DiffEqFlux.sciml_train(loss_uode, pinit, ADAM(0.1), cb = callback_display, maxiters = 100) 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions