# PyTorch pretrained model demo

> *Demo based on [this gist](https://gist.github.com/jkarimi91/d393688c4d4cdb9251e3f939f138876e).*

This script will demonstrate how to use a pretrained model, in PyTorch, 
to make predictions. Specifically, we will be using VGG16 with a random 
image from the internet.

References:
* PyTorch pretrained models doc: http://pytorch.org/docs/master/torchvision/models.html
* PyTorch image transforms example: http://pytorch.org/tutorials/beginner/data_loading_tutorial.html#transforms

In [None]:
import io

from PIL import Image
import requests

import numpy as np

Let's download some image.

In [None]:
IMG_URL = 'https://upload.wikimedia.org/wikipedia/en/5/5f/Original_Doge_meme.jpg'

response = requests.get(IMG_URL)
img = Image.open(io.BytesIO(response.content))  # Read bytes and store as an img.

print(f'Image size: {img.size}')
img

In [None]:
# Class labels used when training VGG as json, courtesy of http://blog.outcome.io/pytorch-quick-start-classifying-an-image/
LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'

# Let's get our class labels.
response = requests.get(LABELS_URL)  # Make an HTTP GET request and store the response
labels = {int(key): value for key, value in response.json().items()}  # Parse the response
assert set(labels.keys()) == set(range(1000))  # Make sure the labels have expected format
labels = np.array([v for (k, v) in sorted(labels.items())])  # Transform them into a Numpy array for convenience

print(f'Total labels: {len(labels)}')
print(f'Example labels: {labels[200:205]}')

Now that we have an img, we need to preprocess it.
We need to:

* Resize the image to 224x224 px, preferably preserving aspect ratio;
* Convert it to a PyTorch Tensor;
* Normalize it, as noted in the [PyTorch pretrained models doc](https://pytorch.org/vision/stable/models.html),
with `mean = [0.485, 0.456, 0.406]` and `std = [0.229, 0.224, 0.225]`.

We can do all this preprocessing using a transform pipeline.

In [None]:
import torchvision.transforms as transforms

transform_pipeline = transforms.Compose([
    # Transforms taken from https://github.com/pytorch/examples/blob/4db11160c21d0e26634ca1fcb94a73ad8d870ba7/imagenet/main.py#L227-L230
    transforms.Resize(256),      # Resize smaller side of the image to 256 px, preserving aspect ratio
    transforms.CenterCrop(224),  # Crop a square with size 224x224 px from the center of the image
    transforms.ToTensor(),       # Convert PIL image (uint8, 0..255) to PyTorch Tensor (float32, 0..1)
    transforms.Normalize(        # Subtract mean and divide by std
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    )
])
tensor = transform_pipeline(img)

PyTorch pretrained models expect the Tensor dims to be (num input imgs, num color channels, height, width).
Currently however, we have (num color channels, height, width); let's fix this by inserting a new axis.

In [None]:
print(f'shape before: {tensor.shape}')
tensor = tensor.unsqueeze(0)  # Insert the new axis at index 0 i.e. in front of the other axes/dims.
print(f'shape after: {tensor.shape}')

Now let's load our model and get a prediction!

In [None]:
import torchvision.models as models

vgg = models.vgg16(pretrained=True)  # This may take a few minutes.

# Switch the model from training mode to inference (evaluation) mode. This alters the behavior of some layers.
vgg.eval()

In [None]:
logits_tensor = vgg(tensor)  # Returns a Tensor of shape (batch, num class labels)
logits = logits_tensor.detach().numpy()  # Disconnect torch.Tensor from the computational graph and convert it into a Numpy array
logits = logits[0]  # Undo .unsqueeze()
prediction = logits.argmax()  # Our prediction will be the index of the class label with the largest value.
print(labels[prediction])  # Converts the index to a string using our labels dict

Let's also compute top 5 predictions and their corresponding probabilities:

In [None]:
import scipy.special

indices = logits.argsort()[-5:][::-1]
probs = scipy.special.softmax(logits)
for idx in indices:
    print(f'{probs[idx] * 100:>5.2f} | {labels[idx]}')

img