# Imports and Setups

In [None]:
! pip install -q -U tensorflow-addons

In [None]:
import tensorflow as tf
tf.keras.utils.set_random_seed(42)

import tensorflow_addons as tfa
import tensorflow_datasets as tfds

from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import numpy as np

import math

# Hyperparameters

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# DATA
DATA_DIR = "/content/drive/MyDrive/Colab Notebooks/ViT"
BUFFER_SIZE = 1024
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 20

# AUGMENTATION
IMAGE_SIZE = 48  # We will resize input images to this size.
PATCH_SIZE = 6  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE HYPERPARAMETERS
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 512
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]
MLP_HEAD_UNITS = [
    2048,
    1024,
]

In [None]:
train_ds, val_ds, test_ds = tfds.load(
    name="cifar10",
    data_dir=DATA_DIR,
    split=["train[:90%]", "train[90%:]", "test"],
    as_supervised=True
)

In [None]:
train_ds = train_ds.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(AUTO)
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)

# Augmentation

In [None]:
def get_train_augmentation_model():
    model = keras.Sequential(
        [
            layers.Rescaling(1 / 255.0),
            layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
            layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
            layers.RandomFlip("horizontal"),
        ],
        name="train_data_augmentation",
    )
    return model


def get_test_augmentation_model():
    model = keras.Sequential(
        [
            layers.Rescaling(1 / 255.0),
            layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
        ],
        name="test_data_augmentation",
    )
    return model

# Patches

In [None]:
class Patches(layers.Layer):
    def __init__(self, patch_size=PATCH_SIZE, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

        # Assuming the image has three channels each patch would be
        # of size (patch_size, patch_size, 3).
        self.resize = layers.Reshape((-1, patch_size * patch_size * 3))

    def call(self, images):
        # Create patches from the input images
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )

        # Reshape the patches to (batch, num_patches, patch_area) and return it.
        patches = self.resize(patches)
        return patches

    def show_patched_image(self, images, patches):
        # This is a utility function which accepts a batch of images and its
        # corresponding patches and help visualize one image and its patches
        # side by side.
        idx = np.random.choice(patches.shape[0])
        print(f"Index selected: {idx}.")

        plt.figure(figsize=(4, 4))
        plt.imshow(keras.utils.array_to_img(images[idx]))
        plt.axis("off")
        plt.show()

        n = int(np.sqrt(patches.shape[1]))
        plt.figure(figsize=(4, 4))
        for i, patch in enumerate(patches[idx]):
            ax = plt.subplot(n, n, i + 1)
            patch_img = tf.reshape(patch, (self.patch_size, self.patch_size, 3))
            plt.imshow(keras.utils.img_to_array(patch_img))
            plt.axis("off")
        plt.show()

        # Return the index chosen to validate it outside the method.
        return idx

    # taken from https://stackoverflow.com/a/58082878/10319735
    def reconstruct_from_patch(self, patch):
        # This utility function takes patches from a *single* image and
        # reconstructs it back into the image. This is useful for the train
        # monitor callback.
        num_patches = patch.shape[0]
        n = int(np.sqrt(num_patches))
        patch = tf.reshape(patch, (num_patches, self.patch_size, self.patch_size, 3))
        rows = tf.split(patch, n, axis=0)
        rows = [tf.concat(tf.unstack(x), axis=1) for x in rows]
        reconstructed = tf.concat(rows, axis=0)
        return reconstructed

In [None]:
# Get a batch of images.
images, labels = next(iter(train_ds))

# Augment the images.
augmentation_model = get_train_augmentation_model()
augmented_images = augmentation_model(images)

# Define the patch layer.
patch_layer = Patches()

# Get the patches from the batched images.
patches = patch_layer(images=augmented_images)

# Now pass the images and the corresponding patches
# to the `show_patched_image` method.
random_index = patch_layer.show_patched_image(images=augmented_images, patches=patches)

# Chose the same chose image and try reconstructing the patches
# into the original image.
image = patch_layer.reconstruct_from_patch(patches[random_index])
plt.imshow(image)
plt.axis("off")
plt.show()

In [None]:
print(images.shape)
print(patches.shape)

# Patch Encoder

In [None]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches=NUM_PATCHES, patch_size=PATCH_SIZE,
        projection_dim=PROJECTION_DIM):
        super().__init__()
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        self.patch_size = patch_size
        
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        batch_size = tf.shape(patches)[0]
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

In [None]:
# Create the patch encoder layer.
patch_encoder = PatchEncoder()

# Get the embeddings and positions.
patch_embeddings = patch_encoder(patches=patches)

print(patch_embeddings.shape)

# MLP

In [None]:
def mlp(x, dropout_rate, hidden_units):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# ViT

