Reference: [PyTorch – How to Load & Predict using Resnet Model](https://vitalflux.com/pytorch-load-predict-pretrained-resnet-model/)

In [2]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import PIL.Image as Image

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = models.resnext101_32x8d(pretrained=True)
# torch.save(model, "resnext101_32x8d")

1. Load the file containing the 1,000 labels for the ImageNet dataset classes
2. Find the index (tensor) corresponding to the maximum score in the out tensor. Torch.max function can be used to find the information.
3. Find the score in terms of percentage by using torch.nn.functional.softmax function which normalizes the output to range [0,1] and multiplying by 100.
4. Print the name along with score of the object identified by the model.
5. Print the top 5 scores along with the image label. Sort function is invoked on the torch to sort the scores.

In [3]:
def predictImage(imagePath, best_only=True):
    '''
    Requires:
    torch
    torchvision.models as models
    torchvision.transforms as transforms
    PIL.Image as Image
    '''

    torch.device("cuda")
    model = torch.load("resnext101_32x8d.pt")
    
    image = Image.open(imagePath)

    preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(244),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
        )
    ])

    img_pps = preprocess(image)
    batch_img_pps = torch.unsqueeze(img_pps, 0)
    model.eval()
    out = model(batch_img_pps)

    # Getting the labels
    with open('./imagenet_classes.txt') as f:
        labels = [line.strip() for line in f.readlines()]

    _, index = torch.max(out, 1)

    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

    if best_only:
        return labels[index[0]], percentage[index[0]].item()
    else:
        _, indices = torch.sort(out, descending=True)
        return [(labels[i], percentage[i].item()) for i in indices[0][:5]]

In [4]:
predictImage("./images/cat.jpg")

('tabby, tabby cat', 83.1136245727539)