In [12]:
from keras.applications.vgg16 import VGG16, preprocess_input, decode_predictions
from keras.models import Model
from keras.optimizers import SGD
import keras.utils as kUtils # ? Solves a deprecation in keras.preprocessing.
import numpy as np

In [14]:
# Load the pre-trained VGG16 model + higher level layers
# ! This is a 553MB download.
model = VGG16(weights='imagenet', include_top=True)

In [15]:
# Show all layers & their shapes in sequential order.
for i, layer in enumerate(model.layers):
    print(i, layer.name, layer.output_shape)

0 input_2 [(None, 224, 224, 3)]
1 block1_conv1 (None, 224, 224, 64)
2 block1_conv2 (None, 224, 224, 64)
3 block1_pool (None, 112, 112, 64)
4 block2_conv1 (None, 112, 112, 128)
5 block2_conv2 (None, 112, 112, 128)
6 block2_pool (None, 56, 56, 128)
7 block3_conv1 (None, 56, 56, 256)
8 block3_conv2 (None, 56, 56, 256)
9 block3_conv3 (None, 56, 56, 256)
10 block3_pool (None, 28, 28, 256)
11 block4_conv1 (None, 28, 28, 512)
12 block4_conv2 (None, 28, 28, 512)
13 block4_conv3 (None, 28, 28, 512)
14 block4_pool (None, 14, 14, 512)
15 block5_conv1 (None, 14, 14, 512)
16 block5_conv2 (None, 14, 14, 512)
17 block5_conv3 (None, 14, 14, 512)
18 block5_pool (None, 7, 7, 512)
19 flatten (None, 25088)
20 fc1 (None, 4096)
21 fc2 (None, 4096)
22 predictions (None, 1000)


In [16]:
sgd = SGD(learning_rate=0.1, momentum=0.9)
model.compile(optimizer=sgd, loss='categorical_crossentropy')

In [17]:
# Load and preprocess a test image
img_path = 'cat.jpg'
img = kUtils.load_img(img_path, target_size=(224, 224))
x = kUtils.img_to_array(img) # shape: (224, 224, 3)
x = np.expand_dims(x, axis=0) # shape: (1, 224, 224, 3)
x = preprocess_input(x) # subtracts ImageNet mean, etc.

In [18]:
# Predict
out = model.predict(x)
print(out.shape) # (1, 1000)
print(np.argmax(out)) # index of the predicted class

# Decode the prediction to human-readable label
decoded = decode_predictions(out, top=1)
label = decoded[0][0] # first batch element, top result
print('%s (%.2f%%)' % (label[1], label[2] * 100))

(1, 1000)
283
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
Persian_cat (22.33%)
