# Traffic sign classification with BiT

In this notebook, we'll finetune a [BigTransfer](https://arxiv.org/abs/1912.11370) (BiT) model from [TensorFlow Hub](https://tfhub.dev/) to classify images of traffic signs from [The German Traffic Sign Recognition Benchmark](http://benchmark.ini.rub.de/?section=gtsrb&subsection=news) using TensorFlow 2 / Keras. This notebook is somewhat based on the Keras code example [Image Classification using BigTransfer (BiT)](https://keras.io/examples/vision/bit/) by Sayan Nath.

**Note that using a GPU with this notebook is highly recommended.**

First, the needed imports.

In [None]:
%matplotlib inline

import os, datetime
import pathlib

import tensorflow as tf
import tensorflow_hub as hub

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import applications, optimizers

from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.utils import plot_model

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()
from PIL import Image

print('Using Tensorflow version:', tf.__version__,
      'Keras version:', tf.keras.__version__,
      'backend:', tf.keras.backend.backend())

## Data

The training dataset consists of 5535 images of traffic signs of varying size. There are 43 different types of traffic signs:

![title](imgs/gtsrb-montage.png)

The validation and test sets consist of 999 and 12630 images, respectively.

In [None]:
datapath = "/media/data/gtsrb/train-5535/"
nimages = {'train':5535, 'validation':999, 'test':12630}

### Image paths and labels

In [None]:
def get_paths(dataset):
    data_root = pathlib.Path(datapath+dataset)
    image_paths = list(data_root.glob('*/*'))
    image_paths = [str(path) for path in image_paths]
    image_count = len(image_paths)
    assert image_count == nimages[dataset], "Found {} images, expected {}".format(image_count, nimages[dataset])
    return image_paths

image_paths = dict()
image_paths['train'] = get_paths('train')
image_paths['validation'] = get_paths('validation')
image_paths['test'] = get_paths('test')

In [None]:
label_names = sorted(item.name for item in pathlib.Path(datapath+'train').glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index,name in enumerate(label_names))

def get_labels(dataset):
    return [label_to_index[pathlib.Path(path).parent.name]
            for path in image_paths[dataset]]
    
image_labels = dict()
image_labels['train'] = get_labels('train')
image_labels['validation'] = get_labels('validation')
image_labels['test'] = get_labels('test')

### Data loading

We now define a function to load the images. The images are in PPM format, so we use the PIL library. Also we need to resize the images to a fixed size (`INPUT_IMAGE_SIZE`).

In [None]:
INPUT_IMAGE_SIZE = [80, 80]

def _load_image(path, label):
    image = Image.open(path.numpy())
    return np.array(image), label

def load_image(path, label):
    image, label = tf.py_function(_load_image, (path, label), (tf.float32, tf.int32))
    image.set_shape([None, None, None])
    label.set_shape([])
    return tf.image.resize(image, INPUT_IMAGE_SIZE), label

### TF Datasets

Let's now define our [TF `Dataset`s](https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/data/Dataset#class_dataset) for training, validation, and test data. 

In [None]:
BATCH_SIZE = 50

train_dataset = tf.data.Dataset.from_tensor_slices((image_paths['train'],
                                                    image_labels['train']))
train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(2000).batch(BATCH_SIZE, drop_remainder=True)
train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

validation_dataset = tf.data.Dataset.from_tensor_slices((image_paths['validation'],
                                                         image_labels['validation']))
validation_dataset = validation_dataset.map(load_image,
                                            num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = validation_dataset.batch(BATCH_SIZE, drop_remainder=True)
validation_dataset = validation_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

test_dataset = tf.data.Dataset.from_tensor_slices((image_paths['test'],
                                                   image_labels['test']))
test_dataset = test_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE, drop_remainder=False)
test_dataset = test_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

Let's see a couple of our training images:

In [None]:
plt.figure(figsize=(10,10))
for batch, labels in train_dataset.take(1):
    for i in range(9):    
        plt.subplot(3,3,i+1)
        plt.imshow(tf.cast(batch[i,:,:,:], tf.int32))
        plt.title(label_names[labels[i]])
        plt.grid(False)
        plt.xticks([])
        plt.yticks([])
    plt.suptitle('some training images', fontsize=16, y=0.93)

## BiT

### Initialization

Now we specify the pre-trained BiT model we are going to use. The model ["BiT-M R50x1"](https://tfhub.dev/google/bit/m-r50x1/1) is pre-trained on ImageNet-21k (14 million images, 21,843 classes). It outputs 2048-dimensional feature vectors.

In [None]:
bit_model_url = "https://tfhub.dev/google/bit/m-r50x1/1"
bit_model = hub.KerasLayer(bit_model_url)

First we'll apply random augmentation transformations (small random crop and contrast adjustment) to them each time we are looping over them. This way, we "augment" our training dataset to contain more data. The augmentation transformations are implemented as preprocessing layers in Keras. There are various such layers readily available, see [https://keras.io/guides/preprocessing_layers/](https://keras.io/guides/preprocessing_layers/) for more information.

Then we add the BiT model as a layer and finally add the output layer with 43 units and softmax activation. Note that we initialize the output layer to all zeroes as instructed in https://keras.io/examples/vision/bit/.

In [None]:
inputs = keras.Input(shape=INPUT_IMAGE_SIZE+[3])
x = layers.Rescaling(scale=1./255)(inputs)

x = layers.RandomCrop(75, 75)(x)
x = layers.RandomContrast(0.1)(x)

x = bit_model(x)

outputs = layers.Dense(43, kernel_initializer="zeros",
                       activation='softmax')(x)

model = keras.Model(inputs=inputs, outputs=outputs,
                    name="gtsrb-bit")

In [None]:
learning_rate, momentum = 0.003, 0.9

optimizer = keras.optimizers.SGD(learning_rate=learning_rate,
                                 momentum=momentum)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False)

model.compile(loss=loss_fn, optimizer=optimizer,
              metrics=['accuracy'])

print(model.summary())

In [None]:
plot_model(model, 'tf2-gtsrb-bit.png', show_shapes=True)

### Learning

We'll set up two callbacks. *EarlyStopping* is used to stop training when the monitored metric has stopped improving. *TensorBoard* is used to visualize our progress during training.

In [None]:
logdir = os.path.join(
    os.getcwd(), "logs",
    "gtsrb-bit-"+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
print('TensorBoard log directory:', logdir)
os.makedirs(logdir)

callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy", patience=4, restore_best_weights=True),
    TensorBoard(log_dir=logdir)]

In [None]:
%%time

EPOCHS = 20

history = model.fit(train_dataset, batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    validation_data=validation_dataset,
                    callbacks=callbacks)
model.save("gtsrb-bit.h5")

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,3))

ax1.plot(history.epoch,history.history['loss'], label='training')
ax1.plot(history.epoch,history.history['val_loss'], label='validation')
ax1.set_title('loss')
ax1.set_xlabel('epoch')
ax1.legend(loc='best')

ax2.plot(history.epoch,history.history['accuracy'], label='training')
ax2.plot(history.epoch,history.history['val_accuracy'], label='validation')
ax2.set_title('accuracy')
ax2.set_xlabel('epoch')
ax2.legend(loc='best');

### Inference

In [None]:
%%time

scores = model.evaluate(test_dataset, verbose=2)
print("Test set %s: %.2f%%" % (model.metrics_names[1], scores[1]*100))