# Create our model

In [1]:
import torch
from scipy import misc
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import torchvision.models as models # contains the the VGG16 pretrained network.  A complete list of the pre-trained networks available are given on the [torchvision.models](https://pytorch.org/docs/stable/torchvision/models.html) page.

In [5]:
model = models.vgg16(pretrained=True) # load the VGG16 model ~ 500MB
print(str(model)) # lets take a look at the model's layers

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [10]:
from PIL import Image
import torchvision.transforms as transforms

# Load and preprocess the image
image = Image.open('my_image.jpeg')
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
image = preprocess(image)
image = image.unsqueeze(0)  # simulate a batch

# Pass the image through the model
output = model(image)

# Load ImageNet classes from a .txt file
with open('classes.txt', 'r') as f:
    imagenet_classes = [line.strip() for line in f.readlines()]

# Interpret the output
_, predicted_class = torch.max(output, 1)
predicted_class = predicted_class.item()

# Print the class name
print('Predicted class:', imagenet_classes[predicted_class])

Predicted class: 386: 'African elephant, Loxodonta africana',
