Skip to content

Error: "NamedTuple has no field axes" #22

@metanoid

Description

@metanoid

It seems that the ComponentArray constructor is not differentiable.

Context: I have a two-step loss function, where I do some upfront work to estimate some parameters from the data, then predict using those parameters and others, so I'm trying to build a single parameter array combining the sets of parameters.

Reproducible code sample, adapted from docs

using ComponentArrays
using OrdinaryDiffEq
using Plots
using UnPack

using DiffEqFlux: sciml_train
using Flux: glorot_uniform, ADAM
using Optim: LBFGS

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.5f0)

dense_layer(in, out) = ComponentArray(W=glorot_uniform(out, in), b=zeros(out))

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end
t = range(tspan[1], tspan[2], length = datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat = t))


function dudt(u, p, t)
    @unpack L1, L2 = p
    return L2.W * tanh.(L1.W * u.^3 .+ L1.b) .+ L2.b
end

prob = ODEProblem(dudt, u0, tspan)

layers = (L1=dense_layer(2, 50), L2=dense_layer(50, 2))
θ = ComponentArray(u=u0, p=layers)

predict_n_ode(θ) = Array(solve(prob, Tsit5(), u0=θ.u, p=θ.p, saveat=t))

function loss_n_ode(θ)
    other_params = rand(3) # simulates additional work done
    θ2 = ComponentArray(u = θ.u, p = θ.p, other = other_params) # constructor
    pred = predict_n_ode(θ2) # changed
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end
loss_n_ode(θ)

cb = function (θ, loss, pred; doplot=false)
    display(loss)
    # plot current prediction against data
    pl = scatter(t, ode_data[1,:], label = "data")
    scatter!(pl, t, pred[1,:], label = "prediction")
    display(plot(pl))
    return false
end


cb(θ, loss_n_ode(θ)...)

data = Iterators.repeated((), 1000)

res1 = sciml_train(loss_n_ode, θ, ADAM(0.05); cb=cb, maxiters=100)
cb(res1.minimizer, loss_n_ode(res1.minimizer)...; doplot=true)

res2 = sciml_train(loss_n_ode, res1.minimizer, LBFGS(); cb=cb)
cb(res2.minimizer, loss_n_ode(res2.minimizer)...; doplot=true)

Error message:

ERROR: type NamedTuple has no field axes
Stacktrace:
 [1] getproperty at .\Base.jl:33 [inlined]
 [2] getindex at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:49 [inlined]
 [3] _broadcast_getindex_evalf at .\broadcast.jl:631 [inlined]
 [4] _broadcast_getindex at .\broadcast.jl:604 [inlined]
 [5] (::Base.Broadcast.var"#19#20"{Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}})(::Int64) at .\broadcast.jl:1024
 [6] ntuple at .\ntuple.jl:41 [inlined]
 [7] copy at .\broadcast.jl:1024 [inlined]
 [8] materialize(::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(getindex),Tuple{Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}},Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple},Nothing,typeof(ComponentArrays.getval),Tuple{Tuple{DataType}}}}}) at .\broadcast.jl:820
 [9] #s16#21(::Any, ::Any, ::Any) at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:74
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any,N} where N) at .\boot.jl:526
 [11] getproperty at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:68 [inlined]
 [12] adjoint at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\if_required\zygote.jl:10 [inlined]
 [13] _pullback(::Zygote.Context, ::typeof(ZygoteRules.literal_getproperty), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Val{:axes}) at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47
 [14] _pullback(::Zygote.Context, ::typeof(getfield), ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::Symbol) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:221
 [15] getaxes at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\set_get.jl:31 [inlined]
 [16] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:131 [inlined]
 [17] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W 
= View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}}, ::UnitRange{Int64}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [18] make_idx at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:121 [inlined]
 [19] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_idx), ::Array{Any,1}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}, ::Int64) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [20] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:111 [inlined]
 [21] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Array{Float64,1}}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [22] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:109 [inlined]
 [23] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::Type{Float64}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [24] make_carray_args at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:108 [inlined]
 [25] _pullback(::Zygote.Context, ::typeof(ComponentArrays.make_carray_args), ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [26] ComponentArray at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:64 [inlined]
 [27] _pullback(::Zygote.Context, ::Type{ComponentArray}, ::NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b 
= 101:102)))}}},Array{Float64,1}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [28] #ComponentArray#12 at C:\Users\username\.julia\packages\ComponentArrays\fNphq\src\componentarray.jl:66 [inlined]
 [29] _pullback(::Zygote.Context, ::ComponentArrays.var"##ComponentArray#12", ::Base.Iterators.Pairs{Symbol,AbstractArray{Float64,1},Tuple{Symbol,Symbol,Symbol},NamedTuple{(:u, :p, :other),Tuple{SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},ComponentArray{Float64,1,SubArray{Float64,1,Array{Float64,1},Tuple{UnitRange{Int64}},true},Tuple{Axis{(L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))}}},Array{Float64,1}}}}, ::Type{ComponentArray}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0 (repeats 2 times)
 [30] loss_n_ode at .\untitled-79d0146585cdbfb1aa13a5142027add7:39 [inlined]
 [31] _pullback(::Zygote.Context, ::typeof(loss_n_ode), ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [32] adjoint at C:\Users\username\.julia\packages\Zygote\uGBKO\src\lib\lib.jl:179 [inlined]
 [33] _pullback at C:\Users\username\.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
 [34] #24 at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:99 [inlined]
 [35] _pullback(::Zygote.Context, ::DiffEqFlux.var"#24#29"{Tuple{},typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}}) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface2.jl:0
 [36] pullback(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:172
 [37] gradient(::Function, ::Zygote.Params) at C:\Users\username\.julia\packages\Zygote\uGBKO\src\compiler\interface.jl:53
 [38] macro expansion at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:98 [inlined]
 [39] macro expansion at C:\Users\username\.julia\packages\ProgressLogging\g8xnW\src\ProgressLogging.jl:328 [inlined]
 [40] (::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params})() at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:43
 [41] maybe_with_logger(::DiffEqFlux.var"#23#28"{var"#42#44",Int64,Bool,Bool,typeof(loss_n_ode),ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}},Zygote.Params}, ::Nothing) at C:\Users\username\.julia\packages\DiffEqBase\Co6yv\src\utils.jl:259
 [42] sciml_train(::Function, ::ComponentArray{Float64,1,Array{Float64,1},Tuple{Axis{(u = 1:2, p = View(3:254, (L1 = View(1:150, (W = View(1:100, ShapedAxis((50, 2), NamedTuple())), b = 101:150)), L2 = View(151:252, (W = View(1:100, ShapedAxis((2, 50), NamedTuple())), b = 101:102)))))}}}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at C:\Users\username\.julia\packages\DiffEqFlux\7Lfxh\src\train.jl:42
 [43] top-level scope at none:0

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions