# Transfer Learning on TPU For Flower Classification

This notebook demonstrates how to use TPUs with TensorFlow to train a model for classifying flower images.

# Imports

In [None]:
from __future__ import annotations

import functools
import math
import os
import warnings

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '4'

from kaggle_datasets import KaggleDatasets
from matplotlib import pyplot as plt
import numpy as np
import optuna
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, optimizers, applications, callbacks, Sequential, Input

warnings.filterwarnings("ignore")
tf.random.set_seed(42)

# Detect TPU

Code borrowed from [Getting started with 100+ flowers on TPU](https://www.kaggle.com/mgornergoogle/getting-started-with-100-flowers-on-tpu) by [Martin Görner](https://www.kaggle.com/mgornergoogle).

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    strategy = tf.distribute.TPUStrategy(tpu)
except ValueError:  # detect GPUs
    strategy = tf.distribute.MirroredStrategy()

print("Number of accelerators: ", strategy.num_replicas_in_sync)
strategy

# Moving Data To Google Cloud Storage (GCS)

TPUs require data to be present on GCS. The below utility copies the data to a GCS bucket co-located with the TPU.

In [None]:
GCS_DS_PATH = KaggleDatasets().get_gcs_path("flower-classification-with-tpus")

In [None]:
# Check the URLs for the dataset
!gsutil ls $GCS_DS_PATH

# Configuration

This section defines some basic configuration that will be used by the rest of the notebook.

In [None]:
IMG_SIZE = 192
EPOCHS = 12
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE

In [None]:
BASE_GCS_PATH = GCS_DS_PATH

gcs_fmt = os.path.join(BASE_GCS_PATH, "tfrecords-jpeg-{}x{}", "")

GCS_PATHS = {
    192: gcs_fmt.format(192, 192),
    224: gcs_fmt.format(224, 224),
    331: gcs_fmt.format(331, 331),
    512: gcs_fmt.format(512, 512),
}
DATA_DIR = GCS_PATHS[IMG_SIZE]

In [None]:
df = pd.read_csv("../input/flower-classification-labels/flower_classification_labels.csv")
CLASSES = df["class"].tolist()
CLASSES

# Data Augmentation

In [None]:
def data_augment(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_brightness(image, 0.2)
    image = tf.image.random_contrast(image, 0.5, 2.0)
    return image, label

# Dataset Functions

This section has code which loads the train, validation and test datasets.

In [None]:
# Decode a JPEG image into a unit8 Tensor 
def decode_image(image: tf.Tensor, channels: int = 3) -> tf.Tensor:
    img = tf.image.decode_jpeg(image, channels=channels)
    img = tf.reshape(img, [IMG_SIZE, IMG_SIZE, 3])
    return img

In [None]:
# Read a TFRecord, extracing the image and either the label or the ID
def read_tfrecord(example: tf.Tensor, has_labels: bool = True) -> tuple[tf.Tensor, tf.Tensor]:
    tfrecord_format = {"image": tf.io.FixedLenFeature([], tf.string)}

    if has_labels is True:
        key = "class"
        tfrecord_format["class"] = tf.io.FixedLenFeature([], tf.int64)
    else:
        key = "id"
        tfrecord_format["id"] = tf.io.FixedLenFeature([], tf.string)

    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example["image"])
    value = example[key]
    return image, value

In [None]:
# Use the list of provided filepaths for TFRecords and build a dataset out of it
def get_dataset(filepaths: list[str], has_labels: bool = True, ordered: int = False) -> tf.data.Dataset:
    options = tf.data.Options()
    if ordered is False:
        options.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=AUTO)
    dataset = dataset.with_options(options)

    reader = functools.partial(read_tfrecord, has_labels=has_labels)
    dataset = dataset.map(reader, num_parallel_calls=AUTO)
    return dataset

