-
-
Notifications
You must be signed in to change notification settings - Fork 161
Closed
Description
I'm trying to find which operation works with DiffEqFlux.jl.
For now, I'm trying to make the output of ODE a dictionary like ode_data = (; t=_sol.t, u=_sol.u).
The following code is modified from a tutorial.
Note that it works when the dictionary is replaced by NamedTuple.
What should I do if I wanna incorporate dictionary-like data form?
- Code
using DiffEqFlux, DifferentialEquations, Plots
using DataFrames
gr()
# unicodeplots()
function main()
# u0 = Float32[2.; 0.]
u0 = [2.0, 1.0]
datasize = 30
# tspan = (0.0f0,1.5f0)
tspan = (0.0, 1.5)
function trueODEfunc(du,u,p,t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc,u0,tspan)
_sol = solve(prob, Tsit5(), saveat=t)
# ode_data = _sol.u
# ode_data = (; t=_sol.t, u=_sol.u)
ode_data = Dict("t" => _sol.t, "u" => _sol.u)
# ode_data = DataFrame(
# time = _sol.t,
# sol = _sol.u,
# )
# ode_data = Array(solve(prob,Tsit5(),saveat=t))
@show ode_data
dudt2 = Chain(
Dense(2,50,tanh),
Dense(50,2),
)
p,re = Flux.destructure(dudt2) # use this p as the initial condition!
dudt(u,p,t) = re(p)(u) # need to restructure for backprop!
prob = ODEProblem(dudt,u0,tspan)
function predict_n_ode()
# Array(solve(prob,Tsit5(),u0=u0,p=p,saveat=t))
_sol = solve(prob,Tsit5(),u0=u0,p=p,saveat=t)
# sol = _sol.u
# (; t=_sol.t, u=_sol.u)
Dict("t" => _sol.t, "u" => _sol.u)
# sol = DataFrame(
# time = _sol.t,
# sol = _sol.u,
# )
end
function loss_n_ode()
pred = predict_n_ode()
loss = sum(abs2, hcat((ode_data["u"] .- pred["u"])...))
loss
end
loss_n_ode() # n_ode.p stores the initial parameters of the neural ODE
cb = function (;doplot=false) # callback function to observe training
pred = predict_n_ode()
# display(sum(abs2, hcat((ode_data .- pred)...)))
display(sum(abs2, hcat((ode_data["u"] .- pred["u"])...)))
# plot current prediction against data
pl = scatter(t, hcat(ode_data["u"]...)', label="data")
scatter!(pl, t, hcat(pred["u"]...)', label="prediction")
display(plot(pl))
return false
end
# Display the ODE with the initial parameter values.
cb()
data = Iterators.repeated((), 1000)
Flux.train!(loss_n_ode, Flux.params(u0, p), data, ADAM(0.05), cb = cb)
end- Result
julia> main()
ode_data = Dict{Symbol, Vector}(:u => [[2.0, 1.0], [1.6862338960057244, 1.6718793919756492], [1.0319926450388706, 1.9265992260865], [0.27927162692550755, 1.9277348444612303], [-0.44108061884417915, 1.8905379645843547], [-1.099394804047414, 1.7998796981753649], [-1.575819586008001, 1.508232110833719], [-1.7685434957097326, 0.9927467752939698], [-1.7808339647673836, 0.4048093142853212], [-1.7533838123474577, -0.1665247268985358], [-1.7142675840018735, -0.7076658551805182], [-1.5951207960146605, -1.1811846904105572], [-1.3201585947206524, -1.4991858648843748], [-0.9059421368674584, -1.6312753851031085], [-0.4451488059926799, -1.643953795405061], [0.006614913904206654, -1.6235362077639393], [0.4399648514335317, -1.5996393052658449], [0.8454003544551636, -1.5488250899589673], [1.1839973764833611, -1.4178276628884785], [1.4036848081489357, -1.1758209616758337], [1.4998342218933636, -0.8491937206357362], [1.515299249134128, -0.48998366495261025], [1.501666229797059, -0.13402233642947514], [1.4843567660530883, 0.21026351700407508], [1.461124734420293, 0.5412488731339737], [1.4092246360030882, 0.8476591177188557], [1.297019818716926, 1.1026048775496469], [1.109096726239168, 1.2780039297165904], [0.8612172546154794, 1.3673022234110102], [0.5844235838429759, 1.3940953012997077]], :t => [0.0, 0.05172413793103448, 0.10344827586206896, 0.15517241379310345, 0.20689655172413793, 0.25862068965517243, 0.3103448275862069, 0.3620689655172414, 0.41379310344827586, 0.46551724137931033, 0.5172413793103449, 0.5689655172413793, 0.6206896551724138, 0.6724137931034483, 0.7241379310344828, 0.7758620689655172, 0.8275862068965517, 0.8793103448275862, 0.9310344827586207, 0.9827586206896551, 1.0344827586206897, 1.0862068965517242, 1.1379310344827587, 1.1896551724137931, 1.2413793103448276, 1.293103448275862, 1.3448275862068966, 1.396551724137931, 1.4482758620689655, 1.5])
160.55385058999687
ERROR: Compiling Tuple{Type{Dict}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Vector{Float64}}}}}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:121
[3] #Primal#20
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:202 [inlined]
[4] Zygote.Adjoint(ir::IRTools.Inner.IR; varargs::Nothing, normalise::Bool)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/reverse.jl:315
[5] _generate_pullback_via_decomposition(T::Type)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/emit.jl:101
[6] #s3063#1218
@ ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:28 [inlined]
[7] var"#s3063#1218"(::Any, ctx::Any, f::Any, args::Any)
@ Zygote ./none:0
[8] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
@ Core ./boot.jl:580
[9] _pullback
@ ./dict.jl:125 [inlined]
[10] _pullback(::Zygote.Context, ::Type{Dict}, ::Pair{Symbol, Vector{Float64}}, ::Pair{Symbol, Vector{Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[11] _pullback
@ ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:44 [inlined]
[12] _pullback(::Zygote.Context, ::var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[13] _pullback
@ ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:52 [inlined]
[14] _pullback(::Zygote.Context, ::var"#loss_n_ode#411"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[15] _apply
@ ./boot.jl:814 [inlined]
[16] adjoint
@ ~/.julia/packages/Zygote/bJn8I/src/lib/lib.jl:200 [inlined]
[17] _pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:65 [inlined]
[18] _pullback
@ ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:105 [inlined]
[19] _pullback(::Zygote.Context, ::Flux.Optimise.var"#39#45"{var"#loss_n_ode#411"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}}, Tuple{}})
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface2.jl:0
[20] pullback(f::Function, ps::Params)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:351
[21] gradient(f::Function, args::Params)
@ Zygote ~/.julia/packages/Zygote/bJn8I/src/compiler/interface.jl:75
[22] macro expansion
@ ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:104 [inlined]
[23] macro expansion
@ ~/.julia/packages/Juno/n6wyj/src/progress.jl:134 [inlined]
[24] train!(loss::Function, ps::Params, data::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{}}}, opt::ADAM; cb::var"#406#412"{var"#406#407#413"{var"#predict_n_ode#410"{Vector{Float32}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, Vector{Float64}}, Dict{Symbol, Vector}, StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}})
@ Flux.Optimise ~/.julia/packages/Flux/BPPNj/src/optimise/train.jl:102
[25] main()
@ Main ~/.julia/dev/ContinuousTimePolicyGradients/test/model-estimation/toy.jl:74
[26] top-level scope
@ REPL[38]:1
[27] top-level scope
@ ~/.julia/packages/CUDA/YpW0k/src/initialization.jl:52
Metadata
Metadata
Assignees
Labels
No labels