Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ComponentArrays makes Custom Layers containing Chains type-unstable #417

Closed
arthur-bizzi opened this issue Oct 6, 2023 · 4 comments
Closed

Comments

@arthur-bizzi
Copy link

arthur-bizzi commented Oct 6, 2023

Hey.

I'm attempting to implement a coupling layer (as in #416 ), now with custom layers. Again, they work fine with NamedTuple but not with ComponentArray. This doesn't seem to be related to #416 . This is possibly due to me not understanding the role of AbstractExplicitContainerLayer and therefore not using it.

The issue appears when the custom layer contains a Chain of other Lux layers, but not when they contain a single layer.
Here's some code:

using Lux, ComponentArrays, Random
rng = Random.default_rng()

struct LeapFrog{T} <: Lux.AbstractExplicitLayer
    sub_net::T
end 

(frog::LeapFrog)(x,ps,st) = (frog.sub_net(x[1],ps,st)[1]+x[2],x[1]),st
Lux.initialparameters(rng::AbstractRNG,frog::LeapFrog) = Lux.initialparameters(rng::AbstractRNG,frog.sub_net)
Lux.initialstates(rng::AbstractRNG,frog::LeapFrog) = Lux.initialstates(rng::AbstractRNG,frog.sub_net)

#LeapFrog containing a single Dense layer, TYPE STABLE
D = Dense(1=>1)
F = LeapFrog(D)
C = Chain(F,F,F,F)
ps,st = Lux.setup(rng,C)
v = ([10.],[-10.])
C(v,ps,st)
@code_warntype C(v,ps,st) #Type stable
psc = ps |> ComponentArray
@code_warntype C(v,psc,st)#Type stable

#LeapFrog containing a Chain, NOT TYPE STABLE
D = Chain(Dense(1=>10),Dense(10=>1))
F = LeapFrog(D)
C = Chain(F,F,F,F)
ps,st = Lux.setup(rng,C)
v = ([10.],[-10.])
C(v,ps,st)
@code_warntype C(v,ps,st) #Type stable
psc = ps |> ComponentArray
@code_warntype C(v,psc,st)#NOT Type stable
@arthur-bizzi
Copy link
Author

Interestingly, a single LeapFrog Layer containing a Chain also works fine. The problem seems to be chaining custom layers containing Chains.

@avik-pal
Copy link
Member

avik-pal commented Oct 6, 2023

Can you try changing frog.sub_net(x[1],ps,st)[1] to first(frog.sub_net(x[1],ps,st))? The first one might cause type instability, since you are indexing into a heterogeneous container.

@arthur-bizzi
Copy link
Author

Tried it, no difference.

@avik-pal
Copy link
Member

avik-pal commented Oct 9, 2023

(frog::LeapFrogCustom)(x,ps,st) = (frog.sub_net(x[1],ps,st)[1] .+ x[2],x[1]),st

See the broadcasted addition. That makes the code type inferrable (Not sure why, though).

Regardless if you use Cthulhu, it tells you exactly where the problem is.

@avik-pal avik-pal closed this as completed Oct 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants