Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct counting of shared parameters in Base.show. #2335

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

codetalker7
Copy link
Contributor

@codetalker7 codetalker7 commented Sep 14, 2023

This PR fixes the show function by correctly counting non-trainable parameters. The earlier counting code duplicated shared parameters in it's count (

noncnt = _childarray_sum(_->1, m) - length(ps)
), and hence some shared trainable parameters were being counted as being non-trainable. The change is in the _big_finale function, where, instead of duplicating the counts, we use an IdSet to keep track of which parameters have been counted (and don't count a parameter twice).

As an example, now the following code shows the correct output:

julia> using Flux;

julia> d = Dense(10 => 10);

julia> shared_layer = Chain(Embedding(10, 10), d, d)
Chain(
  Embedding(10 => 10),                  # 100 parameters
  Dense(10 => 10),                      # 110 parameters
  Dense(10 => 10),                      # 110 parameters
)                   # Total: 3 arrays, 210 parameters, 1.055 KiB.

julia> normal_layer = Chain(Embedding(10, 10), Dense(10 => 10), Dense(10 => 10))
Chain(
  Embedding(10 => 10),                  # 100 parameters
  Dense(10 => 10),                      # 110 parameters
  Dense(10 => 10),                      # 110 parameters
)                   # Total: 5 arrays, 320 parameters, 1.562 KiB.

TODO:

  • Add tests.
  • Add an example in the docs for shared parameters?

Closes #2321.

@codetalker7
Copy link
Contributor Author

If this looks good, I'll go ahead and add some tests and add an example in the documentation as well.

@codetalker7
Copy link
Contributor Author

The documentation CI is failing for an RNN. I assume even there the output is incorrect? It's probably showing the state parameter as non-trainable (which is trainable, right?)

@ToucheSir
Copy link
Member

ToucheSir commented Sep 15, 2023

It's probably showing the state parameter as non-trainable (which is trainable, right?)

That's Recur.state, which should be non-trainable. Note how only cell is included below:

julia> Flux.trainable(RNN(2 => 5))
(cell = RNNCell(2 => 5, tanh),)

@codetalker7
Copy link
Contributor Author

codetalker7 commented Sep 15, 2023

It's probably showing the state parameter as non-trainable (which is trainable, right?)

That's Recur.state, which should be non-trainable. Note how only cell is included below:

julia> Flux.trainable(RNN(2 => 5))
(cell = RNNCell(2 => 5, tanh),)

I see, yes, that makes sense. I think I understand now why it's not showing any non-trainable parameters for RNN(2 => 5): this is because, both the initial state (state0 of the cell) and Recur.state are initialized to the zero matrix (and hence pushing both these matrices to the IdSet just pushes one matrix instead of two). Instead of this, we'll have to push names of parameters to the IdSet as well (to distinguish between two distinct parameters having the same value). Even pushing names of parameters might not work, since two layers can share the same parameter name and the same parameter values and still be different.

Just to confirm: is it true that all parameters in Flux (i.e, Functors.children(m), where m is some layer) have unique names associated to them? If not, I don't immediately see a way of counting the total number of distinct parameters.

@ToucheSir
Copy link
Member

Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like Adjoint and Transpose as aliases? Which wrappers in particular? What about reshapes of an Array, which share the same data but have different objectids and thus aren't caught by using an IdSet? It's not an easy problem, but this PR is a good start.

@codetalker7
Copy link
Contributor Author

Yes, tied parameters are tricky as we found out while working on Optimisers.jl. Sometimes it feels like a philosophical question. Do we consider array wrappers like Adjoint and Transpose as aliases? Which wrappers in particular? What about reshapes of an Array, which share the same data but have different objectids and thus aren't caught by using an IdSet? It's not an easy problem, but this PR is a good start.

Taking inspiration from Flux.params!, I tried to push the whole layer to the IdSet instead of just AbstractArrays, and that seems to be giving correct results. How does it look now?

@ToucheSir
Copy link
Member

I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?

@mcabbott
Copy link
Member

Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.

(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)

@codetalker7
Copy link
Contributor Author

I believe that'd run into the same problem with shared params across nominally different layers. Maybe one idea would be to separately count the number of shared params and report that?

Hi @ToucheSir, could you explain the "nominally different layers" part? I didn't quite follow it. Maybe an example?

Can we farm more of this out to Functors / Optimisers? Instead of building an IdSet by hand, let Functors cache things. Then this will inherit its understanding of Adjoint etc.

(I believe Optimisers.jl has a trainable-only walk definition, since it owns that concept.)

Sure; I'll take a look at both Functors and Optimisers more closely.

@ToucheSir
Copy link
Member

Something like this:

d1 = Dense(3 => 4)
d2 = Dense(d1.weight)
d1.weight === d2.weight # tied
d1 !== d2 # but pushing the whole layer won't capture that

@codetalker7
Copy link
Contributor Author

d1 !== d2

I see, yes, that makes sense. I think, any solution which counts distinct (or shared) parameters in a model must use some form of unique ID associated to that parameter (I can't think of other ways atm, maybe there are more clever ways). Can we somehow associate such an ID to every parameter in a Flux model? Or more generally, associate some metadata to each leaf of a struct?

@mcabbott
Copy link
Member

Functors uses a cache which should detect such sharing. It's a little smarter than just using objectid, so as not to catch immutable objects which are accidentally ===.

julia> using Functors, Flux

julia> let mat = rand(2,2)
         model = Chain(Dense(mat), Dense(mat')) # separate bias vectors
         cnt = Ref(0)
         fmapstructure(model; exclude=x->x isa Array) do x
           cnt[] += 1
         end
       end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = (parent = 1,), bias = 3, σ = ())),)

julia> using StaticArrays

julia> [1,2] === [1,2]  # different arrays with same value, not shared
false

julia> SA[1,2] === SA[1,2]  # here the value is the only identity... still not shared.
true

julia> let mat = @SMatrix rand(2,2)
         model = Chain(Dense(mat), Dense(mat'))  # still has mat === mat
         cnt = Ref(0)
         fmapstructure(model; exclude=x->x isa AbstractArray) do x
           cnt[] += 1
         end
       end
(layers = ((weight = 1, bias = 2, σ = ()), (weight = 3, bias = 4, σ = ())),)

I think fmap like this ought to be equivalent to Flux.params. But the trainable count needs a modified walk to exclude some children.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

show is confused by shared parameters
3 participants