# Image Classification with Vision Transformers

```{article-info}
:avatar: https://avatars.githubusercontent.com/u/25820201?v=4
:avatar-link: https://github.com/PhotonicGluon/
:author: "[Ryan Kan](https://github.com/PhotonicGluon/)"
:date: "Jul 9, 2024"
:read-time: "{sub-ref}`wordcount-minutes` min read"
```

*This notebook is largely inspired by the Keras code example [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/) by [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/).*

<center>
    <img alt="CIFAR-100 Dataset" style="width: 25%" src="https://storage.googleapis.com/kaggle-datasets-images/1059701/1782442/763cab8e6130dad7ff7102abdfef54a0/dataset-card.jpg">
</center>

In this example, we will classify images in the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) using the [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929v2) model by Alexey Dosovitskiy et al. using Keras-MML layers.

:::{important}
We will be using some plotting utilities for this notebook. Run the command below to install them, then reload the kernel.
:::

In [9]:
%pip install matplotlib~=3.9.0 seaborn~=0.13.2

  pid, fd = os.forkpty()


Note: you may need to restart the kernel to use updated packages.


:::{note}
We will use the `jax` backend for faster execution of the code. Feel free to ignore the cell below.
:::

In [1]:
import os

os.environ["KERAS_BACKEND"] = "jax"

## Preparing the Data

Conveniently, the CIFAR-100 dataset is already available in Keras, so we just need to load it from there.

In [6]:
import keras

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz


Exception: URL fetch failure on https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz: None -- [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1007)

Let's take a look at the shapes of the downloaded arrays.

In [15]:
print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

NameError: name 'x_train' is not defined

The CIFAR-100 dataset contains 100 distinct classes. Each image in the dataset is $32 \times 32$ with 3 channels, meaning that the `INPUT_SHAPE` for our model is `(32, 32, 3)`.

In [16]:
num_classes = 100  # TODO: RENAME
input_shape = (32, 32, 3)  # TODO: RENAME

For actual processing, let's resize the images so that we get more *patches* that the ViT learns from.

In [17]:
image_size = 72  # TODO: RENAME

To improve the performance of the model, let's perform some data augmentation on the images.

In [18]:
from keras import layers

data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(factor=0.02),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)

# Compute the mean and the variance of the training data for normalization
data_augmentation.layers[0].adapt(x_train)


2024-07-09 08:47:06.204757: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:984] could not open file to read NUMA node: /sys/bus/pci/devices/0000:01:00.0/numa_node
Your kernel may have been built without NUMA support.
2024-07-09 08:47:06.205293: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


NameError: name 'x_train' is not defined

# Model Creation

The ViT involves a few steps.

1. Creating the patches from a given image.
2. Generating an embedding for the patches.
3. Passing the embeddings into a transformer block to generate 'better' embeddings.
4. Using the 'better' embeddings on the classification task.

Keras-MML implemented the required layers from steps 1 to 3.

In [19]:
import keras_mml

Let's define the size of the patches that we want.

In [20]:
patch_size = 6  # TODO: RENAME
num_patches = (image_size // patch_size) ** 2  # TODO: RENAME

Let's display the patches for a sample image. This is done through the `Patches` layer.

In [21]:
import matplotlib.pyplot as plt
import numpy as np
from keras import ops

plt.figure(figsize=(4, 4))
image = x_train[0]  # Just as an example
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = ops.image.resize(
    ops.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = keras_mml.layers.Patches(patch_size)(resized_image) #  Patch generation layer
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = ops.reshape(patch, (patch_size, patch_size, 3))  # Make it back into RGB
    plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
    plt.axis("off")

NameError: name 'x_train' is not defined

<Figure size 400x400 with 0 Axes>

To generate the embeddings for the patches, we can use the `PatchEmbedding` layer that was also included in Keras-MML. This layer will encode each patch as a `PROJECTION_DIM`-dimensional vector that can be used in the transformer block that is incoming.

In [22]:
projection_dim = 64  # TODO: RENAME

We are now ready to create the full model. We will use the `TRANSFORMER_LAYERS` hyperparameter to specify the number of transformer blocks to use in the ViT.

In [26]:
transformer_layers = 8  # TODO: RENAME
num_heads = 4  # TODO: RENAME

In [28]:
model = keras.models.Sequential()
model.add(layers.Input(shape=input_shape))

# Augment the data
model.add(data_augmentation)

# Create patches
model.add(keras_mml.layers.Patches(patch_size))

# Create patch embeddings
model.add(keras_mml.layers.PatchEmbedding(num_patches, projection_dim, with_positions=True))

# Use multiple transformer blocks
for _ in range(transformer_layers):
    model.add(keras_mml.layers.TransformerBlockMML(projection_dim, projection_dim * 2, num_heads, rate=0.1))
    
# Normalize, flatten, and dropout
model.add(layers.LayerNormalization(epsilon=1e-6))
model.add(layers.Flatten())
model.add(layers.Dropout(0.5))

# Add SwiGLUMML for final classification fine tuning
model.add(keras_mml.layers.SwiGLUMML(1024))
model.add(keras_mml.layers.SwiGLUMML(256))

# Final classification head
model.add(layers.Dense(num_classes))

model.summary()

2024-07-09 08:48:22.796707: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 1811939328 exceeds 10% of free system memory.
2024-07-09 08:48:23.021657: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 1811939328 exceeds 10% of free system memory.


TODO: CONTINUE