In [None]:
class TokenLearner(layers.Layer):
    def __init__(self, number_of_tokens, **kwargs):
        super().__init__(**kwargs)
        self.number_of_tokens = number_of_tokens
    
    def build(self, input_shape):
        _, H, W, C = input_shape

        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.conv_block = keras.Sequential([
            layers.Conv2D(
                filters=self.number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False),
            layers.Conv2D(
                filters=self.number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False),
            layers.Conv2D(
                filters=self.number_of_tokens,
                kernel_size=(3, 3),
                activation=tf.nn.gelu,
                padding="same",
                use_bias=False),
            layers.Conv2D(
                filters=self.number_of_tokens,
                kernel_size=(3, 3),
                activation="sigmoid",
                padding="same",
                use_bias=False),
            layers.Reshape((-1, self.number_of_tokens)),
            layers.Permute((2, 1)),
        ])

        self.reshape_input = layers.Reshape((1, H*W, C))
    
    def call(self, inputs):
        # inputs == (B, H, W, C)
        x = self.layer_norm(inputs)

        # apply conv on the input
        x = self.conv_block(x) # B, num_of_tokens, H*W

        # reshape the input
        inputs = self.reshape_input(inputs) # inputs == (B, 1, H*W, C)
        x = tf.reduce_mean(x[..., tf.newaxis] * inputs, axis=2)
        return x

In [None]:
img = tf.random.normal((2, 4, 4, 256))
token_learner = TokenLearner(8)
out = token_learner(img)

out.shape

In [None]:
def get_encoder(num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM, 
    num_heads=NUM_HEADS, num_layers=NUM_LAYERS):
    # inputs are the encoded patches
    inputs = layers.Input((num_patches, projection_dim))
    
    x = inputs
    for i in range(num_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x)

        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)

        # Skip connection 1.
        x2 = layers.Add()([attention_output, x])

        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

        # MLP.
        x3 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

        # Skip connection 2.
        x = layers.Add()([x3, x2])

        if i == num_layers//2:
            b, num, dims = x.shape

            h = int(math.sqrt(num))
            x = layers.Reshape((h, h, dims))(x)
            x = TokenLearner(8)(x)
    # return the model
    return keras.Model(inputs=inputs, outputs=x)

In [None]:
keras.backend.clear_session()
# Get the encoder
encoder = get_encoder()

encoded_features = encoder(patch_embeddings)

print(encoded_features.shape)

# MLP Head

In [None]:
def get_mlp_head(projection_dim=PROJECTION_DIM, num_classes=NUM_CLASSES):
    inputs = layers.Input((8, projection_dim))
    
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)
    x = layers.Flatten()(x)
    # Add MLP.
    x = layers.Dense(units=1024, activation=tf.nn.gelu)(x)
    x = layers.Dense(units=num_classes, activation="softmax")(x)

    return keras.Model(inputs, x)

In [None]:
mlp_head = get_mlp_head()
preds = mlp_head(encoded_features)
print(preds.shape)

In [None]:
class ViT(keras.Model):
    def __init__(
        self,
        train_augmentation_model,
        test_augmentation_model,
        patch_layer,
        patch_encoder,
        encoder,
        mlp_head,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.train_augmentation_model = train_augmentation_model
        self.test_augmentation_model = test_augmentation_model
        self.patch_layer = patch_layer
        self.patch_encoder = patch_encoder
        self.encoder = encoder
        self.mlp_head = mlp_head

    def calculate_loss(self, images, labels, test=False):
        # Augment the input images.
        if test:
            augmented_images = self.test_augmentation_model(images)
        else:
            augmented_images = self.train_augmentation_model(images)

        # Patch the augmented images.
        patches = self.patch_layer(augmented_images)

        # Encode the patches.
        patch_embeddings = self.patch_encoder(patches)

        encoded_features = self.encoder(patch_embeddings)
        predictions = self.mlp_head(encoded_features)

        total_loss = self.compiled_loss(labels, predictions)

        return total_loss, predictions

    def train_step(self, inputs):
        # get the image and the label
        images, labels = inputs

        with tf.GradientTape() as tape:
            total_loss, predictions = self.calculate_loss(images, labels, test=False)

        # Apply gradients.
        train_vars = [
            self.train_augmentation_model.trainable_variables,
            self.patch_layer.trainable_variables,
            self.patch_encoder.trainable_variables,
            self.encoder.trainable_variables,
            self.mlp_head.trainable_variables,
        ]
        grads = tape.gradient(total_loss, train_vars)
        tv_list = []
        for (grad, var) in zip(grads, train_vars):
            for g, v in zip(grad, var):
                tv_list.append((g, v))
        self.optimizer.apply_gradients(tv_list)

        # Report progress.
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, inputs):
        # get the image and the label
        images, labels = inputs

        # compute the predictions and the loss
        total_loss, predictions = self.calculate_loss(images, labels, test=True)

        # Update the trackers.
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

In [None]:
train_augmentation_model = get_train_augmentation_model()
test_augmentation_model = get_test_augmentation_model()
patch_layer = Patches()
patch_encoder = PatchEncoder()
encoder = get_encoder()
mlp_head = get_mlp_head()

vit_model = ViT(
    train_augmentation_model=train_augmentation_model,
    test_augmentation_model=test_augmentation_model,
    patch_layer=patch_layer,
    patch_encoder=patch_encoder,
    encoder = encoder,
    mlp_head = mlp_head
)

In [None]:
optimizer = tfa.optimizers.AdamW(
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)

# Compile and pretrain the model.
vit_model.compile(
    optimizer=optimizer,
    loss=keras.losses.SparseCategoricalCrossentropy(),
    metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
    ],
)


history = vit_model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
)

In [None]:
_, accuracy, top_5_accuracy = vit_model.evaluate(test_ds)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")