In [18]:
import torch
from torchvision import models
from torchvision import transforms


print('finished importing')

finished importing


In [5]:
alexnet = models.alexnet(pretrained=True)
print(alexnet)

Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /Users/christianbauer/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth
100.0%
AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): Adaptive

In [19]:
def get_transform() -> transforms:
    """
    [1] -> Transform instance which is a combination of all the image transformations to be carried out on the input image
    [2] -> Rezise the image to 256x256
    [3] -> Crop the image to 224x224 about the center
    [4] -> Convert the image to Pytorch Tensor data type
    [5-7] -> Normalize the image by setting its mean and standard deviation to the specified values
    """
    ret_val = transforms.Compose([                    # [1]
        transforms.Resize(256),             # [2]
        transforms.CenterCrop(224),         # [3]
        transforms.ToTensor(),              # [4]
        transforms.Normalize(               # [5]
            mean=[0.485, 0.456, 0.406],     # [6]           
            std=[0.229, 0.224, 0.225]       # [7]
        )
    ])
    return ret_val
transform = get_transform()

Compose(
    Resize(size=256, interpolation=bilinear)
    CenterCrop(size=(224, 224))
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)

# Load class lables for model inference

In [26]:
from PIL import Image
img = Image.open('testimages/dogs/balu_hat.jpeg')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)

# Put our model in eval mode

In [25]:
alexnet.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

## Carry out the inference

In [28]:
out = alexnet(batch_t)
print(out.shape)

torch.Size([1, 1000])


## Load the class labels

In [23]:
with open('imagenet1000_classes.txt') as image_file:
    classes = [line.strip() for line in image_file.readlines()]


["{0: 'tench, Tinca tinca',", "1: 'goldfish, Carassius auratus',", "2: 'great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias',", "3: 'tiger shark, Galeocerdo cuvieri',", "4: 'hammerhead, hammerhead shark',"]


In [29]:
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())

256: 'Newfoundland, Newfoundland dog', 79.62720489501953
