In [None]:
import torchvision
import torchvision.models as models
import torch.utils.model_zoo as model_zoo

### Configuration

In [None]:
model_urls = {
    'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
    'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',
    'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
    'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',
    'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',
    #truncated _google to match module name
    'inception_v3': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',    
    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',
    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',  
}

model_names = model_urls.keys()

input_sizes = {
    'alexnet' : (224,224),
    'densenet': (224,224),
    'resnet' : (224,224),
    'inception' : (299,299),
    'squeezenet' : (224,224),#not 255,255 acc. to https://github.com/pytorch/pytorch/issues/1120
    'vgg' : (224,224)
}

#models_to_test = ['alexnet', 'densenet169', 'inception_v3', \
#                  'resnet34', 'squeezenet1_1', 'vgg13']

models_to_test = model_names

### Generic pretrained model loading

In [None]:
#We solve the dimensionality mismatch between
#final layers in the constructed vs pretrained
#modules at the data level.
def diff_states(dict_canonical, dict_subset):
    names1, names2 = (list(dict_canonical.keys()), list(dict_subset.keys()))
    
    #Sanity check that param names overlap
    #Note that params are not necessarily in the same order
    #for every pretrained model
    not_in_1 = [n for n in names1 if n not in names2]
    not_in_2 = [n for n in names2 if n not in names1]
    assert len(not_in_1) == 0
    assert len(not_in_2) == 0

    for name, v1 in dict_canonical.items():
        v2 = dict_subset[name]
        assert hasattr(v2, 'size')
        if v1.size() != v2.size():
            yield (name, v1)                

def load_model_named(name):   
    #Densenets don't (yet) pass on num_classes, hack it in
    if "densenet" in name:
        if name == 'densenet169':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 32, 32), num_classes=num_classes)
        
        elif name == 'densenet121':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 24, 16), num_classes=num_classes)
        
        elif name == 'densenet201':
            return models.DenseNet(num_init_features=64, growth_rate=32, \
                                   block_config=(6, 12, 48, 32), num_classes=num_classes)

        elif name == 'densenet161':
             return models.DenseNet(num_init_features=96, growth_rate=48, \
                                    block_config=(6, 12, 36, 24), num_classes=num_classes)
        else:
            raise ValueError("Cirumventing missing num_classes kwargs not implemented for %s" % name)
    
    return models.__dict__[name](num_classes=num_classes)
    
            
def load_model(name, num_classes):
    
    model = load_model_named(name)
    pretrained_state = model_zoo.load_url(model_urls[name])

    #Diff
    diff = list(diff_states(model.state_dict(), pretrained_state))
    
    for name, value in diff:
        pretrained_state[name] = value
    
    assert len(list(diff_states(model.state_dict(), pretrained_state))) == 0
    
    #Merge
    model.load_state_dict(pretrained_state)
    return model, diff

In [None]:
# Method to mutate module programmatically (PR #175)
# https://github.com/pytorch/vision/pull/175

def resize_network_output(net, output_size):
    if isinstance(net, torch.nn.DataParallel):
        return resize_network_output(net.module, output_size)

    # Edit: Can't index iterable in python3
    #output_layer = net._modules.keys()[-1]
    for output_layer in net._modules.keys():
        pass
    old_output_layer = net._modules[output_layer]

    if isinstance(old_output_layer, nn.Sequential):
        return resize_network_output(old_output_layer, output_size)
    elif isinstance(old_output_layer, nn.modules.pooling.AvgPool2d):
        # Go back in the layer sequence and find the last conv layer and resize that
        # Only happens for squeezenet1_0
        # Edit: iteritems deprecated in python3
        for name, layer in list(net._modules.items())[::-1][1:]:
            if isinstance(layer, nn.modules.conv.Conv2d):
                net._modules[name] = nn.modules.conv.Conv2d(layer.in_channels, output_size, layer.kernel_size,
                                                            layer.stride, layer.padding, layer.dilation, layer.groups)
                return
        assert False

    assert isinstance(old_output_layer, nn.Linear), 'Class of old_output_layer {}'.format(old_output_layer.__class__.__name__)
    input_size = old_output_layer.weight.size()[1]

    net._modules[output_layer] = nn.Linear(input_size, output_size)


def load_model_resize_post(name, num_classes):
    model = load_model_named(name)
    resize_network_output(model, num_classes)
    return model

## Compare generic loading methods

In [None]:
# If no cuda is present, unpickle fails with this net...
# Need to update pretrained model with cpu to resolve?
# models_to_test.remove('densenet169')

In [None]:
num_classes = 10

for name in models_to_test:
    print("")
    print(name, "with %d classes" % num_classes)
    try:
        model_merged, diff = load_model(name, num_classes)
        diff_vanilla = [d[0] for d in diff]
        result = ("... merge loading: " + str(diff_vanilla)).ljust(99) \
        + "OK" if len(diff_vanilla) > 0 else "X"
    except Exception as e:
        result = ("... merge loading: " + str(e)).ljust(99) + "X"
    finally:
        print(result)
    
    try:
        model_resized = load_model_resize_post(name, num_classes)
        diff_merged_resized = [p[0] for p in \
                               diff_states(model_merged.state_dict(), model_resized.state_dict())]
        result = ("... resizing after load: " + str(diff_merged_resized)).ljust(99) \
        + "OK" if len(diff_merged_resized) == 0 else "X"
    except Exception as e:
        result = ("... resizing after load: " + str(e)).ljust(99) + "X"
    finally:
        print(result)        