In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow import lite

from keras import layers

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import pathlib

In [2]:
tfds.disable_progress_bar()

In [3]:
dataset_name = 'cats_vs_dogs'
model_url = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4'
checkpoint_filepath = 'tmp/checkpoint'
model_filepath = 'model/'
tflite_path = 'cat_vs_dog.tflite'

input_width, input_height, channels = 224, 224, 3

model_input_shape = (input_width, input_height, channels)

# Loading the dataset

In [4]:
train_dataset, val_dataset = tfds.load(name=dataset_name, split=['train[:90%]', 'train[90%:]'], as_supervised=True)

In [5]:
def preprocess(image, label):
  image = tf.image.resize(image, [input_width, input_height])
  image = image / 255.

  return image, label

In [6]:
batch_size = 32

In [7]:
train_dataset = train_dataset.map(preprocess).shuffle(1000).batch(batch_size).prefetch(1)
val_dataset = val_dataset.map(preprocess).batch(batch_size).prefetch(1)

# Loading the model

In [8]:
pretrained_layer = hub.KerasLayer(model_url,
                                  weights='imagenet',
                                  trainable=False,
                                  name='mobile_net',
                                  output_shape=[1280])

In [9]:
model = keras.models.Sequential([
  pretrained_layer,
  layers.Dense(2, activation='softmax', name='output_layer')
], name='classifier')

In [10]:
model.build([None, *model_input_shape])

In [11]:
model.summary()

Model: "classifier"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 mobile_net (KerasLayer)     multiple                  2257984   
                                                                 
 output_layer (Dense)        multiple                  2562      
                                                                 
Total params: 2,260,546
Trainable params: 2,562
Non-trainable params: 2,257,984
_________________________________________________________________


# Training the model

In [12]:
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

In [13]:
callbacks = [
  keras.callbacks.EarlyStopping(monitor='val_loss', patience=3),
  keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True
  )
]

In [14]:
model.fit(train_dataset,
          epochs=500,
          validation_data=val_dataset,
          callbacks=callbacks
)

Epoch 1/500
Epoch 2/500
Epoch 3/500
Epoch 4/500
Epoch 5/500
Epoch 6/500


<keras.callbacks.History at 0x7f519fdc9350>

In [15]:
model.load_weights(checkpoint_filepath)

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f519d1dc890>

# Saving and converting to tflite model

In [16]:
tf.saved_model.save(model, model_filepath)

INFO:tensorflow:Assets written to: model/assets


INFO:tensorflow:Assets written to: model/assets


In [17]:
convertor = lite.TFLiteConverter.from_saved_model(model_filepath)
convertor.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE]

tflite_model = convertor.convert()



In [18]:
path = pathlib.Path(tflite_path)
path.write_bytes(tflite_model)

2649152

# Downloading the converted model

In [20]:
from google.colab import files

In [21]:
files.download(tflite_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>