<a href="https://colab.research.google.com/github/LilianYou/dark-lily/blob/master/5_pytorch_pretrained.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch import optim
import numpy as np
from torch import nn
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from PIL import Image
import requests
import io
from torchvision import models, transforms
from torch.autograd import Variable
import random, time, sys
from collections import namedtuple
import numpy as np
import matplotlib.pyplot as plt

normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
   transforms.Resize(256),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   normalize
])


def image_loader(image_url = 'example_image.png'):
    """load image, returns cuda tensor"""
    response = requests.get(image_url)
    img_pil = Image.open(io.BytesIO(response.content))
    img_tensor = preprocess(img_pil)
    img_tensor.unsqueeze_(0)
    img_variable = Variable(img_tensor)
    img_pil.close()
    return img_variable

def get_imagenet_labels():
    import requests
    LABELS_URL = 'https://s3.amazonaws.com/outcome-blog/imagenet/labels.json'
    imagenet_labels = {int(key):value for (key, value) in requests.get(LABELS_URL).json().items()}
    return imagenet_labels

import warnings
warnings.filterwarnings('ignore')

def print_labels(output_classid, topk=3):
    imagenet_labels = get_imagenet_labels()
    print(''.join(['Rank '+str(i+1) + ' ' +imagenet_labels[cid]+' \n' for i,cid in enumerate(output_classid[:topk])])) #Most Likely

imagenet_labels = get_imagenet_labels()

# Pre-trained models
- Large neural networks can take days to train only multiple GPUs!
- PyTorch provides pre-trained networks, see here https://pytorch.org/docs/stable/torchvision/models.html
- Caution: Some "famous" networks will result in downloading >.5Gb of parameter data
- Pretrained vision models expect a certain data format:
> All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224.

In [0]:
nnet=torchvision.models.squeezenet1_0(pretrained=True);
nnet.eval();

Downloading: "https://download.pytorch.org/models/squeezenet1_0-a815701f.pth" to /root/.cache/torch/checkpoints/squeezenet1_0-a815701f.pth
100%|██████████| 4.79M/4.79M [00:00<00:00, 7.31MB/s]


# Classification with a pre-trained network
Let's try out the classifier!
<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/0/09/A6-EDY_A380_Emirates_31_jan_2013_jfk_%288442269364%29_%28cropped%29.jpg/450px-A6-EDY_A380_Emirates_31_jan_2013_jfk_%288442269364%29_%28cropped%29.jpg" width=10%/>
1. Find the URL a picture from the internet
2. Load as an array, reshape, and crop using the provided image_loader function
3. Feed it to the network
4. Get the most likely class ids
5. Print the three most likely labels using the provided print_labels function

In [0]:
a380_url = "https://upload.wikimedia.org/wikipedia/commons/0/09/A6-EDY_A380_Emirates_31_jan_2013_jfk_%288442269364%29_%28cropped%29.jpg"
data = image_loader(a380_url)

In [0]:
output = nnet(data)
output_classid = output.data.numpy().argsort().squeeze()[::-1]; output_classid
print_labels(output_classid)
output.data.numpy().sort()
nn.Softmax(dim=-1)(output)

Rank 1 airliner 
Rank 2 wing 
Rank 3 warplane, military plane 



tensor([[6.9055e-14, 1.2433e-13, 1.4041e-13, 1.4085e-13, 1.4260e-13, 1.4855e-13,
         1.6276e-13, 1.7337e-13, 1.7562e-13, 1.7886e-13, 1.8197e-13, 2.0160e-13,
         2.0748e-13, 2.0868e-13, 2.1114e-13, 2.2899e-13, 2.5128e-13, 2.6412e-13,
         3.1356e-13, 3.2630e-13, 3.5009e-13, 3.5188e-13, 3.5306e-13, 3.6800e-13,
         3.8951e-13, 3.9250e-13, 4.1948e-13, 4.2376e-13, 4.2501e-13, 4.3119e-13,
         4.4144e-13, 4.5061e-13, 4.5094e-13, 4.6139e-13, 4.6505e-13, 4.7819e-13,
         4.9815e-13, 5.0940e-13, 5.1500e-13, 5.3426e-13, 5.4304e-13, 5.4329e-13,
         5.7304e-13, 5.7595e-13, 5.8692e-13, 6.0142e-13, 6.0494e-13, 6.1340e-13,
         6.1844e-13, 6.2895e-13, 6.4074e-13, 6.4632e-13, 6.4678e-13, 6.5609e-13,
         6.6997e-13, 6.7066e-13, 7.1587e-13, 7.1755e-13, 7.2563e-13, 7.3823e-13,
         7.3893e-13, 7.8339e-13, 7.8438e-13, 7.9005e-13, 8.0006e-13, 8.0964e-13,
         8.2653e-13, 8.2712e-13, 8.6114e-13, 8.7033e-13, 8.8015e-13, 9.1281e-13,
         9.1388e-13, 9.2039e