# Example 3 - Pre-trained models for image classification

Available models are here:

https://pytorch.org/docs/stable/torchvision/models.html

In [0]:
import torch
from PIL import Image
from torchvision import models
from torchvision import transforms
import matplotlib.pyplot as plt

# Download an image and imagenet_classes.txt

In [0]:
!wget -q https://upload.wikimedia.org/wikipedia/commons/thumb/1/17/Tiger_in_Ranthambhore.jpg/1024px-Tiger_in_Ranthambhore.jpg -O tiger.jpg
!wget -q https://github.com/nmilosev/pytorch-arm-builds/raw/master/imagenet_classes.txt

In [0]:
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

In [0]:
labels[:5]

# Initialize model

Do not forget to enable evaluation mode

In [0]:
net = models.shufflenet_v2_x1_0(pretrained=True)
net.eval()

# Normalization

Standard procedure of normalizing ImageNet images. This is so our image is in the same representation as the images in the original dataset.

In [0]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])

# Helper method

Here we define a helper method for inference which receives a PIL image and returns top 5 classes.

In [0]:
def infer(img):
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)
    out = net(batch_t)
    percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
    _, indices = torch.sort(out, descending=True)
    result = [(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]
    return result

# Load image

In [0]:
img = Image.open('tiger.jpg')
plt.imshow(img)

# Run inference on our image

In [0]:
infer(img)