In [1]:
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torchvision import transforms, models

In [2]:
model = models.vgg16(pretrained=True)

# 新的全连接层
model.classifier = nn.Sequential(nn.Linear(25088, 100), 
                                 nn.ReLU(), 
                                 nn.Dropout(p=0.5), 
                                 nn.Linear(100, 2))

model.load_state_dict(torch.load('model/cat_dog.pth'))

model.eval()



VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

In [3]:
label = np.array(['cat', 'dog'])

# 预处理
transform = transforms.Compose([transforms.Resize(224), 
                                transforms.ToTensor()])

In [4]:
def predict(image_path):
    # 打开图片
    img = Image.open(image_path)
    # 增加一个维度
    img = transform(img).unsqueeze(0)
    # 预测
    output = model(img)
    # 最大值位置
    _, predicted = torch.max(output, 1)
    # 类别名称
    print(label[predicted.item()])

In [5]:
route = 'data/dog_vs_cat/train/'
filedot = '.jpg'
images = ['cat.1358', 'dog.1358', 'cat.5297', 'dog.5297']

for image in images:
    print(image + ': ')
    predict(route + image + filedot)
    print()

cat.1358: 
cat

dog.1358: 
dog

cat.5297: 
cat

dog.5297: 
dog

