# Imports and Setups

In [None]:
!pip install -q -U tensorflow-addons

In [None]:
import tensorflow as tf

tf.keras.utils.set_random_seed(42)

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np

import math

# Hyperparameters

In [None]:
# DATA
BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
INPUT_SHAPE = (32, 32, 3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 20

# AUGMENTATION
IMAGE_SIZE = 48  # We will resize input images to this size.
PATCH_SIZE = 6  # Size of the patches to be extracted from the input images.
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE) ** 2

# ViT ARCHITECTURE HYPERPARAMETERS
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4
MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)

# Augmentation

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.Rescaling(1 / 255.0),
        layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
        layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal"),
    ],
    name="train_data_augmentation",
)

# Patch Encoder

In [None]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

# MLP

In [None]:
def mlp(x, dropout_rate, hidden_units):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

# TokenLearner

In [None]:
class TokenLearner(layers.Layer):
    def __init__(self, number_of_tokens, **kwargs):
        super().__init__(**kwargs)
        self.number_of_tokens = number_of_tokens

    def build(self, input_shape):
        _, H, W, C = input_shape

        self.layer_norm = layers.LayerNormalization(epsilon=1e-6)
        self.conv_block = keras.Sequential(
            [
                layers.Conv2D(
                    filters=self.number_of_tokens,
                    kernel_size=(3, 3),
                    activation=tf.nn.gelu,
                    padding="same",
                    use_bias=False,
                ),
                layers.Conv2D(
                    filters=self.number_of_tokens,
                    kernel_size=(3, 3),
                    activation=tf.nn.gelu,
                    padding="same",
                    use_bias=False,
                ),
                layers.Conv2D(
                    filters=self.number_of_tokens,
                    kernel_size=(3, 3),
                    activation=tf.nn.gelu,
                    padding="same",
                    use_bias=False,
                ),
                layers.Conv2D(
                    filters=self.number_of_tokens,
                    kernel_size=(3, 3),
                    activation="sigmoid",
                    padding="same",
                    use_bias=False,
                ),
                layers.Reshape((-1, self.number_of_tokens)),
                layers.Permute((2, 1)),
            ]
        )

        self.reshape_input = layers.Reshape((1, H * W, C))

    def call(self, inputs):
        # inputs == (B, H, W, C)
        x = self.layer_norm(inputs)

        # apply conv on the input
        x = self.conv_block(x)  # B, num_of_tokens, H*W

        # reshape the input
        inputs = self.reshape_input(inputs)  # inputs == (B, 1, H*W, C)
        x = tf.reduce_mean(x[..., tf.newaxis] * inputs, axis=2)
        return x

## ViT model with optional TokenLearner

In [None]:
def create_vit_classifier(use_token_learner=True, token_learner_units=8):
    inputs = layers.Input(shape=INPUT_SHAPE)
    # Augment data.
    augmented = data_augmentation(inputs)

    # Create patches.
    patches = layers.Conv2D(
        PROJECTION_DIM,
        (PATCH_SIZE, PATCH_SIZE),
        strides=(PATCH_SIZE, PATCH_SIZE),
        padding="VALID",
    )(augmented)
    _, h, w, c = patches.shape
    patches = layers.Reshape((-1, h * w * c))(patches)

    # Encode patches.
    encoded_patches = PatchEncoder(NUM_PATCHES, PROJECTION_DIM)(patches)
    encoded_patches = layers.Dropout(0.1)(encoded_patches)

    # Create multiple layers of the Transformer block.
    for i in range(NUM_LAYERS):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
        )(x1, x1)

        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

        # MLP.
        x3 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

        # Add TokenLearner.
        if use_token_learner and i == NUM_LAYERS // 2:
            _, num_tokens, projection_dim = encoded_patches.shape
            h = int(math.sqrt(num_tokens))
            encoded_patches = layers.Reshape((h, h, projection_dim))(encoded_patches)
            encoded_patches = TokenLearner(token_learner_units)(encoded_patches)

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

## Training utility

In [None]:
def run_experiment(model, use_token_learner=True):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )
    timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
    root_dir = "logs-tokenlearner" if use_token_learner else "logs-no-tokenlearner"
    tensorborad_callback = keras.callbacks.TensorBoard(
        log_dir=f"{root_dir}-{timestamp}"
    )

    _ = model.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=val_ds,
        callbacks=[checkpoint_callback, tensorborad_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

## Experiments

In [None]:
# Should at least be 5.
num_trials = 1

In [None]:
# With TokenLearner.
for _ in range(num_trials):
    vit_token_learner = create_vit_classifier()
    run_experiment(vit_token_learner, use_token_learner=True)

In [None]:
# Without TokenLearner.
for _ in range(num_trials):
    vit = create_vit_classifier(use_token_learner=False)
    run_experiment(vit, use_token_learner=False)

In [None]:
create_vit_classifier().count_params(), create_vit_classifier(
    use_token_learner=False
).count_params()

## References

* [Official TokenLearner code](https://github.com/google-research/scenic/blob/main/scenic/projects/token_learner/model.py)
* [Image Classification with ViTs on keras.io](https://keras.io/examples/vision/image_classification_with_vision_transformer/)