In [1]:
from torchvision import models, transforms
from torch.autograd import Variable
from torch.nn import functional as F
import numpy as np
from PIL import Image
import cv2
from functools import partial
import sys
import matplotlib.pyplot as plt
# load a pretrained model, such a model already has a global pooling at the end
# model_id: 1 - SqueezeNet, 2 - ResNet, 3 - DenseNet
def load_model(model_id):
    if model_id == 1:
        model = models.squeezenet1_1(pretrained = True)
        final_conv_layer = 'classifier.1'
    elif model_id == 2:
        model = models.resnet101(pretrained = True)
        final_conv_layer = 'layer4'
    elif model_id == 3:
        model = models.densenet161(pretrained = True)
        final_conv_layer = 'features'
    else:
        sys.exit('No such model!')

    return model, final_conv_layer

# a hook to a given layer
def hook(module, input, output, feature_blob):
    feature_blob.append(output.data.numpy())

# load and preprocess an image
def load_image(filename = './cat.jpg'):
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        normalize
    ])

    image = Image.open(filename)
    image = preprocess(image)

    return Variable(image.unsqueeze(0))

# read in labels, original file url: https://s3.amazonaws.com/outcome-blog/imagenet/labels.json
def get_labels(filename = '.1000.txt'):
    content=open(filename).read().split('\n')
    # labels = {int(k) : v for (k, v) in content}
    return content

# compute class activation map
def compute_cam(activation, softmax_weight, class_ids):
    b, c, h, w = activation.shape
    cams = []
    for idx in class_ids:
        activation = activation.reshape(c, h * w)
        cam = softmax_weight[idx].dot(activation)
        cam = cam.reshape(h, w)
        # normalize to [0, 1]
        cam =  (cam - cam.min()) / (cam.max() - cam.min())
        # conver to [0, 255]
        cam = np.uint8(255 * cam)
        # reshape to (224, 224)
        cams.append(cv2.resize(cam, (224, 224)))

    return cams


# load a pretrained model
model, final_conv_layer = load_model(2)    # model_id: 1 - SqueezeNet, 2 - ResNet, 3 - DenseNet
model.eval()

# add a hook to a given layer
feature_blob = []
model._modules.get(final_conv_layer).register_forward_hook(partial(hook, feature_blob = feature_blob))

# get the softmax (last fc layer) weight
params = list(model.parameters())
softmax_weight = np.squeeze(params[-2].data.numpy())

input = load_image('./cat.jpg')

output = model(input)   # scores

labels = get_labels('./1000.txt')

probs = F.softmax(output).data.squeeze()
probs, idx = probs.sort(0, descending = True)

# output the top-5 prediction
for i in range(5):
    print('{:.3f} -> {}'.format(probs[i], labels[idx[i]]))

# generate class activation map for the top-5 prediction
cams = compute_cam(feature_blob[0], softmax_weight, idx[0: 5])

for i in range(len(cams)):
    # render cam and original image
    filename = labels[idx[i]] + '.jpg'
    print('output %s for the top-%s prediction: %s' % (filename, (i + 1), labels[idx[i]]))

    img = cv2.imread('./cat.jpg')
    h, w, _ = img.shape
    heatmap = cv2.applyColorMap(cv2.resize(cams[i], (w, h)), cv2.COLORMAP_JET)
    result = heatmap * 0.3 + img * 0.5
    # plt.imshow(result,cmap='hot')
    cv2.imwrite(filename, result)

  from .autonotebook import tqdm as notebook_tqdm
Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /home/yanai-lab/xiong-p/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|██████████| 171M/171M [00:02<00:00, 88.9MB/s] 


0.608 -> n02124075 Egyptian cat
0.117 -> n02123159 tiger cat
0.094 -> n02123394 Persian cat
0.044 -> n02123045 tabby, tabby cat
0.014 -> n01622779 great grey owl, great gray owl, Strix nebulosa
output n02124075 Egyptian cat.jpg for the top-1 prediction: n02124075 Egyptian cat
output n02123159 tiger cat.jpg for the top-2 prediction: n02123159 tiger cat
output n02123394 Persian cat.jpg for the top-3 prediction: n02123394 Persian cat
output n02123045 tabby, tabby cat.jpg for the top-4 prediction: n02123045 tabby, tabby cat
output n01622779 great grey owl, great gray owl, Strix nebulosa.jpg for the top-5 prediction: n01622779 great grey owl, great gray owl, Strix nebulosa


  probs = F.softmax(output).data.squeeze()
