In [1]:
from torchvision import models

In [2]:
alexnet = models.AlexNet()

In [3]:
resnet = models.resnet101(pretrained=True)



In [4]:
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [5]:
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(226),
                                transforms.ToTensor(),
                                transforms.Normalize(
                                    mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]
                                )])

In [6]:
from PIL import Image

In [7]:
img = Image.open("bobby.png")

In [8]:
img.show()

In [9]:
img_t = preprocess(img)

In [10]:
import torch
batch_t = torch.unsqueeze(img_t, 0)

In [11]:
resnet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:
out = resnet(batch_t)

In [13]:
out

tensor([[-2.6669e+00, -6.3305e-01, -2.0768e+00, -2.8874e+00, -3.6836e+00,
         -1.0111e+00, -2.0622e+00, -2.2859e+00, -7.4176e-01, -2.3564e+00,
         -1.4962e+00, -8.1862e-01, -1.9687e+00, -2.8206e+00, -2.4564e+00,
         -1.8665e+00, -2.2856e+00, -5.3533e-01, -6.6849e-02, -2.6929e-01,
         -2.2879e+00, -2.9237e+00, -1.3560e+00, -3.4610e-01, -1.0395e+00,
         -8.8899e-01, -2.6200e+00, -2.3711e+00, -2.6203e+00, -2.2671e+00,
         -2.9333e+00, -1.6149e+00, -1.1653e+00, -1.4308e+00, -2.1787e+00,
         -2.6310e+00, -9.6855e-01, -9.2492e-01, -9.7570e-01, -1.0822e+00,
         -3.0576e-01, -1.6790e+00,  8.3466e-01,  9.4613e-02, -1.9601e+00,
         -1.3090e+00,  3.4557e-01, -1.1514e+00, -2.3866e+00, -2.3161e+00,
         -1.7685e+00, -1.1849e+00, -2.0881e+00, -2.6026e+00, -2.3625e+00,
         -1.7506e+00, -1.6391e+00, -2.4976e+00, -2.9841e+00, -5.3286e-01,
         -9.4467e-01, -1.4177e+00, -1.1234e+00, -1.1655e+00, -1.8669e+00,
         -2.1423e+00, -1.9696e+00, -1.

In [14]:
# imagenet-classes.txt
with open('imagenet-classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]

In [15]:
labels

['tench, Tinca tinca',
 'goldfish, Carassius auratus',
 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',
 'tiger shark, Galeocerdo cuvieri',
 'hammerhead, hammerhead shark',
 'electric ray, crampfish, numbfish, torpedo',
 'stingray',
 'cock',
 'hen',
 'ostrich, Struthio camelus',
 'brambling, Fringilla montifringilla',
 'goldfinch, Carduelis carduelis',
 'house finch, linnet, Carpodacus mexicanus',
 'junco, snowbird',
 'indigo bunting, indigo finch, indigo bird, Passerina cyanea',
 'robin, American robin, Turdus migratorius',
 'bulbul',
 'jay',
 'magpie',
 'chickadee',
 'water ouzel, dipper',
 'kite',
 'bald eagle, American eagle, Haliaeetus leucocephalus',
 'vulture',
 'great grey owl, great gray owl, Strix nebulosa',
 'European fire salamander, Salamandra salamandra',
 'common newt, Triturus vulgaris',
 'eft',
 'spotted salamander, Ambystoma maculatum',
 'axolotl, mud puppy, Ambystoma mexicanum',
 'bullfrog, Rana catesbeiana',
 'tree frog, tree-f

In [16]:
_, index = torch.max(out, 1)

In [17]:
index

tensor([207])

In [18]:
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
labels[index[0]], percentage[index[0]].item()

('golden retriever', 94.82124328613281)

In [19]:
_, indices = torch.sort(out, descending=True)
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]

[('golden retriever', 94.82124328613281),
 ('cocker spaniel, English cocker spaniel, cocker', 2.4743893146514893),
 ('Labrador retriever', 1.7057245969772339),
 ('redbone', 0.10562468320131302),
 ('Irish setter, red setter', 0.08026900887489319)]

In [20]:
flower_img = Image.open("flower.jpg")

In [21]:
flower_img.show()

In [22]:
flower_img_t = preprocess(flower_img)

In [23]:
batch_t = torch.unsqueeze(flower_img_t, 0)

In [24]:
out = resnet(batch_t)

In [25]:
out

tensor([[-9.2134e-01,  2.4723e+00, -1.5630e+00, -1.1914e+00,  7.9858e-02,
          3.4111e+00,  1.8897e-01,  3.1460e+00,  6.9476e-01, -1.9332e+00,
         -3.1268e+00, -2.3248e+00, -2.3487e+00, -2.6226e+00,  1.1502e+00,
         -3.3319e+00, -1.9047e+00,  6.5934e-01, -1.0779e+00,  1.7284e-01,
         -2.7812e+00, -2.6687e+00, -2.2839e+00,  4.1294e-01, -2.6705e+00,
         -2.0314e+00, -9.4407e-01, -2.2425e-02,  1.0766e+00, -7.4741e-02,
         -5.6137e-01,  1.4257e+00,  1.1785e+00,  6.2516e-01,  2.2287e+00,
         -1.2259e+00,  7.7980e-01, -1.4582e+00,  5.7789e-01,  9.7299e-01,
         -7.4610e-01, -1.8156e+00,  5.2746e-01,  2.9221e+00, -1.2157e+00,
          7.6495e-01, -5.1446e-01, -1.5628e+00, -2.1398e+00, -1.5833e+00,
         -2.6123e+00,  5.0988e+00,  6.2954e-01,  1.0248e+00,  1.5408e-01,
          2.7363e+00,  5.3312e-01, -1.9594e+00,  3.1176e-01,  1.7925e+00,
          1.2894e+00,  8.2233e-01,  1.5573e+00,  3.6072e-01,  1.2687e+00,
          2.0810e+00,  1.8602e+00, -8.

In [26]:
_, index = torch.max(out, 1)

In [27]:
index

tensor([107])

In [28]:
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
labels[index[0]], percentage[index[0]].item()

('jellyfish', 30.89031219482422)

In [29]:
_, indices = torch.sort(out, descending=True)
[(labels[idx], percentage[idx].item()) for idx in indices[0][:35]]

[('jellyfish', 30.89031219482422),
 ('head cabbage', 21.822839736938477),
 ('knot', 6.181674957275391),
 ('mortar', 3.0100040435791016),
 ('isopod', 2.6433591842651367),
 ('cauliflower', 2.4141666889190674),
 ('swab, swob, mop', 1.8999606370925903),
 ('sea anemone, anemone', 1.8166102170944214),
 ('hermit crab', 1.7150697708129883),
 ('coral reef', 1.5847384929656982),
 ('sea slug, nudibranch', 1.3024439811706543),
 ('starfish, sea star', 1.2349121570587158),
 ('ice cream, icecream', 1.086985468864441),
 ('velvet', 1.080807089805603),
 ('chain', 1.0241081714630127),
 ('brain coral', 0.9706169366836548),
 ('chambered nautilus, pearly nautilus, nautilus', 0.9664548635482788),
 ('coil, spiral, volute, whorl, helix', 0.683440089225769),
 ('crayfish, crawfish, crawdad, crawdaddy', 0.6280921101570129),
 ('sea urchin', 0.5656934976577759),
 ('vase', 0.5301814675331116),
 ('perfume, essence', 0.5199890732765198),
 ('bell pepper', 0.5063238143920898),
 ('scorpion', 0.4940244257450104),
 ('goble