# Introduction

The notebook expands upon the image classification example in notebook `1.1-flowers-in-tensorflow.ipynb`, illustrating some additional ideas and concepts from TensorFlow.

**Main takeaways and motivation:**

* Last time we had a look at an image classification example in TensorFlow. There, we used high-level functionality from Keras to read images from disk. This time we'll use the more efficient TFRecords format to store the images, and show how to work with TFRecords in TensorFlow.
* ..

# Setup

In [None]:
%matplotlib inline
import numpy as np, pandas as pd, matplotlib.pyplot as plt
import pickle, PIL, os
from pathlib import Path

os.environ["CUDA_VISIBLE_DEVICES"]="2";

import tensorflow as tf
from tensorflow import keras

# Load the flowers data and store as TFRecords

This is the data set we downloaded in the notebook `0.1-download_flowers_data.ipynb`, and studied in the two notebooks `1.0` (fastai) and `2.0` (TensorFlow).

In [None]:
with open('path.pkl', 'rb') as f:
    path = pickle.load(f)

In [None]:
list(path.iterdir())

**Plot some random images:**

In [None]:
import random

In [None]:
kind = 'sunflowers'
nb = 9
images = random.choices(list((path/kind).iterdir()), k=nb)

plt.figure(figsize=(10, 10))

for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    img = PIL.Image.open(images[i])
    plt.imshow(img)
    plt.axis("off")

## Split images into train and test

In [None]:
all_images = list(path.glob("*/*.jpg"))
nb_images = len(all_images)
nb_images

We shuffle the images to make sure we can get all classes in the test set by slicing the list.

In [None]:
random.shuffle(all_images)

In [None]:
test_size=0.2

In [None]:
train_images = all_images[:int((1-test_size)*nb_images)]

test_images = all_images[int((1-test_size)*nb_images):]

print(f"nb train images: {len(train_images)}\n nb test images: {len(test_images)}")

## Load and store as TFRecord

> TFRecord is a data format for storing sequences of binary records. Performance-wise, binary formats allow for fast reading and writing. TensorFlow has several optimization that are based on TFRecord, allowing for seamless integration with everything from preprocessing layers to distributed data sets. 

https://www.tensorflow.org/tutorials/load_data/tfrecord

We're going to store the images as binary records, using TFRecord. For this, we need to specify the structure of the data, including the labels we want to assign to each image. 

### Get labels

In [None]:
sorted(list(path.iterdir()))

In [None]:
labels_dict = {0: 'daisy',
               1: 'dandelion',
               2: 'roses',
               3: 'sunflowers',
               4: 'tulips'}

In [None]:
labels_dict_reversed = {v: k for k,v in labels_dict.items()}

In [None]:
labels_dict_reversed

In [None]:
def get_label(img_path):
    label = img_path.parent.stem
    return labels_dict_reversed[label]

**Test**

In [None]:
test_img_path = train_images[10]
test_img = PIL.Image.open(test_img_path)
test_img

In [None]:
get_label(test_img_path)

In [None]:
labels_dict[get_label(test_img_path)]

### Save the raw image data byte strings in a tensor

In [None]:
test_img_tensor = tf.io.read_file(str(test_img_path))

In [None]:
len(test_img_tensor.numpy())

In [None]:
test_img_tensor.numpy()[:50]

### Construct `tf.train.Example`

To specify the structure of the data stored as byte strings, we can use `tf.train.Example`, a standard [protocol buffer](https://developers.google.com/protocol-buffers/?hl=en) for serializing data (created by Google): https://www.tensorflow.org/api_docs/python/tf/train/Example. 

In [None]:
#?tf.train.Example

In [None]:
#?tf.train.Feature

Helper functions to turn values into lists that are then turned into a list of values to use in `tf.train.Example`:

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

**Test:**

In [None]:
test_example = tf.train.Example(features=tf.train.Features(
    feature = {'image_raw': _bytes_feature(test_img_tensor.numpy()), 
               'label': _int64_feature(get_label(test_img_path))}
))

In [None]:
#test_example

### Write all images to tfrecord

In [None]:
tfrecord_train_fn = str(path/'flowers_dataset_train.tfrecord')
tfrecord_test_fn = str(path/'flowers_dataset_test.tfrecord')

In [None]:
def write_tfrecord(images, tfrecord_fn):
    if not os.path.isfile(tfrecord_fn):
        with tf.io.TFRecordWriter(tfrecord_fn) as writer:
            for img_path in images:
                try:
                    raw_file = tf.io.read_file(str(img_path))
                except:
                    print(f"File {img_path} could not be found (or read)")
                    continue

                example = tf.train.Example(features=tf.train.Features(
                            feature = {'image_raw': _bytes_feature(raw_file.numpy()), 
                                       'label': _int64_feature(get_label(img_path))}
        ))
                writer.write(example.SerializeToString())
                
            print(f"TFRecord written to {tfrecord_fn}")

    else:
        print(f"TFRecords already written to disk: {tfrecord_fn}")
        return

In [None]:
write_tfrecord(train_images, tfrecord_train_fn)

In [None]:
write_tfrecord(test_images, tfrecord_test_fn)

# Dataloader

We can now construct data sets and dataloaders that gets data from the stored TFRecords. For this, we use the `tf.data` API: https://www.tensorflow.org/guide/data_performance. We'll partly follow the example from https://keras.io/examples/keras_recipes/tfrecord/. Have a look at this example for some additional details and links.

In [None]:
BATCH_SIZE = 64
IMAGE_SIZE = (224, 224)

We need to decode the byte strings representing JPEG images:

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMAGE_SIZE)
    image = tf.cast(image, tf.float32)
    return image

In [None]:
def read_tfrecord(example):
    tfrecord_format = {
            "image_raw": tf.io.FixedLenFeature([], tf.string),
            "label": tf.io.FixedLenFeature([], tf.int64)

    }
    
    example = tf.io.parse_single_example(example, tfrecord_format)

    
    image = decode_image(example["image_raw"])
    label = tf.cast(example["label"], tf.int32)
    return image, label

We construct data sets by shuffling, prefetching and batching the data read from TFRecords:

In [None]:
def load_dataset(fn):
    dataset = tf.data.TFRecordDataset(fn)
    
    dataset = dataset.map(read_tfrecord, num_parallel_calls=8)
    
    dataset = dataset.shuffle(1024)
    dataset = dataset.prefetch(buffer_size=64)
    dataset = dataset.batch(BATCH_SIZE)
    
    return dataset

In [None]:
train_dataset = load_dataset(tfrecord_train_fn)
test_dataset = load_dataset(tfrecord_test_fn)

**Plot some images:**

As you know very well by now, it's always a good idea to have a look at the data after each step. Here's a few images from the batches extracted from the training and test datasets:

In [None]:
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10, 10))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        plt.title(labels_dict[label_batch[n]])
        plt.axis("off")

In [None]:
image_batch, label_batch = next(iter(train_dataset))

In [None]:
show_batch(image_batch.numpy(), label_batch.numpy())

From the test set:

In [None]:
image_batch, label_batch = next(iter(test_dataset))

In [None]:
show_batch(image_batch.numpy(), label_batch.numpy())

# Train a model

Now we can train a model. We'll follow the setup in notebook `1.1`, with some minor modifications.

This time we'll try out a learning rate schedule:

In [None]:
initial_learning_rate = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

We'll set up two different base models: an Xception model and a ResNet50

In [None]:
base_model_xc = tf.keras.applications.Xception(
        input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet"
    )

preprocess_input_xc = keras.applications.xception.preprocess_input

In [None]:
base_model_rn = keras.applications.resnet.ResNet50(
         input_shape=(*IMAGE_SIZE,3), include_top=False, weights="imagenet")

preprocess_input_rn = keras.applications.resnet.preprocess_input

The below model is essentially a copy of the one in notebook `1.1`. Consult that notebook for additional details. 

In [None]:
def make_model(base_model, preprocess_input):
    
    base_model.trainable = False

    inputs = keras.layers.Input([*IMAGE_SIZE, 3])
    
    # Data augmentation
    x = keras.layers.Resizing(224, 224)(inputs)
    x = keras.layers.RandomFlip("horizontal")(x)
    x = keras.layers.RandomRotation(0.1)(x)
    x = keras.layers.RandomZoom(0.1)(x)
    x = keras.layers.RandomContrast(factor=0.01)(x) 
    
    # Preprocess according to the base model
    x = preprocess_input(x)
    
    # Pass through the base model
    x = base_model(x)
    
    # Head
    x = keras.layers.GlobalAveragePooling2D()(x)
    x = keras.layers.BatchNormalization(axis=-1)(x)
    x = keras.layers.Dropout(rate=0.25)(x)
    x = keras.layers.Dense(512, activation="relu")(x)
    x = keras.layers.BatchNormalization(axis=-1)(x)
    x = keras.layers.Dropout(rate=0.5)(x)

    outputs = keras.layers.Dense(5, activation="softmax")(x)

    # Create and compile the model
    model = keras.Model(inputs=inputs, outputs=outputs)

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False),
        metrics="accuracy",
    )

    return model

In [None]:
model = make_model(base_model=base_model_rn, preprocess_input=preprocess_input_rn)
#tf.keras.utils.plot_model(model, show_shapes=True)

Let's train it. We'll add a TensorBoard callback. More on that below.

In [None]:
tensorboard_callback = keras.callbacks.TensorBoard(log_dir="./logs")

In [None]:
history = model.fit(
    train_dataset,
    epochs=3,
    validation_data=test_dataset,
    callbacks=[tensorboard_callback]
)

### Tensorboard

It's useful to follow along with the training process, visualize and inspect the trained model using TensorBoard: 

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs

# Evaluate

Now we can evaluate the model in all the ways discussed previously. Here we simply compute the accuract on the validation data plot some predictions. 

In [None]:
loss, acc = model.evaluate(test_dataset)

In [None]:
print(f"Loss: {loss}\n Accuracy: {acc}")

In [None]:
def show_batch_predictions(image_batch, label_batch):
    fig = plt.figure(figsize=(12,12))
    fig.suptitle("prediction / actual", y=0.93)
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        img_array = tf.expand_dims(image_batch[n], axis=0)
        title = f"{labels_dict[model.predict(img_array)[0].argmax()]}/ {labels_dict[label_batch.numpy()[n]]}"
        plt.title(title, fontsize=10)
        plt.axis("off")

In [None]:
image_batch, label_batch = next(iter(test_dataset))

show_batch_predictions(image_batch, label_batch)