# Tensorflow/Keras transfer learning example

This notebook gives a quick introduction to using a CNN trained on imagenet to do classification on a different problem.

In this case, we take VGG16 (because it's simple) and re-train it to classify handwritten digits from MNIST

In [None]:
import keras
from keras.applications.vgg16 import VGG16
from keras.layers import Flatten, Dense
from keras.models import Model
from keras.datasets import mnist
from keras.utils import to_categorical, model_to_dot
from IPython.display import Image
import numpy as np
import matplotlib.pyplot as plt


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
## Load the pretrained model, just to have a look at it
pretrained_model = VGG16()

pretrained_model.summary()

Note that it expects a given input size, and produces a prediction based on the number of classes it was trained on.

Now lets try loading the VGG network with our custom options:

In [None]:
retraining_model = VGG16(include_top=False, weights='imagenet', input_shape=(32,32,3), classes=10)
retraining_model.summary()

Note that there is no densely connected layer at the end now, this is just the feature extraction bits of the network

Now we add our own classification parts, namely a densely connected layer with 10 classes

In [None]:
flattened = Flatten()(retraining_model.output)
fc1 = Dense(4096, activation='relu')(flattened)
fc2 = Dense(10, activation='softmax')(fc1)

mnist_model = Model(inputs=retraining_model.input, outputs=fc2)
mnist_model.summary()



In [None]:
## This line visualises the VGG_MNIST network graph
display(Image(model_to_dot(mnist_model).create(prog='dot', format='png')))

We're now ready to start training the network! However, we have a couple of things to do:
- We should freeze some layers in the pretrained part of the model (how many is up to you)
- We need to load the dataset and tweak it a bit to work with VGG16

In [None]:
## Here we grab up to the last 4 layers (block5_conv{1,2,3} and block5_pool) and set 
## them to be untrainable, leaving only the last few layers
for layer in retraining_model.layers[:-4]:
    layer.trainable = False
    
mnist_model.summary() ## Note how the number of trainable params has gone down a lot

In [None]:
## Now we can load the MNIST data
(x_train_raw, y_train_raw), (x_test_raw, y_test_raw) = mnist.load_data()
print(x_train_raw.shape)

## Pad around the images with zeros to make them 32x32
x_train = np.pad(x_train_raw, ((0,0), (2,2),(2,2)), 'constant', constant_values=0)
x_test = np.pad(x_test_raw, ((0,0), (2,2),(2,2)), 'constant', constant_values=0)

## make into 3 channel
x_train = x_train[:,:,:, None] * np.ones(3)[None, None, None, :]
x_test = x_test[:,:,:, None] * np.ones(3)[None, None, None, :]


## convert integer labels to categorical
y_train = to_categorical(y_train_raw)
y_test = to_categorical(y_test_raw)


print(x_train.shape)
print(y_train.shape)

Now we have everything ready, we can re-train the network!

In [None]:
mnist_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])


history = mnist_model.fit(x_train[:4096], y_train[:4096], epochs=5, batch_size=32, validation_split=0.2) ## NB this is deliberately short so it might run in time!

# Plot training & validation accuracy values
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Test'], loc='upper left')
plt.show()

Now we can evaluate the model using the test data

In [None]:
mnist_model.evaluate(x_test[:100], y_test[:100])



We would now unfreeze a few more layers in the network and use an optimizer with a small learning rate to fine tune the network a bit more until we were happy with performance.

That's left as an exercise for the reader!