# SciKeras Transfer Learning & Fine Tuning

Transfer learning is popular deep learning approach that involves re-purposing a model trained on one dataset onto another dataset. This notebook shows how to implement this in SciKeras. We will be following the [Keras tutorial](https://www.tensorflow.org/tutorials/images/transfer_learning) on the topic, which goes much more in depth and breadth than we will here. You are highly encouraged to check out that tutorial if you want to learn about fine tuning and transfer learning in the general sense.



<table align="left"><td>
<a target="_blank" href="https://colab.research.google.com/github/adriangb/scikeras/blob/master/notebooks/Basic_Usage.ipyn">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>  
</td><td>
<a target="_blank" href="https://github.com/adriangb/scikeras/blob/master/notebooks/Basic_Usage.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a></td></table>

### Table of contents

* [Data](#Data)
* [Load Pre-Trained Model](#Model)
* [Fine Tuning](#Keras-benchmark)
* [SciKeras benchmark](#SciKeras-benchmark)

Install SciKeras

In [None]:
!python -m pip install git+https://github.com/adriangb/scikeras.git@master

Silence TensorFlow warnings to keep output succint.

In [36]:
import warnings
from tensorflow import get_logger
get_logger().setLevel('ERROR')
warnings.filterwarnings("ignore", message="Setting the random state for TF")

In [37]:
import numpy as np
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow import keras

## Data

We load the dataset from the Keras tutorial. The dataset consists of images of cats and dogs.

In [38]:
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             image_size=IMG_SIZE)

validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  image_size=IMG_SIZE)

Found 2000 files belonging to 2 classes.
Found 1000 files belonging to 2 classes.


Scikit-Learn (and by extension SciKeras) does not support `tf.DataSet`s. Altough this may change in the future, for now you need to convert your data to numpy arrays to use SciKeras. This has a performance impact and limits the size of the datasets to what can fit in-memory. If you encounter issues with this approach, you may need to use Keras directly.

In [40]:
X_train = []
y_train = []
for batch in train_dataset.as_numpy_iterator():
    X_train.append(batch[0])
    y_train.append(batch[1])
X_train = np.concatenate(X_train)
y_train = np.concatenate(y_train)

In this tutorial, we will not be using the validation set outside of Keras. Thus we can leave it as a `tf.DataSet` and pass it to `tf.keras.Model.fit` via the `fit__validation_data` routed param. Keras will then use this to compute validation metrics for each epoch.

## Define Keras Model

We load a pre-trained MobileNet v2. We specify that we want the weights from the ImageNet dataset by passing `weights='imagenet'`

In [41]:
# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

### Set up model for fine tuning

To fine tune the model, we first must pick some layers to train and some to freeze. Generally, we will want to train the upper layers of the neural network.

In [42]:
base_model.trainable = True

# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable =  False

Number of layers in the base model:  155


We now add input preprocessing and an output layer to `base_model` as well as a classifier head.

In [43]:
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

In [44]:
inputs = tf.keras.Input(shape=(160, 160, 3))
x = preprocess_input(inputs)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
model = tf.keras.Model(inputs, outputs)

### Wrap the model with KerasClassifier

In addition to freezing some layers, for fine tuning you will usually want to set a relatively low learning rate. This avoids overfitting to the new dataset and loss of generality, which would defeat the purpose of transfer learning.

In [45]:
clf = KerasClassifier(
    model=model,
    loss="binary_crossentropy",
    optimizer__learning_rate=1e-5,
    metrics=['accuracy'],
)

In [46]:
clf.set_params(epochs=5)
clf.set_params(fit__validation_data=validation_dataset)
clf.fit(X_train, y_train)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


KerasClassifier(
	model=<tensorflow.python.keras.engine.functional.Functional object at 0x150b023a0>
	build_fn=None
	warm_start=False
	random_state=None
	optimizer=rmsprop
	loss=binary_crossentropy
	metrics=['accuracy']
	batch_size=None
	verbose=1
	callbacks=None
	validation_split=0.0
	shuffle=True
	run_eagerly=False
	epochs=10
	optimizer__learning_rate=1e-05
	class_weight=None
	fit__validation_data=<BatchDataset shapes: ((None, 160, 160, 3), (None,)), types: (tf.float32, tf.int32)>
)