Skip to content

Autodiff broken in latest release #67

@bgroenks96

Description

@bgroenks96

v0.8.20 seems to have broken autodiff code that worked fine in v0.8.19.

MWE (excuse the clutter)

using DiffEqFlux, DiffEqSensitivity, Flux, OrdinaryDiffEq, Zygote, Test #using Plots
using DiffEqBase: get_tmp, dualcache
using ComponentArrays
using Parameters
using ForwardDiff
using ReverseDiff

p = ComponentArray(lvpara=ComponentArray=2.2=1.0=2.0=0.4),a=1.0,b=1.0)
u0 = ComponentArray(state=ComponentArray(x=1.0,y=1.0))
ax_p = getaxes(p)
ax_u = getaxes(u0)

chunk_size(dual::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N} = N
select(a::AbstractArray, u::AbstractArray) = a
select(dc::DiffEqBase.DiffCache, u) = get_tmp(dc,u)
select(a::AbstractArray, u::ReverseDiff.TrackedArray) = begin
  x = similar(u,size(a))
  x .= a
  x
end

struct LotkaVolterra{T}
  d::T
end

function (lv::LotkaVolterra)(du,u_,p,t)
  d = select(lv.d,u_)
  u = ComponentArray(u_,ax_u)
  p = ComponentArray(p,ax_p)
  @unpack x,y = u.state
  @unpack lvpara, a, b = p
  @unpack α, β, δ, γ = lvpara
  d[1] = a^2+b^2
  d[2] = b^2-a^2
  du[1] = dx =- β*y)x + d[1]
  du[2] = dy =*x - γ)y + d[2]
end
u0 = Array(u0)
p = Array(p)
lv = LotkaVolterra(dualcache(zeros(2),Val{6}))
prob = ODEProblem(lv,u0,(0.0,1.0),p)
function predict_rd(p)
  Array(solve(prob,Tsit5(),p=p,saveat=0.1,reltol=1e-4,sensealg=ForwardDiffSensitivity()))
end
loss_rd(p) = sum(abs2,x-1 for x in predict_rd(p))

opt = ADAM(0.1)
cb = function (p,l,pred)
  display(loss_rd(p))
  #display(plot(solve(remake(prob,p=p),Tsit5(),saveat=0.1),ylim=(0,6)))
end

@time res = DiffEqFlux.sciml_train(loss_rd, p, opt, maxiters=100)

Error message:

MethodError: no method matching ComponentArray(::var"#7#8", ::Array{Float64,2})
Closest candidates are:
  ComponentArray(::Any, !Matched::FlatAxis...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:50
  ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}} where IdxMap...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:51
  ComponentArray(::Any, !Matched::Union{FlatAxis, ComponentArrays.NullAxis, Axis{IdxMap}, ShapedAxis{Shape,IdxMap}} where IdxMap where Shape...) at /home/brian/.julia/packages/ComponentArrays/zHt90/src/componentarray.jl:52
  ...
rrule(::UnionAll, ::Function, ::Array{Float64,2}) at chainrulescore.jl:23
chain_rrule at chainrules.jl:89 [inlined]
macro expansion at interface2.jl:0 [inlined]
_pullback(::Zygote.Context, ::Type{Base.Generator}, ::var"#7#8", ::Array{Float64,2}) at interface2.jl:9
loss_rd at lotka_volterra.jl:47 [inlined]
_pullback(::Zygote.Context, ::typeof(loss_rd), ::Array{Float64,1}) at interface2.jl:0
#69 at train.jl:3 [inlined]
_pullback(::Zygote.Context, ::DiffEqFlux.var"#69#70"{typeof(loss_rd)}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
adjoint(::Zygote.Context, ::typeof(Core._apply_iterate), ::typeof(iterate), ::OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing}, ::Tuple{Array{Float64,1},SciMLBase.NullParameters}) at none:0
_pullback at adjoint.jl:57 [inlined]
OptimizationFunction at basic_problems.jl:107 [inlined]
_pullback(::Zygote.Context, ::OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing}, ::Array{Float64,1}, ::SciMLBase.NullParameters) at interface2.jl:0
adjoint at lib.jl:188 [inlined]
_pullback at adjoint.jl:57 [inlined]
#8 at solve.jl:94 [inlined]
_pullback(::Zygote.Context, ::GalacticOptim.var"#8#13"{OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}},Array{Float64,1},GalacticOptim.NullData}) at interface2.jl:0
pullback(::Function, ::Params) at interface.jl:167
gradient(::Function, ::Params) at interface.jl:48
__solve(::OptimizationProblem{false,OptimizationFunction{false,GalacticOptim.AutoZygote,OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},GalacticOptim.var"#146#156"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#149#159"{GalacticOptim.var"#145#155"{OptimizationFunction{true,GalacticOptim.AutoZygote,DiffEqFlux.var"#69#70"{typeof(loss_rd)},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},Nothing}},GalacticOptim.var"#154#164",Nothing,Nothing,Nothing},Array{Float64,1},SciMLBase.NullParameters,Nothing,Nothing,Nothing,Base.Iterators.Pairs{Symbol,Int64,Tuple{Symbol},NamedTuple{(:maxiters,),Tuple{Int64}}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{GalacticOptim.NullData}}; maxiters::Int64, cb::Function, progress::Bool, save_best::Bool, kwargs::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}) at solve.jl:93
__solve at solve.jl:66 [inlined]
__solve at solve.jl:6...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions