Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why is Flux.destructure type unstable? #2405

Closed
irisallevi opened this issue Mar 21, 2024 · 3 comments
Closed

Why is Flux.destructure type unstable? #2405

irisallevi opened this issue Mar 21, 2024 · 3 comments

Comments

@irisallevi
Copy link

irisallevi commented Mar 21, 2024

I was building a simple model and at some point I needed to "unroll" it to get all the parameters in an array.

So I tired with Flux.destructure. I got some type instability, so I checked the documentation and I tried with the example provided there:

model = Chain(Dense(2 => 1, tanh), Dense(1 => 1))
@code_warntype Flux.destructure(model)

But this gives a type instability as well!

Flux.destructure(model) = (Float32[0.27410066, 0.6508191, 0.0, 0.16767712, 0.0], Restructure(Chain, ..., 5))
MethodInstance for Optimisers.destructure(::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from destructure(x) @ Optimisers
Arguments
  #self#::Core.Const(Optimisers.destructure)
  x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  @_3::Int64
  len::Int64
  off::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}
  flat::AbstractVector
Body::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
1 ─ %1  = Optimisers._flatten(x)::Tuple{AbstractVector, NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}
│   %2  = Base.indexed_iterate(%1, 1)::Core.PartialStruct(Tuple{AbstractVector, Int64}, Any[AbstractVector, Core.Const(2)])
│         (flat = Core.getfield(%2, 1))
│         (@_3 = Core.getfield(%2, 2))
│   %5  = Base.indexed_iterate(%1, 2, @_3::Core.Const(2))::Core.PartialStruct(Tuple{NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Int64}, Any[NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}}, Core.Const(3)])
│         (off = Core.getfield(%5, 1))
│         (@_3 = Core.getfield(%5, 2))
│   %8  = Base.indexed_iterate(%1, 3, @_3::Core.Const(3))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(4)])
│         (len = Core.getfield(%8, 1))
│   %10 = flat::AbstractVector
│   %11 = Optimisers.Restructure(x, off, len)::Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})
│   %12 = Core.tuple(%10, %11)::Tuple{AbstractVector, Optimisers.Restructure{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, S} where S<:(NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}, NamedTuple{(:weight, :bias, :σ), <:Tuple{Any, Any, Tuple{}}}}}})}
└──       return %12

What am I missing?

@mcabbott
Copy link
Member

mcabbott commented Mar 22, 2024

Flux uses Functors.jl for all kinds of recursive walks, and when there are mutable objects, it keeps a cache of their objectIDs to look for duplicates. This means that it branches on the values of objectid, and this is usually type-unstable. You can see it for instance in @code_warntype f64(model).

The reason it does this is to allow for shared parameters. The same array may appear multiple times. To be honest this is a giant pain, maybe it's a feature not worth preserving? Forbidding it would simplify many things. It is turned off for e.g. models with SMatrix parameters (which have no identitty beyond their value).

Type stability is super-important deep inside tight loops, and hence drummed into us when learning Julia, but often doesn't matter at all for larger objects. E.g. removing type parameters from Flux layers often has no impact on performance, as there are enough function barriers between there and operations which take all the time.

Having said all that, I'm not 100% sure this isn't an XY problem. Is your question actually why it is unstable, or are you really implying that you believe this is the cause of a performance problem?

@mcabbott
Copy link
Member

Same Q on discourse here, please link things so as not to waste time on duplicates.

Here's a quick example to show some design differences between ComponentArrays and Optimisers.destructure:

julia> using ComponentArrays, Optimisers

julia> arr = [1.0, 2.0];

julia> v, re = Optimisers.destructure((one=arr, two=[3f0], three=arr))  # this notices & preserves x1 === x3
([1.0, 2.0, 3.0], Restructure(NamedTuple, ..., 3))

julia> v .= 99;  # this does not mutate arr, v is a copy

julia> nt = re([10, 20, 30.0])
(one = [10.0, 20.0], two = Float32[30.0], three = [10.0, 20.0])

julia> nt.one === nt.three  # identity is restored
true

julia> ca = ComponentArray(one=arr, two=[3f0], three=arr)  # this ignores the identity
ComponentVector{Float64}(one = [1.0, 2.0], two = [3.0], three = [1.0, 2.0])

julia> getfield(ca, :data)
5-element Vector{Float64}:
 1.0
 2.0
 3.0
 1.0
 2.0

julia> ca.two  # type has been promoted on construction
1-element view(::Vector{Float64}, 3:3) with eltype Float64:
 3.0

julia> ca.three .= 99;  # structured form is a view of flat form

julia> ca
ComponentVector{Float64}(one = [1.0, 2.0], two = [3.0], three = [99.0, 99.0])

@irisallevi
Copy link
Author

Thank you @mcabbott and sorry for not linking. These component arrays seem very nice, especially as you can easily acess them and (apparently) mutate them in place.

#Having said all that, I'm not 100% sure this isn't an XY problem. Is your question actually why it is unstable, or are you really implying that you believe this is the cause of a performance problem?

Both actually. Since the code is quite simple for now, I'd like to have most that I can under control. So I'd like to understand what is going on.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants