Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cannot load pretrained weights for ResNet on master #206

Closed
lorenzoh opened this issue Nov 27, 2022 · 13 comments · Fixed by #235
Closed

Cannot load pretrained weights for ResNet on master #206

lorenzoh opened this issue Nov 27, 2022 · 13 comments · Fixed by #235
Labels
bug Something isn't working
Milestone

Comments

@lorenzoh
Copy link
Member

Package Version

0.8.0-DEV (master)

Julia Version

1.8.3

OS / Environment

] st
Project Metalhead v0.8.0-DEV
Status `~/.julia/dev/Metalhead/Project.toml`
  [fbb218c0] BSON v0.3.6
  [052768ef] CUDA v3.12.0
  [d360d2e6] ChainRulesCore v1.15.6
  [587475ba] Flux v0.13.8
  [d9f16b24] Functors v0.3.0
  [f1d291b0] MLUtils v0.3.1
  [872c559c] NNlib v0.8.10
  [a00861dc] NNlibCUDA v0.2.4
  [570af359] PartialFunctions v1.1.1
  [56f22d72] Artifacts
  [4af54fe1] LazyArtifacts
  [9a3f8284] Random
  [10745b16] Statistics 

Describe the bug

Running

using Metalhead
Metalhead.ResNet(18, pretrain=true)

errors with:

ERROR: ArgumentError: Tried to load (MODEL)
    @ Base ./reduce.jl:58
  [9] foldl_impl
    @ ./reduce.jl:48 [inlined]
 [10] mapfoldl_impl
    @ ./reduce.jl:44 [inlined]
 [11] #mapfoldl#259
    @ ./reduce.jl:170 [inlined]
 [12] #foldl#260
    @ ./reduce.jl:193 [inlined]
 [13] foreach
    @ ./tuple.jl:556 [inlined]
 [14] #loadmodel!#352
    @ ~/.julia/packages/Flux/FKl3M/src/loading.jl:91 [inlined]
 [15] (::Flux.var"#353#356"{Flux.var"#354#357", Base.IdSet{Any}, DimensionMismatch})(ldst::Tuple{MODEL})
    @ Flux ~/.julia/packages/Flux/FKl3M/src/loading.jl:98
 [16] foreach(::Function, ::MODEL)
    @ Base ./abstractarray.jl:2775
 [17] loadmodel!(dst::MODEL; filter::Function, cache::Base.IdSet{Any})
    @ Flux ~/.julia/packages/Flux/FKl3M/src/loading.jl:91
 [18] loadmodel!(dst::MODEL)
    @ Flux ~/.julia/packages/Flux/FKl3M/src/loading.jl:84
 [19] loadpretrain!(model::MODEL, name::String)
    @ Metalhead ~/.julia/dev/Metalhead/src/pretrain.jl:26
 [20] ResNet(depth::Int64; pretrain::Bool, inchannels::Int64, nclasses::Int64)
    @ Metalhead ~/.julia/dev/Metalhead/src/convnets/resnets/resnet.jl:26
 [21] top-level scope
    @ REPL[6]:1

Steps to Reproduce

Install master and run the above code.

Expected Results

Expected it to load the weights into the model.

Observed Results

It threw the error.

Relevant log output

No response

@lorenzoh lorenzoh added the bug Something isn't working label Nov 27, 2022
@pri1311
Copy link
Contributor

pri1311 commented Dec 17, 2022

I think the problem is
The pretained weights have the structure as

Chain(Parallel(
    PartialFunctions.PartialFunction{typeof(Metalhead.addact), Tuple{typeof(NNlib.relu)}, NamedTuple{(), Tuple{}}}(Metalhead.addact, (NNlib.relu,), NamedTuple()), 
    Chain(Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false), 
    BatchNorm(128), relu, 
    Conv((3, 3), 128 => 128, pad=1, bias=false),
    BatchNorm(128)),
    Chain(Conv((1, 1), 64 => 128, stride=2, bias=false), 
    BatchNorm(128))),

whereas the model implemented in Metalhead.jl as structure

Chain(Parallel(
    PartialFunctions.PartialFunction{typeof(Metalhead.addact), Tuple{typeof(NNlib.relu)}, NamedTuple{(), Tuple{}}}(Metalhead.addact, (NNlib.relu,), NamedTuple()), 
    Chain(Conv((1, 1), 64 => 128, stride=2, bias=false), 
    BatchNorm(128)), 
    Chain(Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false), 
    BatchNorm(128), relu, 
    Conv((3, 3), 128 => 128, pad=1, bias=false), 
    BatchNorm(128))),

Above comparison is made from the structures mentioned in the error message.

@pri1311
Copy link
Contributor

pri1311 commented Dec 22, 2022

@theabhirath @darsnack do we plan to tweak the model implementation to match torchvision?

@shivance
Copy link
Contributor

For me it's working fine

julia> model = ResNet(18; pretrain=true)
  Downloaded artifact: resnet18
ResNet(
  Chain(
    Chain([
      Conv((7, 7), 3 => 64, pad=3, stride=2, bias=false),  # 9_408 parameters
      BatchNorm(64, relu),              # 128 parameters, plus 128
      MaxPool((3, 3), pad=1, stride=2),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64, relu),          # 128 parameters, plus 128
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
        identity,
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64, relu),          # 128 parameters, plus 128
          Conv((3, 3), 64 => 64, pad=1, bias=false),  # 36_864 parameters
          BatchNorm(64),                # 128 parameters, plus 128
        ),
        identity,
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 64 => 128, pad=1, stride=2, bias=false),  # 73_728 parameters
          BatchNorm(128, relu),         # 256 parameters, plus 256
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
        Chain([
          Conv((1, 1), 64 => 128, stride=2, bias=false),  # 8_192 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ]),
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128, relu),         # 256 parameters, plus 256
          Conv((3, 3), 128 => 128, pad=1, bias=false),  # 147_456 parameters
          BatchNorm(128),               # 256 parameters, plus 256
        ),
        identity,
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 128 => 256, pad=1, stride=2, bias=false),  # 294_912 parameters
          BatchNorm(256, relu),         # 512 parameters, plus 512
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
        Chain([
          Conv((1, 1), 128 => 256, stride=2, bias=false),  # 32_768 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ]),
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256, relu),         # 512 parameters, plus 512
          Conv((3, 3), 256 => 256, pad=1, bias=false),  # 589_824 parameters
          BatchNorm(256),               # 512 parameters, plus 512
        ),
        identity,
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 256 => 512, pad=1, stride=2, bias=false),  # 1_179_648 parameters
          BatchNorm(512, relu),         # 1_024 parameters, plus 1_024
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
        Chain([
          Conv((1, 1), 256 => 512, stride=2, bias=false),  # 131_072 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ]),
      ),
      Parallel(
        Metalhead.addrelu,
        Chain(
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512, relu),         # 1_024 parameters, plus 1_024
          Conv((3, 3), 512 => 512, pad=1, bias=false),  # 2_359_296 parameters
          BatchNorm(512),               # 1_024 parameters, plus 1_024
        ),
        identity,
      ),
    ]),
    Chain(
      AdaptiveMeanPool((1, 1)),
      MLUtils.flatten,
      Dense(512 => 1000),               # 513_000 parameters
    ),
  ),
)         # Total: 62 trainable arrays, 11_689_512 parameters,
          # plus 40 non-trainable, 9_600 parameters, summarysize 44.642 MiB.

@jeremiedb
Copy link

It works fine on latest release (v0.7.3) (commit b37bee7: b37bee7)

However, I can confirm that it's broken on current master.
Trying out from on some previous commits, loading of pre-trained models appears broken on da5d3a7 (da5d3a7)

So a commit between June 26 and Sept 2 did broke pre-trained Resnet.

I did tried on commit 588d703, July 21 2022, but ResNet function function wasn't exported.

@theabhirath Since you were working on these ResNet at this time, would you be able to take a look at this?

@theabhirath
Copy link
Member

theabhirath commented Jan 14, 2023

IIRC, the reason it's broken is because we made some structural changes to the model and didn't port the weights again. I'll be happy to port the weights, but we ran into issues with loading weights on 1.6 if the weights were ported on later Julia versions due to minor edge cases like RNG changes. I'll take a look at what we can do to avoid that (of interest may be FluxML/Functors.jl#56, which should allow us to only save the trainable parameters and thus avoid these edge cases).

@ToucheSir
Copy link
Member

That function is just a convenient wrapper, so we can already do this. The only remaining bit is to give the user-facing API a reasonable name so that people can actually remember it.

@christiangnrd
Copy link
Contributor

It seems like ResNeXt is also broken with the same error on master.

@ToucheSir
Copy link
Member

It's included under the general ResNet umbrella here, so all the above posts apply. In meantime, you could load the weights separately, patch up the mismatched portion and call Flux.loadmodel! yourself.

@RomeoV
Copy link

RomeoV commented Feb 8, 2023

Can you elaborate on what the "mismatched portion" is?

@darsnack
Copy link
Member

darsnack commented Feb 8, 2023

The mismatch (@theabhirath can correct me) is that the pre-trained models contain dropout layers and the model we are loading into may not depending on the configuration.

We added FluxML/Flux.jl#2041 specifically for this purpose, and I think that's what @ToucheSir is referencing.

@theabhirath
Copy link
Member

Actually, the current mismatch is that the order of the weights in the PyTorch state dict don't match the order of our model, so iterating through both of them in parallel and simply trying to load the weights won't work. I am trying to see if something can be done about this

@darsnack
Copy link
Member

darsnack commented Feb 8, 2023

Why can't we port the existing weights that have been previously released instead of going back to PyTorch? Sure that will leave out new variants that we've added, but it will at least let us release without a regression.

@darsnack
Copy link
Member

darsnack commented Feb 8, 2023

In the case of the iteration mismatch, we have access to the parameter keys in both dicts. Since those keys reflect the underlying structure, it should be possible to map one to the other semi-automatically? Then we can iterate one dict and access the other by the matching key.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

10 participants