In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import plot_model
from tensorflow.keras import models

In [2]:
#load cifar10 dataset
(training_images, training_labels), (test_images, test_labels) = keras.datasets.cifar10.load_data()
print(training_images.shape)

training_images = training_images.astype('float32')
test_images = test_images.astype('float')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
(50000, 32, 32, 3)


In [3]:
train_norm = training_images / 255.0 - 1
test_norm = test_images / 255.0 -1

In [4]:
#load base model (VGG16)
from keras.applications.vgg16 import VGG16

base_model = VGG16(
    include_top=False,
    weights="imagenet",
    input_tensor=None,
    input_shape=(32, 32, 3),
    pooling="max"
)
base_model.summary()



Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "vgg16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 block1_conv1 (Conv2D)       (None, 32, 32, 64)        1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 32, 32, 64)        36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 16, 16, 64)        0         
                                                                 
 block2_conv1 (Conv2D)       (None, 16, 16, 128)       73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 16, 16, 128)      

In [5]:
#Freeze all layers in the base model 
base_model.trainable = False

#Create new model by adding two dense layers on top of base model
inputs = keras.Input((32, 32, 3))

x = base_model(inputs, training=False)

x = keras.layers.Dense(units=128, activation='relu')(x)

outputs = keras.layers.Dense(units=10, activation='softmax')(x)

model = keras.models.Model(inputs, outputs)

model.summary()


Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 vgg16 (Functional)          (None, 512)               14714688  
                                                                 
 dense (Dense)               (None, 128)               65664     
                                                                 
 dense_1 (Dense)             (None, 10)                1290      
                                                                 
Total params: 14,781,642
Trainable params: 66,954
Non-trainable params: 14,714,688
_________________________________________________________________


In [6]:
model.compile(optimizer=keras.optimizers.Adam(), 
              loss=keras.losses.SparseCategoricalCrossentropy(), 
              metrics=keras.metrics.SparseCategoricalAccuracy())

In [8]:
model.fit(x=train_norm, y=training_labels, epochs=10, validation_data=(test_norm, test_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7f82a02ef4d0>

In [15]:
#Do a round of fine-tuning of the entire model

base_model.trainable = True

In [17]:

model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-5), # Low learning rate
              loss=keras.losses.SparseCategoricalCrossentropy(), 
              metrics=keras.metrics.SparseCategoricalAccuracy())

In [None]:
model.fit(x=train_norm, y=training_labels, epochs=10, validation_data=(test_norm, test_labels))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fae855b3350>