In [7]:
import tensorflow as tf
import tensorflow_datasets as tfds

from sklearn.datasets import load_sample_images 


  from .autonotebook import tqdm as notebook_tqdm


# Transfer learning

Using CNNs models that are built into Keras.

First we see how to load a model (ResNet50) and use it to make predictions

In [6]:
# Getting a ResNet 50 model that was trained on the imagenet dataset
model = tf.keras.applications.ResNet50(weights="imagenet")

# This model expects 224x224 images. We'll resize the images
images = load_sample_images()["images"]
images_resized = tf.keras.layers.Resizing(height=224, width=224, crop_to_aspect_ratio=True)(images)

# Built-in models have a preprocess_input function that does additional needed preprocessing
inputs = tf.keras.applications.resnet50.preprocess_input(images_resized)

# Making predictions
Y_proba = model.predict(inputs)

# This model is trained to classify 1000 types of objects in images. Therefore, the output dims is (input_size, 1000)
print (Y_proba.shape)

# The model comes with a "decode prediction" function that puts a label to the result
top_K = tf.keras.applications.resnet50.decode_predictions(Y_proba, top=3)
for image_index in range(len(images)):
  print (f"Image #{image_index}")
  for class_id, name, y_proba in top_K[image_index]:
    print (f"  {class_id} - {name:12s} {y_proba:.2%}")


(2, 1000)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json
Image #0
  n03877845 - palace       54.69%
  n03781244 - monastery    24.71%
  n02825657 - bell_cote    18.55%
Image #1
  n04522168 - vase         32.67%
  n11939491 - daisy        17.82%
  n03530642 - honeycomb    12.04%


### Xception as a base model to classify flower types

Now let's use Xception to classify types of flowers

In [14]:
# Load the flowers dataset. Use a train/valid/test split
dataset, info = tfds.load("tf_flowers", 
                          split=["train[:10%]", "train[10%:25%]", "train[25%:]"],
                          as_supervised=True, 
                          with_info=True)

# get the dataset info
dataset_size = info.splits["train"].num_examples
class_names = info.features["label"].names
n_classes = info.features["label"].num_classes

test_set_raw, valid_set_raw, train_set_raw = dataset

In [15]:
# Preprocess the images

batch_size = 32

# Resizing and using the Xception built-in preprocessing as a single Keras preprocessing model
preprocess = tf.keras.Sequential([
  tf.keras.layers.Resizing(height=224, width=224, crop_to_aspect_ratio=True),
  tf.keras.layers.Lambda(tf.keras.applications.xception.preprocess_input)
])
train_set = train_set_raw.map(lambda X,y: (preprocess(X), y))

# Shuffle and batch the training set
train_set = train_set.shuffle(1000, seed=42).batch(batch_size).prefetch(1)

# Preprocess for validation and test sets
valid_set = valid_set_raw.map(lambda X,y: (preprocess(X), y))
test_set = test_set_raw.map(lambda X,y: (preprocess(X), y))


In [17]:
# Adding augmentation 

# During training, it will randomly augment images using this pipeline
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.RandomFlip(mode="horizontal", seed=42),
  tf.keras.layers.RandomRotation(factor=0.5, seed=42),
  tf.keras.layers.RandomContrast(factor=0.5, seed=42),
])

In [19]:
# Loading the Xception model

# We set include_top=False so that it excludes the global avg pooling and dense output layer.
# We'll add our own output softmax layer for the flowers labels
base_model = tf.keras.applications.xception.Xception(weights="imagenet", include_top=False)

# Adding our own "top" layers
avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation="softmax")(avg)
model = tf.keras.Model(inputs=base_model.input, outputs=output)

# Freezing the weights of the pretrained layers so that we don't corrupt them during training
for layer in base_model.layers:
  layer.trainable = False

optimizer = tf.keras.optimizers.SGD(learning_rate=0.1, momentum=0.9)
model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])



In [None]:
# Fitting - USE GPU! very slow

# We start by doing 3 epochs on the new top with everything below it frozen
history = model.fit(train_set, validation_data=valid_set, epochs=3)

In [None]:
# Fitting more layers - USE GPU! very slow

# Now that we calibrated the top, we can unfreeze more layers below for training. The first calibration ensures 
# that the large gradients don't corrupt the well trained layer weights

for layer in base_model.layers[56:]:
  layer.trainable = True

# Need to re-compile
# Notice that we decreased the learning rate also to not corrupt the unfrozen, well trained layers
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)
model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])

# Training for longer
history = model.fit(train_set, validation_data=valid_set, epochs=3)


# Classification and Localization

Doing transfer learning on the Xception model to output 4 regression values that will be used for a bounding box around an object.

**This is just for demonstration. The data does not have bounding boxes that the model can learn from**

In [None]:
base_model = tf.keras.applications.xception.Xception(weights="imagenet", include_top=False)

avg = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)

# Output the class label (using the 1000 labels the model has)
class_output = tf.keras.layers.Dense(n_classes, activation="softmax")(avg)

# Output the 4 location values (regression)
loc_output = tf.keras.layers.Dense(4)(avg)

# The model now has two output types (class probability and location regression values)
model = tf.keras.Model(inputs=base_model.input, outputs=[class_output, loc_output])

model.compile(loss=["sparse_categorical_crossentropy", "mse"], 
              loss_weights=[0.8, 0.2], optimizer=optimizer, metrics=["accuracy"])
