Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mismatch with torchvision resnets #228

Closed
CarloLucibello opened this issue Apr 24, 2023 · 1 comment · Fixed by #229
Closed

mismatch with torchvision resnets #228

CarloLucibello opened this issue Apr 24, 2023 · 1 comment · Fixed by #229

Comments

@CarloLucibello
Copy link
Member

CarloLucibello commented Apr 24, 2023

I'm using the script
https://github.com/FluxML/Metalhead.jl/blob/master/scripts/port_torchvision.jl
to load torchvision's models and copy their weights into Metalhead's ones.

With vggX model all is fine.

With resnets instead I get the following mismatchs:

  • ResNet18:
    flux_key = "model.layers[1].layers[3].layers[1].layers[1].layers[1].conv_weight"
    size(flux_param) = (1, 1, 64, 128)
    pytorch_key = "layer2.0.conv1.weight"
    size(pytorch_param) = (3, 3, 64, 128)
    
  • ResNet34
    flux_key = "model.layers[1].layers[3].layers[1].layers[1].layers[1].conv_weight"
    size(flux_param) = (1, 1, 64, 128)
    pytorch_key = "layer2.0.conv1.weight"
    size(pytorch_param) = (3, 3, 64, 128)
    
  • ResNet50
    flux_key = "model.layers[1].layers[2].layers[1].layers[1].layers[1].conv_weight"
    size(flux_param) = (1, 1, 64, 256)
    pytorch_key = "layer1.0.conv1.weight"
    size(pytorch_param) = (1, 1, 64, 64)
    
  • ResNet101
    flux_key = "model.layers[1].layers[2].layers[1].layers[1].layers[1].conv_weight"
    size(flux_param) = (1, 1, 64, 256)
    pytorch_key = "layer1.0.conv1.weight"
    size(pytorch_param) = (1, 1, 64, 64)
    
  • ResNet152
    flux_key = "model.layers[1].layers[2].layers[1].layers[1].layers[1].conv_weight"
    size(flux_param) = (1, 1, 64, 256)
    pytorch_key = "layer1.0.conv1.weight"
    size(pytorch_param) = (1, 1, 64, 64)
    
@theabhirath
Copy link
Member

It's because Parallel gives the layers in a different order than the one in torchvision. Enumerating over reverse(node.layers) at

for (i, n) in enumerate(node.layers)
should fix this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants