Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xception fails to load weights in PyTorch 0.4.0 (with fix) #62

Closed
fabioperez opened this issue Apr 27, 2018 · 4 comments
Closed

Xception fails to load weights in PyTorch 0.4.0 (with fix) #62

fabioperez opened this issue Apr 27, 2018 · 4 comments

Comments

@fabioperez
Copy link

fabioperez commented Apr 27, 2018

When loading the Xception net with PyTorch 0.4.0, I got the following error:

While copying the parameter named "block1.rep.0.pointwise.weight", whose dimensions in the model are torch.Size([128, 64, 1, 1]) and whose dimensions in the checkpoint are torch.Size([128, 64]).
While copying the parameter named "block1.rep.3.pointwise.weight", whose dimensions in the model are torch.Size([128, 128, 1, 1]) and whose dimensions in the checkpoint are torch.Size([128, 128]).
While copying the parameter named "block2.rep.1.pointwise.weight", whose dimensions in the model are torch.Size([256, 128, 1, 1]) and whose dimensions in the checkpoint are torch.Size([256, 128]).
While copying the parameter named "block2.rep.4.pointwise.weight", whose dimensions in the model are torch.Size([256, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([256, 256]).
While copying the parameter named "block3.rep.1.pointwise.weight", whose dimensions in the model are torch.Size([728, 256, 1, 1]) and whose dimensions in the checkpoint are torch.Size([728, 256]).
While copying the parameter named "block3.rep.4.pointwise.weight", whose dimensions in the model are torch.Size([728, 728, 1, 1]) and whose dimensions in the checkpoint are torch.Size([728, 728]).
...

I could fix the problem by modifying the shapes of these layers in the weights file:

import torch
import pretrainedmodels as ptm

state_dict = torch.load('./xception-b5690688.pth')

for name, weights in state_dict.items():
    if 'pointwise' in name:
        state_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
        
torch.save(state_dict, 'xception-fixed.pth')

Other networks may also result in this error.

@fabioperez fabioperez changed the title Xception fails to load weights in PyTorch 0.4.0 Xception fails to load weights in PyTorch 0.4.0 (with fix) Apr 27, 2018
@Cadene
Copy link
Owner

Cadene commented Apr 28, 2018

Thanks for submitting this issue. I will try my best to make this lib 0.4.x compatible soon.

@PapaMadeleine2022
Copy link

I meet the same error

@Cadene
Copy link
Owner

Cadene commented Oct 27, 2018

Fixed

@Cadene Cadene closed this as completed Oct 27, 2018
@AnsonCNS
Copy link

I have the same error on torch==1.7.1

I've tried the fix by @fabioperez, but it didn't work.

RuntimeError: Error(s) in loading state_dict for Xception:
        size mismatch for block1.rep.0.pointwise.weight: copying a param with shape torch.Size([128, 64]) from checkpoint, the shape in current model is torch.Size([128, 64, 1, 1]).
        size mismatch for block1.rep.3.pointwise.weight: copying a param with shape torch.Size([128, 128]) from checkpoint, the shape in current model is torch.Size([128, 128, 1, 1]).
        size mismatch for block2.rep.1.pointwise.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([256, 128, 1, 1]).
        size mismatch for block2.rep.4.pointwise.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([256, 256, 1, 1]).
        size mismatch for block3.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 256]) from checkpoint, the shape in current model is torch.Size([728, 256, 1, 1]).
        size mismatch for block3.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block4.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block4.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block4.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block5.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block5.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block5.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block6.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block6.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block6.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block7.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block7.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block7.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block8.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block8.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block8.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block9.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block9.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block9.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block10.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block10.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block10.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block11.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block11.rep.4.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block11.rep.7.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block12.rep.1.pointwise.weight: copying a param with shape torch.Size([728, 728]) from checkpoint, the shape in current model is torch.Size([728, 728, 1, 1]).
        size mismatch for block12.rep.4.pointwise.weight: copying a param with shape torch.Size([1024, 728]) from checkpoint, the shape in current model is torch.Size([1024, 728, 1, 1]).
        size mismatch for conv3.pointwise.weight: copying a param with shape torch.Size([1536, 1024]) from checkpoint, the shape in current model is torch.Size([1536, 1024, 1, 1]).
        size mismatch for conv4.pointwise.weight: copying a param with shape torch.Size([2048, 1536]) from checkpoint, the shape in current model is torch.Size([2048, 1536, 1, 1]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants