Skip to content

ERROR: Compiling ...: try/catch is not supported. #660

@JinraeKim

Description

@JinraeKim

I'm trying to find which operation works with DiffEqFlux.jl.

For now, I'm trying to make the output of ODE a dictionary like ode_data = (; t=_sol.t, u=_sol.u).
The following code is modified from a tutorial.
Note that it works when the dictionary is replaced by NamedTuple.

What should I do if I wanna incorporate dictionary-like data form?

  • Code
using DiffEqFlux, DifferentialEquations, Plots
using DataFrames
gr()
# unicodeplots()


function main()
    # u0 = Float32[2.; 0.]
    u0 = [2.0, 1.0]
    datasize = 30
    # tspan = (0.0f0,1.5f0)
    tspan = (0.0, 1.5)

    function trueODEfunc(du,u,p,t)
        true_A = [-0.1 2.0; -2.0 -0.1]
        du .= ((u.^3)'true_A)'
    end
    t = range(tspan[1], tspan[2], length=datasize)
    prob = ODEProblem(trueODEfunc,u0,tspan)
    _sol = solve(prob, Tsit5(), saveat=t)
    # ode_data = _sol.u
    # ode_data = (; t=_sol.t, u=_sol.u)
    ode_data = Dict("t" => _sol.t, "u" => _sol.u)
    # ode_data = DataFrame(
    #                time = _sol.t,
    #                sol = _sol.u,
    #               )
    # ode_data = Array(solve(prob,Tsit5(),saveat=t))
    @show ode_data

    dudt2 = Chain(
                  Dense(2,50,tanh),
                  Dense(50,2),
                 )
    p,re = Flux.destructure(dudt2) # use this p as the initial condition!
    dudt(u,p,t) = re(p)(u) # need to restructure for backprop!
    prob = ODEProblem(dudt,u0,tspan)

    function predict_n_ode()
        # Array(solve(prob,Tsit5(),u0=u0,p=p,saveat=t))
        _sol = solve(prob,Tsit5(),u0=u0,p=p,saveat=t)
        # sol = _sol.u
        # (; t=_sol.t, u=_sol.u)
        Dict("t" => _sol.t, "u" => _sol.u)
        # sol = DataFrame(
        #                 time = _sol.t,
        #                 sol = _sol.u,
        #                )
    end

    function loss_n_ode()
        pred = predict_n_ode()
        loss = sum(abs2, hcat((ode_data["u"] .- pred["u"])...))
        loss
    end

    loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE

    cb = function (;doplot=false) # callback function to observe training
        pred = predict_n_ode()
        # display(sum(abs2, hcat((ode_data .- pred)...)))
        display(sum(abs2, hcat((ode_data["u"] .- pred["u"])...)))
        # plot current prediction against data
        pl = scatter(t, hcat(ode_data["u"]...)', label="data")
        scatter!(pl, t, hcat(pred["u"]...)', label="prediction")
        display(plot(pl))
        return false
    end

    # Display the ODE with the initial parameter values.
    cb()

    data = Iterators.repeated((), 1000)
    Flux.train!(loss_n_ode, Flux.params(u0, p), data, ADAM(0.05), cb = cb)
end
  • Result
julia> main()
ode_data = Dict{Symbol, Vector}(:u => [[2.0, 1.0], [1.6862338960057244, 1.6718793919756492], [1.0319926450388706, 1.9265992260865], [0.27927162692550755, 1.9277348444612303], [-0.44108061884417915, 1.8905379645843547], [-1.099394804047414, 1.7998796981753649], [-1.575819586008001, 1.508232110833719], [-1.7685434957097326, 0.9927467752939698], [-1.7808339647673836, 0.4048093142853212], [-1.7533838123474577, -0.1665247268985358], [-1.7142675840018735, -0.7076658551805182], [-1.5951207960146605, -1.1811846904105572], [-1.3201585947206524, -1.4991858648843748], [-0.9059421368674584, -1.6312753851031085], [-0.4451488059926799, -1.643953795405061], [0.006614913904206654, -1.6235362077639393], [0.4399648514335317, -1.5996393052658449], [0.8454003544551636, -1.5488250899589673], [1.1839973764833611, -1.4178276628884785], [1.4036848081489357, -1.1758209616758337], [1.4998342218933636, -0.8491937206357362], [1.515299249134128, -0.48998366495261025], [1.501666229797059, -0.13402233642947514], [1.4843567660530883, 0.21026351700407508], [1.461124734420293, 0.5412488731339737], [1.4092246360030882, 0.8476591177188557], [1.297019818716926, 1.1026048775496469], [1.109096726239168, 1.2780039297165904], [0.8612172546154794, 1.3673022234110102], [0.5844235838429759, 1.3940953012997077]], :t => [0.0, 0.05172413793103448, 0.10344827586206896, 0.15517241379310345, 0.20689655172413793, 0.25862068965517243, 0.3103448275862069, 0.3620689655172414, 0.41379310344827586, 0.46551724137931033, 0.5172413793103449, 0.5689655172413793, 0.6206896551724138, 0.6724137931034483, 0.7241379310344828, 0.7758620689655172, 0.8275862068965517, 0.8793103448275862, 0.9310344827586207, 0.9827586206896551, 1.0344827586206897, 1.0862068965517242, 1.1379310344827587, 1.1896551724137931, 1.2413793103448276, 1.293103448275862, 1.3448275862068966, 1.396551724137931, 1.4482758620689655, 1.5])
160.55385058999687
ERROR: Compiling Tuple{Type{Dict}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Vector{Float64}}}}}: try/catch is not supported.
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] instrument(ir::IRTools.Inner.IR)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:121
  [3] #Primal#20
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:202 [inlined]
  [4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:315
  [5] _generate_pullback_via_decomposition(T::Type)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/emit.jl:101
  [6] #s3063#1218
    @ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:28 [inlined]
  [7] var"#s3063#1218"(::Any, ctx::Any, f::Any, args::Any)
    @ Zygote ./none:0
  [8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:580
  [9] _pullback
    @ ./dict.jl:125 [inlined]
 [10] _pullback(::Zygote.Context, ::Type{Dict}, ::Pair{Symbol, Vector{Float64}}, ::Pair{Symbol, Vector{Vector{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:44 [inlined]
 [12] _pullback(::Zygote.Context, ::var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:52 [inlined]
 [14] _pullback(::Zygote.Context, ::var"#loss_n_ode#411"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [15] _apply
    @ ./boot.jl:814 [inlined]
 [16] adjoint
    @ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
 [17] _pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
 [18] _pullback
    @ ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:105 [inlined]
 [19] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{var"#loss_n_ode#411"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}}, Tuple{}})
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
 [20] pullback(f::Function, ps::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:351
 [21] gradient(f::Function, args::Params)
    @ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:75
 [22] macro expansion
    @ ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:104 [inlined]
 [23] macro expansion
    @ ~/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
 [24] train!(loss::Function, ps::Params, data::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, opt::ADAM; cb::var"#406#412"{var"#406#407#413"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}})
    @ Flux.Optimise ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:102
 [25] main()
    @ Main ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:74
 [26] top-level scope
    @ REPL[38]:1
 [27] top-level scope
    @ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions