In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf

from easydict import EasyDict
import tensorflow_datasets as tfds
from datetime import datetime
print(tf.__version__)

In [0]:
# For faster training, GPU should be available. Go to: Editor->Notebook settings and select the hardware acceleration (TPU does not work atm)
print("GPU Available: ", tf.test.is_gpu_available())

In [0]:
# Installing newest version of tensorflow_hub otherwise we get an error
!pip install --upgrade tensorflow_hub
import tensorflow_hub as hub
hub.__version__

In [0]:
IMG_SIZE = 96
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)
basemodel_path = "mobilenet"

In [0]:
def ld_cifar():
    def format_example(image, label):
        image = tf.cast(image, tf.float32)
        image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
        image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
        return image, label

    BATCH_SIZE = 128
    train_split_weights = (9, 1)
    train_split = tfds.Split.TRAIN.subsplit(weighted=train_split_weights)
    test_split = tfds.Split.TEST.subsplit(weighted=(1,))
    (raw_train, raw_validation, raw_test), metadata = tfds.load('cifar10', split=list(train_split + test_split),
                                                                with_info=True,
                                                                as_supervised=True)

    train_batches = raw_train.map(format_example).shuffle(1000).batch(BATCH_SIZE).repeat()
    validation_batches = raw_validation.map(format_example).batch(BATCH_SIZE).repeat()
    test_batches = raw_test.map(format_example).batch(BATCH_SIZE)

    num_train, num_validation = (
        metadata.splits['train'].num_examples * weight / 10
        for weight in train_split_weights
    )

    num_test = metadata.splits['test'].num_examples

    train_steps = round(num_train) // BATCH_SIZE
    validation_steps = round(num_validation) // BATCH_SIZE
    test_steps = round(num_test) // BATCH_SIZE
    
    steps_dict = EasyDict(train_steps=train_steps, validation_steps=validation_steps, test_steps=test_steps)
    return EasyDict(train=train_batches, validation=validation_batches, test=test_batches,
                    steps=steps_dict)

In [0]:
def main():
    data = ld_cifar()    
    feature_extractor_layer = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=IMG_SHAPE)
    feature_extractor_layer.trainable = False
    
    model = tf.keras.Sequential([
        feature_extractor_layer,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(0.5),     
        tf.keras.layers.Dense(10, activation="softmax")
    ])

    model.summary()
    
    opt = tf.keras.optimizers.SGD()
    model.compile(
        optimizer=opt,
        loss='sparse_categorical_crossentropy',
        metrics=['acc'])
    
    now = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_dir = "./" + basemodel_path + "_logs/fit/" + now
    checkpoint_path = log_dir + "/weights/" + basemodel_path + "_weights"
    
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir + "/tensorboard")
    es_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, save_best_only=True, verbose=True)

    model.fit(data.train, epochs=50, steps_per_epoch=data.steps.train_steps, validation_data=data.validation, validation_steps=data.steps.validation_steps, callbacks=[tensorboard_callback, es_callback, cp_callback])
    model.load_weights(checkpoint_path) # Load best model
    model.evaluate(data.test, steps=data.steps.test_steps)
    

In [0]:
main()

In [0]:
# Helper to zip and download weights
def zip_and_download_logs():
  try:
    !zip -r logs.zip mobilenet_logs
#     from google.colab import files Works better to download tham manually
#     files.download('logs.zip') 
  except Exception as e:
    print(e)
zip_and_download_logs()

In [0]:
# Helper
def remove_created_files():
  import shutil
  import os
  shutil.rmtree("./mobilenet_logs")
  os.remove("./logs.zip")
# remove_created_files()

In [0]:
# Training with TPU. Does not work atm since it is still in experimental stage.
# tpu_strategy = tf.distribute.experimental.TPUStrategy()
# with tpu_strategy.scope():
#   main()