In [None]:
# Make a dataset from the TFRecords stored in the given directory
def load_from_dir(
    directory: str,
    has_labels: bool = True,
    ordered: bool = False,
    repeat: bool = False,
    cache: bool = False,
    shuffle: bool = False,
    augment: bool = False
) -> tuple[tf.data.Dataset, filenames]:
    path = os.path.join(DATA_DIR, directory, "*.tfrec")
    filepaths = tf.io.gfile.glob(path)
    
    dataset = get_dataset(filepaths, has_labels=has_labels, ordered=ordered)
    
    if augment is True:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    
    if repeat is True:
        dataset = dataset.repeat()
    
    if shuffle is True:
        dataset = dataset.shuffle(2048)
        
    dataset = dataset.batch(BATCH_SIZE)
    
    if cache is True:
        dataset = dataset.cache()
        
    dataset = dataset.prefetch(AUTO)
    return dataset

In [None]:
def n_samples(directory):
    path = os.path.join(DATA_DIR, directory, "*.tfrec")
    filepaths = tf.io.gfile.glob(path)
    tot = 0
    for filepath in filepaths:
        basename = os.path.basename(filepath)
        filename, _ = os.path.splitext(basename)
        tot += int(filename.split("-")[-1])
    return tot

# Plotting functions

In [None]:
# Convert a batch of images to NumPy
def batch_to_numpy(
    batch: tuple[tf.Tensor, tf.Tensor],
    has_labels=False
) -> tuple[np.ndarray, np.ndarray]:
    if has_labels is False:
        images, _ = batch
        return images.numpy(), None
    
    images, labels = batch
    return images.numpy(), labels.numpy()

In [None]:
# Get the text to display on top of each image
# With its size and color
def get_title(prediction: int, label: int):
    c = ["red", "black"]
    
    # If test data with no predictions, no text
    if prediction is None and label is None:
        return '', c[True]
    
    # If test data but with predictions, return predicted label
    if label is None:
        return CLASSES[prediction], c[True] 
    
    actual = CLASSES[label]
    
    # If train/validation data with prediction,
    # Display only the label if correct prediction
    # Otherwise, display the prediction with the correct label
    if prediction is not None:
        correct = prediction == label
        title = f"p: {CLASSES[prediction]}\na: {actual}"
        return title, c[correct]
    
    # If only label, return as is
    return f"{actual}", c[True]

In [None]:
# Make a grid of the given images
def plot_grid(
    images: np.ndarray,
    labels: np.ndarray | list[None],
    predictions: tf.Tensor | list[None],
    spacing: float = 0.1
) -> None:
    n_images = len(images)
    
    # Make a square grid by taking square root
    rows = int(math.sqrt(n_images))
    cols = n_images // rows
    tot = rows * cols

    # Some parameters borrowed from the Getting Started Notebook.
    size = 13
    fontdict = {"verticalalignment": "center"}

    figsize = (size, size / tot) if rows < cols else (size / tot, size)
    plt.figure(figsize=figsize)

    # Make a subplot
    fig, axs = plt.subplots(
        rows,
        cols,
        figsize=(size, size),
        constrained_layout=True,
        gridspec_kw={"wspace": spacing, "hspace": spacing}
    )
    plt.axis("off")
    axs = axs.flatten()

    # Go over each image, label, prediction
    # And add to subplot
    zipped = zip(images[:tot], labels[:tot], predictions[:tot], axs)
    for image, label, prediction, ax in zipped:
        fontsize = size * spacing / max(rows, cols) * 40 + 3
        title, color = get_title(prediction, label)
        ax.imshow(image)
        ax.set_title(title, fontsize=fontsize, color=color, fontdict=fontdict, pad=fontsize / 1.5)
        ax.set_axis_off()

In [None]:
# Plot a batch of images
def plot_batch(
    batch: tuple[tf.Tensor, tf.Tensor],
    has_labels: bool = True,
    predictions: tf.Tensor = None
) -> None:
    # Convert to Numpy
    images, labels = batch_to_numpy(batch, has_labels=has_labels)

    n_images = len(images)

    # Fill labels and predictions with None if required
    labels_ = labels if labels is not None else [None for _ in range(n_images)]
    predictions_ = (
        predictions if predictions is not None else [None for _ in range(n_images)]
    )
    spacing = 0.1
    # Plot the images with labels and predictions in a grid
    plot_grid(images, labels_, predictions_, spacing=spacing)
    
    # Handle whitespace
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=spacing, hspace=spacing)
    plt.show()

