# Transfer Learning

Steps to follow:

- Load the existing model
- Use full model or select the first layers manually using `model.layers`
- Freeze the gotten model/layers
- Create your full network (Function API may be easier)
    - Adding pre_processing (tf models already come with pre_processing layer)
    - Adding last layers
- Train
- Unfreeze
- Train again

In [6]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory
import os
import matplotlib.pyplot as plt
import numpy as np

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir, shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)

Found 2000 files belonging to 2 classes.


In [7]:
validation_dataset = image_dataset_from_directory(validation_dir, batch_size=BATCH_SIZE, image_size=IMG_SIZE)

Found 1000 files belonging to 2 classes.


In [8]:
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)

In [9]:
print('Number of train batches: %d' % tf.data.experimental.cardinality(train_dataset))
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))

Number of validation batches: 26
Number of test batches: 6


## Rescale images

MobileNetV2 has images ranging `[-1, 1]`

In [14]:
for img, labels in train_dataset.take(1):
    print(np.max(img.numpy()))    # Our dataset seems to be from [0, 255]

255.0


In [55]:
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

## Get MobileNetV2

In [48]:
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet')

In [49]:
base_model.summary()

Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_4[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
_______________________________________________________________________________

In [50]:
base_model.layers

[<tensorflow.python.keras.engine.input_layer.InputLayer at 0x7f9a084665e0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f9a086ac100>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f9a2816f190>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x7f9a08417ee0>,
 <tensorflow.python.keras.layers.convolutional.DepthwiseConv2D at 0x7f9a08564940>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f9a08659250>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x7f9a2819aaf0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f9a0857e070>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f9a0857edf0>,
 <tensorflow.python.keras.layers.convolutional.Conv2D at 0x7f9a0857ff10>,
 <tensorflow.python.keras.layers.normalization_v2.BatchNormalization at 0x7f9a08437670>,
 <tensorflow.python.keras.layers.advanced_activations.ReLU at 0x7f9a08505c40>,
 <tensorflow.python.keras.

In [51]:
for i in range(0, len(base_model.layers)):
    base_model.layers[i].trainable = False
# base_model.trainable = False

In [52]:
base_model.layers[1].trainable

False

## Create network
We are going to use Functional API

In [57]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = preprocess_input(inputs)
x = base_model(x, training=False)
# With this parameter there is no need to set the trainable layers to False right?
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(1)(x)

model = tf.keras.Model(inputs, outputs)

In [58]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [59]:
len(model.trainable_variables)

2

In [60]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
tf.math.truediv_1 (TFOpLambd (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract_1 (TFOpLamb (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281  

## Train

In [61]:
initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)



In [62]:
history = model.fit(train_dataset, epochs=initial_epochs, validation_data=validation_dataset)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## Unfreeze an keep training

In [63]:
base_model.trainable = True
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_8 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
tf.math.truediv_1 (TFOpLambd (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract_1 (TFOpLamb (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281  

In [64]:
# One must always compile again
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])

In [65]:
fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)

Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
