# Retraining Inception model for CIFAR-10 images
--------------------------------------

In this script, we download the CIFAR-10 images and transform them in the Inception Retraining Format. The end purpose of the files is for re-training the Google Inception tensorflow model to work on the CIFAR-10.

We start by loading the necessary libraries.

In [1]:
import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.applications.inception_v3 import InceptionV3
from tensorflow.keras.applications.inception_v3 import preprocess_input, decode_predictions

Next, set the parameters.

 - `batch_size`: this is how many cifar examples to train on in one batch.
 - `buffer_size`: this is how many cifar examples to shuffle randomly on a buffer

In [2]:
# Set dataset parameters
batch_size = 32
buffer_size= 1000

Download the CIFAR-10 dataset using `tf.keras.datasets` API.

CIFAR is composed of 50k train and 10k test images that are 32x32.

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

objects = ['airplane', 'automobile', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck']

Initialize the data pipeline using `tf.data.DataSet` for train and test datasets.

InceptionV3 is pretrained on ImageNet dataset, so our CIFAR-10 images must match the format of these images:

- The width and height expected should be no smaller than 75, so we will resize our images to 75x75 spatial size.

- The images should be normalized, so we will apply the inception preprocessing task (`preprocess_input` method) on each image.



In [4]:
dataset_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
dataset_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))

def preprocess_cifar10(img, label):
    img = tf.cast(img, tf.float32)
    img = tf.image.resize(img, (75, 75))
    return tf.keras.applications.inception_v3.preprocess_input(img) , label

dataset_train_processed = dataset_train.shuffle(buffer_size).batch(batch_size).map(preprocess_cifar10)
dataset_test_processed = dataset_test.batch(batch_size).map(preprocess_cifar10)

Now, we will create our own model based on the Inception V3 model.

We will load the Inception V3 model using the `tensorflow.keras.applications` API. This API contains pre-trained deep learning models that can be used for prediction, feature extraction and fine-tuning.

Then, we will load the weights without the classification head.

In [5]:
inception_model = InceptionV3(
    include_top=False,
    weights="imagenet",
    input_shape=(75,75,3)
)

We build our own model on top of the Inception V3 model by adding a classifier with 3 fully-connected layers.



In [6]:
x = inception_model.output
x= keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dense(1024, activation="relu")(x)
x = keras.layers.Dense(128, activation="relu")(x)
output = keras.layers.Dense(10, activation="softmax")(x)

model=keras.Model(inputs=inception_model.input, outputs = output)

for inception_layer in inception_model.layers:
    inception_layer.trainable= False


model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

Display the model summary

In [7]:
model.summary()

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 75, 75, 3)]  0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 37, 37, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 37, 37, 32)   96          conv2d[0][0]                     
__________________________________________________________________________________________________
activation (Activation)         (None, 37, 37, 32)   0           batch_normalization[0][0]        
_______________________________________________________________________________________

In [8]:
model.fit(x=dataset_train_processed , 
          validation_data=dataset_test_processed)



<tensorflow.python.keras.callbacks.History at 0x7fe8086ac400>