In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torch.utils.model_zoo as model_zoo

### Configuration

In [2]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',
    #truncated key (_google) to match module name
    'inception_v3': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',    
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',    
}

model_names = model_urls.keys()
models_to_test = model_names

### Generic pretrained model loading

In [3]:
def load_model_named(name, **kwargs):   
    #Densenets don't (yet) pass on num_classes, hack it in
    if "densenet" in name:
        if name == 'densenet169':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 32, 32), num_classes=num_classes)
        
        elif name == 'densenet121':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 24, 16), num_classes=num_classes)
        
        elif name == 'densenet201':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 48, 32), num_classes=num_classes)

        elif name == 'densenet161':
             return models.DenseNet(num_init_features=96, growth_rate=48, \
                                    block_config=(6, 12, 36, 24), num_classes=num_classes)
        else:
            raise ValueError("Cirumventing missing num_classes kwargs not implemented for %s" % name)
    
    return models.__dict__[name](**kwargs)       

Variants from https://github.com/pytorch/vision/issues/173

In [4]:
def load_model(name, num_classes):
    model = load_model_named(name, num_classes=num_classes)
    
    #Update
    model_dict = model.state_dict()    
    pretrained_dict = model_zoo.load_url(model_urls[name])

    diff = {k: v for k, v in model_dict.items() if \
            k in pretrained_dict and pretrained_dict[k].size() != v.size()}
    
    pretrained_dict.update(diff)
    
    model.load_state_dict(pretrained_dict)
    return model, diff

In [5]:
def load_model_key_only(name, num_classes):
    model = load_model_named(name, num_classes=num_classes)
    
    #Update
    model_dict = model.state_dict()
    pretrained_dict = model_zoo.load_url(model_urls[name])

    # 1. filter out unnecessary keys
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    model.load_state_dict(model_dict)
    
    return model

## Test generic loading methods

In [6]:
# Note: If cuda isn't available, unpickle fails with (at least) densenets

In [7]:
num_classes = 10

for name in ['alexnet', 'resnet18', 'vgg11']: # models_to_test:
    print("")
    print(name, "with %d classes" % num_classes)
    try:
        model_merged, diff = load_model(name, num_classes)
        diff_keys = list(diff.keys())
        result = ("... merge loading: " + str(diff_keys)).ljust(80) \
        + (" OK" if len(diff_keys) > 0 else " X")
    except Exception as e:
        result = ("... merge loading: " + str(e)).ljust(80) + " X"
    finally:
        print(result)
        
    try:
        model_keys = load_model_key_only(name, num_classes)
        diff_keys = [p[0] for p in \
                    diff_states(model_merged.state_dict(), model_resized.state_dict())]
        result = ("... merge on keys: " + str(diff_keys)).ljust(80) \
        + (" OK" if len(diff_keys) == 0 else " X")
    except Exception as e:
        result = ("... merge on keys: " + str(e)).ljust(80) + " X"
    finally:
        print(result)        


alexnet with 10 classes
... merge loading: ['classifier.6.weight', 'classifier.6.bias']                  OK
... merge on keys: inconsistent tensor size at /Users/soumith/miniconda2/conda-bld/pytorch_1493757603856/work/torch/lib/TH/generic/THTensorCopy.c:51 X

resnet18 with 10 classes
... merge loading: ['fc.weight', 'fc.bias']                                      OK
... merge on keys: inconsistent tensor size at /Users/soumith/miniconda2/conda-bld/pytorch_1493757603856/work/torch/lib/TH/generic/THTensorCopy.c:51 X

vgg11 with 10 classes
... merge loading: ['classifier.6.weight', 'classifier.6.bias']                  OK
... merge on keys: inconsistent tensor size at /Users/soumith/miniconda2/conda-bld/pytorch_1493757603856/work/torch/lib/TH/generic/THTensorCopy.c:51 X
