-
-
Notifications
You must be signed in to change notification settings - Fork 39
Closed
Description
v0.8.20 seems to have broken autodiff code that worked fine in v0.8.19.
MWE (excuse the clutter)
using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test #using Plots
using DiffEqBase: get_tmp, dualcache
using ComponentArrays
using Parameters
using ForwardDiff
using ReverseDiff
p = ComponentArray(lvpara=ComponentArray(α=2.2,β=1.0,δ=2.0,γ=0.4),a=1.0,b=1.0)
u0 = ComponentArray(state=ComponentArray(x=1.0,y=1.0))
ax_p = getaxes(p)
ax_u = getaxes(u0)
chunk_size(dual::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N
select(a::AbstractArray, u::AbstractArray) = a
select(dc::DiffEqBase.DiffCache, u) = get_tmp(dc,u)
select(a::AbstractArray, u::ReverseDiff.TrackedArray) = begin
x = similar(u,size(a))
x .= a
x
end
struct LotkaVolterra{T}
d::T
end
function (lv::LotkaVolterra)(du,u_,p,t)
d = select(lv.d,u_)
u = ComponentArray(u_,ax_u)
p = ComponentArray(p,ax_p)
@unpack x,y = u.state
@unpack lvpara, a, b = p
@unpack α, β, δ, γ = lvpara
d[1] = a^2+b^2
d[2] = b^2-a^2
du[1] = dx = (α - β*y)x + d[1]
du[2] = dy = (δ*x - γ)y + d[2]
end
u0 = Array(u0)
p = Array(p)
lv = LotkaVolterra(dualcache(zeros(2),Val{6}))
prob = ODEProblem(lv,u0,(0.0,1.0),p)
function predict_rd(p)
Array(solve(prob,Tsit5(),p=p,saveat=0.1,reltol=1e-4,sensealg=ForwardDiffSensitivity()))
end
loss_rd(p) = sum(abs2,x-1 for x in predict_rd(p))
opt = ADAM(0.1)
cb = function (p,l,pred)
display(loss_rd(p))
#display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6)))
end
@time res = DiffEqFlux.sciml_train(loss_rd, p, opt, maxiters=100)Error message:
MethodError: no method matching ComponentArray(::var"#7#8", ::Array{Float64,2})
Closest candidates are:
ComponentArray(::Any, !Matched::FlatAxis...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:50
ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}} where IdxMap...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:51
ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}, ShapedAxis{Shape,IdxMap}} where IdxMap where Shape...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:52
...
rrule(::UnionAll, ::Function, ::Array{Float64,2}) at chainrulescore.jl:23
chain_rrule at chainrules.jl:89 [inlined]
macro expansion at interface2.jl:0 [inlined]
_pullback(::Zygote.Context, ::Type{Base.Generator}, ::var"#7#8", ::Array{Float64,2}) at interface2.jl:9
loss_rd at lotka_volterra.jl:47 [inlined]
_pullback(::Zygote.Context, ::typeof(loss_rd), ::Array{Float64,1}) at interface2.jl:0
#69 at train.jl:3 [inlined]
_pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss_rd)}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Tuple{Array{Float64,1},SciMLBase.NullParameters}) at none:0
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
#8 at solve.jl:94 [inlined]
_pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}},Array{Float64,1},GalacticOptim.NullData}) at interface2.jl:0
pullback(::Function, ::Params) at interface.jl:167
gradient(::Function, ::Params) at interface.jl:48
__solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at solve.jl:93
__solve at solve.jl:66 [inlined]
__solve at solve.jl:6...
Metadata
Metadata
Assignees
Labels
No labels