In [1]:
import torch
from torchvision import models
from torchvision.models import ResNet101_Weights
from PIL import Image
from torchvision import transforms

# Obtain pre-trained ResNet-101
resnet101 = models.resnet101(weights=ResNet101_Weights.IMAGENET1K_V1)


# Define image preprocess pipeline
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


# Load an image and preprocess it
img = Image.open("bobby.jpg") # img.size: (1280, 720)
# img # Show the image inline
# img.show() # Show the image in a new pop-up viewer window
img_t = preprocess(img) # img_t.shape: (torch.Size([3, 224, 224])
batch_t = torch.unsqueeze(img_t, 0)# batch_t.shape: torch.Size([1, 3, 224, 224]))


# Put the network in `eval()` mode to do inference
resnet101.eval()
out = resnet101(batch_t) # out.shape: torch.Size([1, 1000])


# Load labels from `.txt` file (1,000 labels)
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()] 


# Determine the index corresponding to the maximum score in the `out` tensor
_, index = torch.max(out,1) # index: tensor([207])
percentage = torch.nn.functional.softmax(out, dim=1)[0]*100 # percentage.shape: torch.Size([1000])
(labels[index[0]], percentage[index[0]].item()) # ('golden retriever', 96.57185363769531)


# Determine the indexes corresponding to the top-5 maximum score in the `out` tensor
_, indices = torch.sort(out, descending=True)
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
# [('golden retriever', 96.57185363769531),
#  ('Labrador retriever', 2.6082706451416016),
#  ('cocker spaniel, English cocker spaniel, cocker', 0.2699621915817261),
#  ('redbone', 0.17958936095237732),
#  ('tennis ball', 0.10991999506950378)]

[('golden retriever', 96.57185363769531),
 ('Labrador retriever', 2.6082706451416016),
 ('cocker spaniel, English cocker spaniel, cocker', 0.2699621915817261),
 ('redbone', 0.17958936095237732),
 ('tennis ball', 0.10991999506950378)]