# Introduction

This notebook is a companion to `1.0-flowers-in-fastai.ipynb` meant to introduce TensorFlow and Keras to those who already know fastai.

# Setup

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pickle
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [None]:
with open('path.pkl', 'rb') as f:
    path = pickle.load(f)

# Dataset and dataloaders

We load our flowers using `image_dataset_from_directory`, similarly to what we did using a DataBlock in fastai.

In [None]:
image_size=(224,224)

In [None]:
flowers_train = keras.preprocessing.image_dataset_from_directory(path, 
                                                                    batch_size=64, 
                                                                    image_size=image_size,
                                                                    validation_split=0.2,
                                                                    seed = 42,
                                                                    subset="training")

In [None]:
flowers_val = keras.preprocessing.image_dataset_from_directory(path, 
                                                                batch_size=64, 
                                                                image_size=image_size,
                                                                validation_split=0.2,
                                                                seed = 42,
                                                                subset="validation")

In [None]:
flowers_val.class_names

Here's a few elements from the first batch:

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in flowers_train.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(flowers_train.class_names[int(labels[i])])
        plt.axis("off")


# Train a model

In fastai we used a `cnn_learner` with a pretrained ResNet model as the base model. Let's do it in a similar way with Keras. 

## Data augmentation

We'll need some data augmentation. In Keras, we can add data augmentation by adding layers to the model. 

Here are some examples of data augmentation layers:

In [None]:
data_augmentation = keras.Sequential(
    [
        keras.layers.Resizing(224, 224),
        keras.layers.RandomFlip("horizontal"),
        keras.layers.RandomRotation(0.1),
        keras.layers.RandomZoom(0.1),
        keras.layers.RandomContrast(factor=0.01)
    ]
)


In [None]:
plt.figure(figsize=(10, 10))
for images, _ in flowers_train.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images, training=True)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")


## Instantiate a base model and load pre-trained weights

We cut off the top part of the model (we'll insert our own head later):

In [None]:
resnet_model = keras.applications.resnet.ResNet50(weights="imagenet",
                                                   input_shape=image_size + (3,), 
                                                   include_top=False)

In [None]:
#keras.utils.plot_model(resnet_model, show_shapes=True)

## Preprocessing

As the model is pre-trained on a dataset with specific properties (ImageNet) we need to preprocess our data to make it resemble the original data set. We'll do that by inserting a preprocessing layer at the beginning of the model. 

### Extra: check the effect of preprocessing

In [None]:
inputs = keras.Input(shape=image_size + (3,))
outputs = keras.applications.resnet.preprocess_input(inputs)
preprocess_m = keras.Model(inputs, outputs)

In [None]:
preprocess_m.summary()

Here's a batch of images:

In [None]:
batch = next(iter(flowers_train.take(1)))[0]

In [None]:
batch.shape

In [None]:
np.mean(batch)

Here's the batch after being fed through these preprocessing layers:

In [None]:
batch_pred = preprocess_m.predict(batch)
batch_pred.shape

In [None]:
batch_pred.mean()

## Add a new head to the model

In [None]:
# Inputs are images (tensors) of a particular size
inputs = keras.Input(image_size + (3,))

# Data augmentation
x = keras.layers.Resizing(224, 224)(inputs)
x = keras.layers.RandomFlip("horizontal")(x)
x = keras.layers.RandomRotation(0.1)(x)
x = keras.layers.RandomZoom(0.1)(x)
x = keras.layers.RandomContrast(factor=0.01)(x)               

# We preprocess the tensors to be compatible with the pretrained base model:
x = keras.applications.resnet.preprocess_input(x)

# Base model:
x = resnet_model(x, training=False)

# Custom head:
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.BatchNormalization(axis=-1)(x)
x = keras.layers.Dropout(rate=0.25)(x)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.BatchNormalization(axis=-1)(x)
x = keras.layers.Dropout(rate=0.5)(x)

# Outputs
outputs = keras.layers.Dense(5, activation="softmax")(x)

In [None]:
model = keras.Model(inputs, outputs)

In [None]:
model.summary()

In [None]:
keras.utils.plot_model(model, show_shapes=True)

## Train the head of the model

In [None]:
train_ds = flowers_train.prefetch(buffer_size=64)
val_ds = flowers_val.prefetch(buffer_size=64)

In [None]:
resnet_model.trainable=False

In [None]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=1e-3),
             loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
             metrics = "accuracy")

Note: if we wanted to stick closer to fastai ideas, we could use cyclical learning rates from [TensorFlow addons](https://www.tensorflow.org/addons/tutorials/optimizers_cyclicallearningrate), and also a [learning rate finder](https://pyimagesearch.com/2019/08/05/keras-learning-rate-finder/).

In [None]:
#from tensorflow_addons.optimizers import CyclicalLearningRate

In [None]:
model.fit(train_ds, epochs=3, validation_data=val_ds)

## Unfreeze and fine-tune the model

In [None]:
resnet_model.trainable=True

In [None]:
model.compile(optimizer=keras.optimizers.Adam(1e-5),
             loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), 
             metrics = "accuracy")

In [None]:
model.fit(train_ds, epochs=2, validation_data=val_ds)

# Evaluate the model

In [None]:
loss, acc = model.evaluate(flowers_val)
print("Accuracy: ", acc)