## Inspect Datasets

In [None]:
train_data = load_from_dir("train", repeat=True, shuffle=True, augment=True)
n_train = n_samples("train")
val_data = load_from_dir("val", cache=True)
n_val = n_samples("val")

In [None]:
print(f"Number of training samples = {n_train}")
print("Training data shape and sample labels:")
for image, label in train_data.take(1):
    print(image.shape, label)

print(f"Number of validation samples = {n_val}")
print("Validation data shape and sample labels:")
for image, label in val_data.take(1):
    print(image.shape, label)

In [None]:
itrain = iter(train_data.unbatch().batch(20))

In [None]:
plot_batch(next(itrain))

In [None]:
ival = iter(val_data.unbatch().batch(20))

In [None]:
plot_batch(next(ival))

# Functions For Building The Model

Transfer learning in the form of fine-tuning is used to adapt a pretrained model for this task. The model is initialized with weights for the ImageNet dataset and training is turned on for it. A `GlobalAveragePooling2D` layer is applied to its output. Optionally, additional `Dense` layers can also added. The output is a `Dense` layer with as many units as the number of classes and a softmax activation.

The pretrained model can either be set to VGG16, VGG19, Xception or ResNet50 by adding the key `core_model` in `params` with an appropriate value. The default is VGG16. The optional `Dense` layers are added by adding the key `dense_out_features` with a list of integers, each integer being the number of units in the layer. As many layers as the length of the list will be added to the model, in addition to the final output layer.

In [None]:
core_model_map = {
    "vgg16": [
        applications.vgg16.preprocess_input,
        applications.VGG16,
    ],
    "xception": [
        applications.xception.preprocess_input,
        applications.Xception,
    ],
    "vgg19": [
        applications.vgg19.preprocess_input,
        applications.VGG19,
    ],
    "resnet50": [
        applications.resnet50.preprocess_input,
        applications.ResNet50,
    ]
}

In [None]:
def get_model_params(params):
    dense_out_features = []
    
    for param, value in params.items():
        if "dense" in param:
            dense_out_features.append(value)
    
    return {
        "core_model": params.get("core_model", "vgg16"),
        "dense_out_features": dense_out_features
    }

In [None]:
def make_model(params):
    model_params = get_model_params(params)
    
    core_model = params.get("core_model", "vgg16")
    dense_out_features = params.get("dense_out_features", [])
    
    shape = [IMG_SIZE, IMG_SIZE, 3]
    
    with strategy.scope():
        preproces, core = core_model_map[core_model]
        
        ip = layers.Lambda(lambda data: preproces(tf.cast(data, tf.float32)), input_shape=shape)
        core = core(weights="imagenet", include_top=False)
        
        dense = [layers.Dense(features, activation="relu") for features in dense_out_features]
        
        model = Sequential(
            [
                ip,
                core,
                layers.GlobalAveragePooling2D(),
                *dense,
                layers.Dense(len(CLASSES), activation="softmax"),
            ]
        )
        
        optimizer = optimizers.Adam(learning_rate=params.get("lr", 1e-3))
        loss = "sparse_categorical_crossentropy"
        metric = "sparse_categorical_accuracy"
        model.compile(optimizer=optimizer, loss=loss, metrics=[metric], steps_per_execution=16)
        
        return model

# Training Function

The training logic is encapsulated in the function below. It returns the validation loss after the final epoch, the trained model and the training history at the end of training.

