-
-
Notifications
You must be signed in to change notification settings - Fork 39
Open
Description
It seems that the ComponentArray constructor is not differentiable.
Context: I have a two-step loss function, where I do some upfront work to estimate some parameters from the data, then predict using those parameters and others, so I'm trying to build a single parameter array combining the sets of parameters.
Reproducible code sample, adapted from docs
using ComponentArrays
using OrdinaryDiffEq
using Plots
using UnPack
using DiffEqFlux: sciml_train
using Flux: glorot_uniform, ADAM
using Optim: LBFGS
u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)
dense_layer(in, out) = ComponentArray(W=glorot_uniform(out, in), b=zeros(out))
function trueODEfunc(du, u, p, t)
true_A = [-0.1 2.0; -2.0 -0.1]
du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat = t))
function dudt(u, p, t)
@unpack L1, L2 = p
return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end
prob = ODEProblem(dudt, u0, tspan)
layers = (L1=dense_layer(2, 50), L2=dense_layer(50, 2))
θ = ComponentArray(u=u0, p=layers)
predict_n_ode(θ) = Array(solve(prob, Tsit5(), u0=θ.u, p=θ.p, saveat=t))
function loss_n_ode(θ)
other_params = rand(3) # simulates additional work done
θ2 = ComponentArray(u = θ.u, p = θ.p, other = other_params) # constructor
pred = predict_n_ode(θ2) # changed
loss = sum(abs2, ode_data .- pred)
return loss, pred
end
loss_n_ode(θ)
cb = function (θ, loss, pred; doplot=false)
display(loss)
# plot current prediction against data
pl = scatter(t, ode_data[1,:], label = "data")
scatter!(pl, t, pred[1,:], label = "prediction")
display(plot(pl))
return false
end
cb(θ, loss_n_ode(θ)...)
data = Iterators.repeated((), 1000)
res1 = sciml_train(loss_n_ode, θ, ADAM(0.05); cb=cb, maxiters=100)
cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)
res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)Error message:
ERROR: type NamedTuple has no field axes
Stacktrace:
[1] getproperty at .\Base.jl:33 [inlined]
[2] getindex at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:49 [inlined]
[3] _broadcast_getindex_evalf at .\broadcast.jl:631 [inlined]
[4] _broadcast_getindex at .\broadcast.jl:604 [inlined]
[5] (::Base.Broadcast.var"#19#20"{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}})(::Int64) at .\broadcast.jl:1024
[6] ntuple at .\ntuple.jl:41 [inlined]
[7] copy at .\broadcast.jl:1024 [inlined]
[8] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}) at .\broadcast.jl:820
[9] #s16#21(::Any, ::Any, ::Any) at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:74
[10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at .\boot.jl:526
[11] getproperty at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:68 [inlined]
[12] adjoint at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\if_required\zygote.jl:10 [inlined]
[13] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Val{:axes}) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47
[14] _pullback(::Zygote.Context, ::typeof(getfield), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Symbol) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:221
[15] getaxes at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:31 [inlined]
[16] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:131 [inlined]
[17] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W
= View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::UnitRange{Int64}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[18] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:121 [inlined]
[19] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}, ::Int64) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[20] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:111 [inlined]
[21] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Array{Float64,1}}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[22] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:109 [inlined]
[23] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Float64}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[24] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:108 [inlined]
[25] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[26] ComponentArray at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:64 [inlined]
[27] _pullback(::Zygote.Context, ::Type{ComponentArray}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b
= 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[28] #ComponentArray#12 at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:66 [inlined]
[29] _pullback(::Zygote.Context, ::ComponentArrays.var"##ComponentArray#12", ::Base.Iterators.Pairs{Symbol,AbstractArray{Float64,1},Tuple{Symbol,Symbol,Symbol},NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}}, ::Type{ComponentArray}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0 (repeats 2 times)
[30] loss_n_ode at .\untitled-79d0146585cdbfb1aa13a5142027add7:39 [inlined]
[31] _pullback(::Zygote.Context, ::typeof(loss_n_ode), ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[32] adjoint at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:179 [inlined]
[33] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
[34] #24 at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:99 [inlined]
[35] _pullback(::Zygote.Context, ::DiffEqFlux.var"#24#29"{Tuple{},typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
[36] pullback(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:172
[37] gradient(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:53
[38] macro expansion at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:98 [inlined]
[39] macro expansion at C:\Users\username\.julia\packages\ProgressLogging\g8xnW\src\ProgressLogging.jl:328 [inlined]
[40] (::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params})() at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:43
[41] maybe_with_logger(::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params}, ::Nothing) at C:\Users\username\.julia\packages\DiffEqBase\Co6yv\src\utils.jl:259
[42] sciml_train(::Function, ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:42
[43] top-level scope at none:0
Metadata
Metadata
Assignees
Labels
No labels