Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No sanity checks on destructure and loadparams! #1408

Open
mcabbott opened this issue Nov 27, 2020 · 6 comments
Open

No sanity checks on destructure and loadparams! #1408

mcabbott opened this issue Nov 27, 2020 · 6 comments

Comments

@mcabbott
Copy link
Member

mcabbott commented Nov 27, 2020

Given too many parameters, or parameters of the wrong shapes, destructure and loadparams! silently have a go. I believe it would be safer to make these errors. Or at least warnings:

using Flux

m1 = Chain(Dense(2,2), Dense(2,2))
m1[1].W .= 1
v1, re1 = Flux.destructure(m1)
m2 = re1(rand(100)) # doesn't check length
m2[1].W # new values

m3 = Chain(Dense(2,2), Dense(2,2), Dense(2,2))
Flux.loadparams!(m1, params(m3)) # doesn't check length
m1[1].W # has been overwritten

m4 = Chain(Dense(2,2), Dense(2,3), Dense(3,2))
Flux.loadparams!(m1, params(m3)) # doesn't check shape either

When there are too few parameters, it does it seem to fail:

Flux.loadparams!(m4, params(m1)) # ERROR: Expected param size (3, 2), got (2, 2)
re1(rand(3))  # ERROR: BoundsError: attempt to access 3-element Array{Float64,1} at index [1:4]
@mcabbott
Copy link
Member Author

mcabbott commented Nov 28, 2020

However, what if you wanted to loadparams! from a 3-layer MLP onto a 2-layer one or vice versa? Assuming the shape of the first two layers are the same, this should be supported in Flux. That is what strict=False in load_state_dict gives you: the ability to partially load params into a network without having to slice it up and reconstitute it afterwards (if indeed that's even possible in e.g. the presence of custom container layers).

Originally posted by @ToucheSir in #1402 (comment)

The pytorch function is described for instance here (maybe not the best link?) https://pytorch.org/docs/master/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict

Perhaps we should follow that and make a keyword strict=true by default, which rejects all the size mismatches above. But what exactly is the rule for strict=false? Is the order of appearance in params(m) something on which results should depend?

@darsnack
Copy link
Member

If we want to address the strict = false case in a forward-thinking manner, then I think params needs to evolve past the "bag of arrays" approach to store structural information à la functors. This alone would make sanity checking better (by having a more narrowly scoped error case that length(params(m)) != length(ps)), and it would make educated-guessing for the strict = false case more reliable.

@mcabbott
Copy link
Member Author

The obvious treelike functor-structure here is the model itself. Perhaps it is loadparams!(m1, m5; strict=false) which ought to work, with some rule like ignoring branches absent from either tree? Or does this have downsides I'm missing (like re-used parameter arrays?)

I know there's been discussion of re-designing Params, do you have a link to a summary / entry point on that?

@ToucheSir
Copy link
Member

ToucheSir commented Jan 23, 2021

It would probably look something like https://github.com/FluxML/XLA.jl/blob/master/examples/conv.jl. in other words, Params may not be required at all.

@DhairyaLGandhi
Copy link
Member

Correct, that's why you have FluxML/Functors.jl#1 and Optimisers.jl, along with #1017 for training.

This is the new api that we are moving towards.

@DhairyaLGandhi
Copy link
Member

That's why I've updated Optimisers.jl recently to include most of the optimisers from flux (modulo some Adam derivatives but that is fairly easy)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants