Title: Image classification with Vision Transformer
Description: Implementing the Vision Transformer (ViT) model for image classification.

## Introduction

This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
model by Alexey Dosovitskiy et al. for image classification,
and demonstrates it on the so2sat dataset.
The ViT model applies the Transformer architecture with self-attention to sequences of
image patches, without using convolution layers.

This example requires TensorFlow 2.4 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons/overview),
which can be installed using the following command:

```python
pip install -U tensorflow-addons
```

## Setup

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [2]:
import tensorflow_datasets as tfds

In [45]:
from tensorflow import __version__ as tf_ver

major, minor, point = [int(t) for t in str(tf_ver).split('.')]

assert major >= 2 and minor >= 4, f"This example requires TensorFlow 2.4 or higher, but the installed version is {tf_ver} "

AssertionError: This example requires TensorFlow 2.4 or higher, but the installed version is 2.3.1 

## Prepare the data

### Construct a tf.data.Dataset

In [3]:
# ! sudo umount /mnt/tb && sudo rmdir mnt/tb && sudo rmdir mnt

In [4]:
# ! sudo mkdir -p /mnt/tb && sudo mount -o rw /dev/xvdg /mnt/tb && sudo chmod 777 /mnt/tb

In [5]:
from shutil import disk_usage
from pathlib import Path

downloaded_samples = 7.5e3
downloaded_gb = 7.7 # so, about 1 GB per 1k samples, or 1 MB per sample.
total_samples = 53e3
total_gb = total_samples * downloaded_gb // downloaded_samples # this is how much free space we need to download the so2sat dataset

data_dir = Path("/mnt/tb/")
fs = data_dir
_, _, free_bytes = disk_usage(fs)

giga = 1 << 30
free_gb = free_bytes // giga
assert total_gb < (free_gb - 1), f"required memory = {total_gb} GB < {free_gb} GB = free memory"
print(f"{free_gb} GB free on {fs}")

920 GB free on /mnt/tb


In [6]:
! du /mnt/tb -h -d2

0	/mnt/tb/downloads/extracted
52G	/mnt/tb/downloads
26G	/mnt/tb/so2sat/all
26G	/mnt/tb/so2sat
78G	/mnt/tb


In [7]:
# Dataset: https://www.tensorflow.org/datasets/catalog/so2sat (53k samples)
# Website: http://www.so2sat.eu/
# Paper:   https://arxiv.org/abs/1912.12171
# Code:    https://github.com/zhu-xlab/So2Sat-LCZ42#description-of-the-content-of-sen1

from tempfile import tempdir

tempdir = Path(data_dir, "tmp").as_posix()
train_ds, test_ds = tfds.load('so2sat/all', split=['train', 'validation'], data_dir=data_dir)
train_ds, len(train_ds), test_ds, len(test_ds)

(<PrefetchDataset shapes: {label: (), sample_id: (), sentinel1: (32, 32, 8), sentinel2: (32, 32, 10)}, types: {label: tf.int64, sample_id: tf.int64, sentinel1: tf.float32, sentinel2: tf.float32}>,
 352366,
 <PrefetchDataset shapes: {label: (), sample_id: (), sentinel1: (32, 32, 8), sentinel2: (32, 32, 10)}, types: {label: tf.int64, sample_id: tf.int64, sentinel1: tf.float32, sentinel2: tf.float32}>,
 24119)

### Build your input pipeline

In [8]:
%matplotlib notebook
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

In [9]:
n = 4
plt.figure(figsize=(n, n))

# train_ds = train_ds_pref.shuffle(1024).batch(32).prefetch(tf.data.AUTOTUNE)
x_for_plot = train_ds.shuffle(1024).batch(32).prefetch(1024)
for i, example in enumerate(x_for_plot.take(n*n)):
  ax = plt.subplot(n, n, i + 1)
  image, label = example["sentinel1"], example["label"]
  plt.imshow(image[0,:,:,0:3].numpy().astype("float32"))
  ax.title.set_text(label[0].numpy())
plt.show()

TclError: couldn't connect to display "localhost:10.0"

In [25]:
from tqdm import tqdm

num_classes = 17
input_shape = (32, 32, 8)

y_train, ids_train, x_train = list(zip(*[[ex['label'], ex['sample_id'], ex['sentinel1']] for ex in tqdm(tfds.as_numpy(train_ds), total=len(train_ds))]))
y_test, ids_test, x_test = list(zip(*[[ex['label'], ex['sample_id'], ex['sentinel1']] for ex in tqdm(tfds.as_numpy(test_ds), total=len(test_ds))]))

print(type(x_train), len(x_train), x_train[0].shape)

from numpy import array
y_train, ids_train, x_train = array(y_train), array(ids_train), array(x_train)
y_test, ids_test, x_test = array(y_test), array(ids_test), array(x_test)


print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

100%|██████████| 352366/352366 [01:40<00:00, 3521.73it/s]
100%|██████████| 24119/24119 [00:06<00:00, 3588.58it/s]


<class 'tuple'> 352366 (32, 32, 8)
x_train shape: (352366, 32, 32, 8) - y_train shape: (352366,)
x_test shape: (24119, 32, 32, 8) - y_test shape: (24119,)


## Configure the hyperparameters

In [26]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

## Use data augmentation

In [34]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## Implement multilayer perceptron (MLP)

In [35]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

## Implement patch creation as a layer

In [36]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

Let's display patches for a sample image

In [37]:
import matplotlib.pyplot as plt

plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = tf.image.resize(
    tf.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")

TclError: couldn't connect to display "localhost:10.0"

## Implement the patch encoding layer

The `PatchEncoder` layer will linearly transform a patch by projecting it into a
vector of size `projection_dim`. In addition, it adds a learnable position
embedding to the projected vector.

In [38]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

## Build the ViT model

The ViT model consists of multiple Transformer blocks,
which use the `layers.MultiHeadAttention` layer as a self-attention mechanism
applied to the sequence of patches. The Transformer blocks produce a
`[batch_size, num_patches, projection_dim]` tensor, which is processed via an
classifier head with softmax to produce the final class probabilities output.

Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929),
which prepends a learnable embedding to the sequence of encoded patches to serve
as the image representation, all the outputs of the final Transformer block are
reshaped with `layers.Flatten()` and used as the image
representation input to the classifier head.
Note that the `layers.GlobalAveragePooling1D` layer
could also be used instead to aggregate the outputs of the Transformer block,
especially when the number of patches and the projection dimensions are large.

In [39]:
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

## Compile, train, and evaluate the mode

In [40]:
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
history

AttributeError: module 'tensorflow.keras.layers' has no attribute 'MultiHeadAttention'

After 100 epochs, the ViT model achieves around 55% accuracy and
82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset,
as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy.

Note that the state of the art results reported in the
[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using
the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality
without pre-training, you can try to train the model for more epochs, use a larger number of
Transformer layers, resize the input images, change the patch size, or increase the projection dimensions. 
Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices, 
but also by parameters such as the learning rate schedule, optimizer, weight decay, etc.
In practice, it's recommended to fine-tune a ViT model
that was pre-trained using a large, high-resolution dataset.