In [None]:
def train(params):
    train_data = load_from_dir("train", repeat=True, shuffle=True, augment=True)
    n_train = n_samples("train")

    val_data = load_from_dir("val", cache=True)
    n_val = n_samples("val")
    
    model = make_model(params)    

    steps_per_epoch = n_train // BATCH_SIZE
    validation_steps = -(-n_val // BATCH_SIZE)
    
    early_stopping = callbacks.EarlyStopping(patience=5)
    pruning_callback = params.get("pruning_callback", [])
    
    epochs = params.get("epochs", EPOCHS)
    
    history = model.fit(
        train_data,
        steps_per_epoch=steps_per_epoch,
        epochs=epochs,
        validation_data=val_data,
        validation_steps=validation_steps,
        callbacks=[early_stopping, *pruning_callback]
    )
    
    val_loss = history.history["val_loss"]

    return val_loss[-1], model, history
        

# Optuna Objective

The hyperparameters are tuned using Optuna. It is used to select the architecture that will be used for transfer learning, the learning rate, the number and sizes of the additional `Dense` layers to be added and the number of epochs. Additionally, Optuna is penalized whenever it chooses parameters that lead to early stopping since early stopping suggests that there is overfitting.

In [None]:
def objective(trial):
    params = {
        "core_model": trial.suggest_categorical("core_model", ["vgg16", "xception", "vgg19", "resnet50"]),
        "lr": trial.suggest_float("lr", 1e-5, 4e-4),
        "epochs": trial.suggest_int("epochs", 8, 20),
    }
    
    n_dense = trial.suggest_categorical("n_dense", [1, 2, 3, 4, 5])
    for i in range(n_dense):
        key = f"dense{i + 1}_out"
        params[f"dense{i + 1}_out"] = trial.suggest_int(key, 32, 1024)
        
    params["pruning_callback"] = [optuna.integration.TFKerasPruningCallback(trial, "val_loss")]
    
    score, _, hist = train(params)
    
    val_loss = hist.history["val_loss"]
    
    # Penalize tuner for being too aggresive
    if len(val_loss) < params["epochs"]:
        return np.max(val_loss)

    return score

# Training

# Tune Parameters

The Hyperband pruning technique is a popular and SOTA pruning technique for stopping unpromising trial mid-way.

In [None]:
pruner = optuna.pruners.HyperbandPruner()
sampler = optuna.samplers.TPESampler(42, multivariate=True)
study = optuna.create_study(
    direction="minimize",
    pruner=pruner,
    sampler=sampler
)
study.optimize(objective, n_trials=50, gc_after_trial=True)

# Train Final Model Using Best Parameters

In [None]:
study.best_trial.params

In [None]:
best_trial = study.best_trial
_, model, history = train(best_trial.params)

# Loss Curve

In [None]:
train_loss = history.history["loss"]
val_loss = history.history["val_loss"]
plt.plot(train_loss, label="Training")
plt.plot(val_loss, label="Validation")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Accuracy Curve

In [None]:
train_loss = history.history["sparse_categorical_accuracy"]
val_loss = history.history["val_sparse_categorical_accuracy"]
plt.plot(train_loss, label="Training")
plt.plot(val_loss, label="Validation")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

# Predictions

In [None]:
test_data = load_from_dir("test", has_labels=False, ordered=True)
n_test = n_samples("test")
test_steps = -(-n_test // BATCH_SIZE)

In [None]:
print('Computing predictions...')
test_images = test_data.map(lambda image, idnum: image)
probabilities = model.predict(test_images, steps=test_steps)
predictions = np.argmax(probabilities, axis=-1)
print(predictions)

print('Generating submission.csv file...')
test_ids = test_data.map(lambda image, idnum: idnum).unbatch()
test_ids = next(iter(test_ids.batch(n_test))).numpy().astype('U') # all in one batch
np.savetxt('submission.csv', np.rec.fromarrays([test_ids, predictions]), fmt=['%s', '%d'], delimiter=',', header='id,label', comments='')

# Visual Validation

In [None]:
val_data = load_from_dir("val")
batches = val_data.unbatch().batch(20)
ibatches = iter(batches)

In [None]:
images, labels = next(ibatches)
probabilities = model.predict(tf.cast(images, tf.float32))
predictions = np.argmax(probabilities, axis=-1)
plot_batch((images, labels), predictions=predictions)