# Vision Transformer optimisation using TFMOT

Example notebook to demonstrate how TFMOT can be used for optimising complex transformer models such as ViT.

## Background

The [Vision Transformer (ViT)](https://arxiv.org/pdf/2010.11929.pdf) architecture uses stacked transformer encoder blocks to process images for certain tasks. The encoder blocks are architecturally similar to the popular [NLP transformers](https://arxiv.org/pdf/1706.03762.pdf). The inputs to the transformer encoders are embeddings of patches extracted from the image. For a classification task, an additional feed forward network is added to the end.

<img src="https://github.com/google-research/vision_transformer/blob/main/vit_figure.png?raw=true" alt="Vision Transformer architecture" width="700"/>

In this notebook:
1. Firstly a ViT model is created and trained from scratch on the MNIST dataset. In practice, pre-trained weights can also be loaded.
2. Afterwards, unstructured weight pruning, clustering and quantisation aware training (QAT) techniques are applied sequentially using the collaborative optimisation features of the [TensorFlow Model Optimization Toolkit (TFMOT)](https://www.tensorflow.org/model_optimization).
3. Finally, an integer-only TFLite model is generated and tested.

## TFMOT limitations
- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)
- Custom subclassed layers are not supported. (Clustering & QAT)
    - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).
    - QAT works correctly if the subclassed layer performs only 1 operation.
- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)
    - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`.

In [None]:
import math
import numpy as np
import tensorflow as tf
import tensorflow_model_optimization as tfmot

tf.random.set_seed(0)

print('TensorFlow version: {}'.format(tf.__version__))
print('TFMOT version: {}'.format(tfmot.__version__))

## Model definition

Due to the above-mentioned limitations, custom Keras layers must be defined for all of the low-level TensorFlow operators in order to perform QAT (each layer must only contain a single operation).

Since none of these will have any prunable/clusterable weights, first we create a base prunable clusterable layer class to extend, instead of `tf.keras.layers.Layer`. If any of the weights in the custom layers should be pruned or clustered, a list of the weights should be provided in the respective method. Refer to the TFMOT documentation for more details.

In [None]:
class PrunableClusterableLayer(tf.keras.layers.Layer,
                               tfmot.sparsity.keras.PrunableLayer,
                               tfmot.clustering.keras.ClusterableLayer):
    def get_prunable_weights(self): return []
    def get_clusterable_weights(self): return []

### 1. Define each of the TensorFlow operations ViT uses as a Keras subclassed layer:

Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered.

In [None]:
class MatMul(PrunableClusterableLayer):
    def __init__(self, transpose_b=False, **kwargs):
        super().__init__(**kwargs)
        self.transpose_b = transpose_b

    def call(self, inputs):
        return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)

    def get_config(self):
        config = super().get_config()
        config.update({'transpose_b': self.transpose_b})
        return config

class Multiply(PrunableClusterableLayer):
    def call(self, inputs):
        return tf.multiply(*inputs)

# Calling Multiply with a scalar input will lead to an error.
# Use the following ScalarMultiply class instead.
class ScalarMultiply(PrunableClusterableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = scalar

    def call(self, x):
        return tf.math.multiply(x, self.scalar)

    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config

class Add(PrunableClusterableLayer):
    def call(self, inputs):
        return tf.math.add(*inputs)

# Calling Add with a scalar input will lead to an error.
# Use the following ScalarAdd class instead.
class ScalarAdd(PrunableClusterableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = scalar

    def call(self, x):
        return tf.math.add(x, self.scalar)

    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config

class Slice(PrunableClusterableLayer):
    def __init__(self, seq_idx, **kwargs):
        super().__init__(**kwargs)
        self.seq_idx = seq_idx

    def call(self, x):
        return x[:, self.seq_idx, ...]

    def get_config(self):
        config = super().get_config()
        config.update({'seq_idx': self.seq_idx})
        return config

class Mean(PrunableClusterableLayer):
    def __init__(self, axes=None, keepdims=True, **kwargs):
        super().__init__(**kwargs)
        self.axes=axes
        self.keepdims = keepdims

    def call(self, x):
        return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)

    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes,
                       'keepdims': self.keepdims})
        return config

class Subtract(PrunableClusterableLayer):
    def call(self, inputs):
        return tf.math.subtract(*inputs)

class StopGradient(PrunableClusterableLayer):
    def call(self, x):
        return tf.stop_gradient(x)

class RSqrt(PrunableClusterableLayer):
    def call(self, x):
        return tf.math.rsqrt(x)

class ClipMin(PrunableClusterableLayer):
    def __init__(self, min_val=0, **kwargs):
        super().__init__(**kwargs)
        self.min_val = min_val

    def call(self, x):
        return tf.math.maximum(x, self.min_val)

    def get_config(self):
        config = super().get_config()
        config.update({'min_val': self.min_val})
        return config

class BroadcastToken(PrunableClusterableLayer):
    """Layer to broadcast the class token"""
    def __init__(self, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim

    def build(self, input_shape):
        self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', 
                                 trainable=True, name='token')
        super().build(input_shape)

    def call(self, x):
        batch_size = tf.shape(x)[0]
        return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])

    def get_config(self):
        config = super().get_config()
        config.update({'embedding_dim': self.embedding_dim})
        return config

class AddPositionalEmbedding(PrunableClusterableLayer):
    """Layer to add positional embeddings to the tokens"""
    def __init__(self, seq_len, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.seq_len = seq_len

    def build(self, input_shape):
        self.w = self.add_weight(shape=(1, self.seq_len, self.embedding_dim), initializer=None,
                                 trainable=True, name='pos_emb')
        super().build(input_shape)

    def call(self, x):
        return x + self.w

    def get_config(self):
        config = super().get_config()
        config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})
        return config

class Scale(PrunableClusterableLayer):
    """Multiply with gamma (LayerNorm)"""
    def __init__(self, axes, **kwargs):
        super().__init__(**kwargs)
        self.axes = axes

    def build(self, input_shape):
        param_shape = [input_shape[dim] for dim in self.axes]
        self.w = self.add_weight(name='gamma', shape=param_shape,
                                 trainable=True, initializer='ones')
        super().build(input_shape)

    def call(self, x):
        return tf.multiply(x, self.w)

    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes})
        return config

class Centre(PrunableClusterableLayer):
    """Add beta (LayerNorm)"""
    def __init__(self, axes, **kwargs):
        super().__init__(**kwargs)
        self.axes = axes

    def build(self, input_shape):
        param_shape = [input_shape[dim] for dim in self.axes]
        self.w = self.add_weight(name='beta', shape=param_shape,
                                 trainable=True, initializer='zeros')
        super().build(input_shape)

    def call(self, x):
        return tf.math.add(x, self.w)

    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes})
        return config

### 2. Now that these low-level operators are defined as Keras layers, we can start writing ViT layers such as multi-head attention or layer normalisation functionally:

In [None]:
Tanh = tf.keras.layers.Activation('tanh')

def patch_encoder(inp, patch_size, num_patches, embedding_dim):
    """
    Patch encoder layer, extracts patches from the image, flattens them 
    and adds the class token and positional embedding vectors.
    """
    x = tf.keras.layers.Conv2D(filters=embedding_dim, kernel_size=patch_size,
                               strides=patch_size, name='patch_encoder/conv2d')(inp)
    x = tf.keras.layers.Reshape((num_patches, embedding_dim))(x)

    # add the class token
    cls_token = BroadcastToken(embedding_dim=embedding_dim, name='patch_encoder/cls_token')(inp)
    x = tf.keras.layers.Concatenate(axis=1)([cls_token, x])

    x = AddPositionalEmbedding(seq_len=(num_patches + 1),  # +1 for the class token
                               embedding_dim=embedding_dim,
                               name='patch_encoder/add_pos_emb')(x)
    return x

def self_attention(x, n_heads, dim, name='mha'):
    """Multi-head attention layer"""
    depth = dim // n_heads

    q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(x)
    k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(x)
    v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(x)

    q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)
    q = tf.keras.layers.Permute((2, 1, 3))(q)
    k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)
    k = tf.keras.layers.Permute((2, 1, 3))(k)
    v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)
    v = tf.keras.layers.Permute((2, 1, 3))(v)

    qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))
    attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)

    attn_out = MatMul()([attn_weights, v]) 
    attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)
    attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)
    out = tf.keras.layers.Dense(dim, name=f'{name}/output_dense')(attn_out)

    return out

def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable=True):
    """LayerNormalization"""
    if isinstance(axes, int): axes = [axes]

    mean = Mean(axes=axes)(x)
    ## This block can be replaced with a squared_difference layer ##
    diff = Subtract()([x, StopGradient()(mean)])                  ##
    sq_diff = Multiply()([diff, diff])                            ##
    ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##
    variance = Mean(axes=axes, name=f'{name}/variance')(sq_diff)
    if not trainable:
        inv = RSqrt()(variance)
        x = Multiply()([diff, inv])
    else:
        inv = RSqrt()(ClipMin(min_val=epsilon)(variance))  # ClipMin prevents division by 0.
        x = Subtract(name=f'{name}/grad_subtract')([x, mean])  # This layer is removed for inference so it is named.
        x = Multiply()([x, inv])

    x = Scale(axes=axes)(x)
    x = Centre(axes=axes)(x)

    return x

def gelu(x):
    """Functional definition of approximate GELU with Keras layers"""
    res = Add()([x, ScalarMultiply(0.044715)(Multiply()([x, Multiply()([x, x])]))])
    res = ScalarAdd(1.0)(Tanh(ScalarMultiply(math.sqrt(2 / math.pi))(res)))
    res = ScalarMultiply(0.5)(res)
    res = Multiply()([x, res])
    return res

def mlp(x, hidden_dim, out_dim):
    """Multi-layer perceptron block"""
    x = tf.keras.layers.Dense(units=hidden_dim)(x)
    x = gelu(x)
    x = tf.keras.layers.Dense(units=out_dim)(x)
    return x

### 3. Full functional model definition:

In [None]:
def get_vision_transformer(input_shape,
                           n_classes,
                           patch_size,
                           embedding_dim,
                           n_layers,
                           n_attention_heads,
                           mlp_hidden_dim,
                           trainable=True):
    """
    Args:
        input_shape (tuple): Shape of the inputs, including the batch size.
        n_classes (int): Number of classes in the dataset.
        patch_size (int / tuple of ints): Size of the patches to extract from the images.
        embedding_dim (int): Size of the embedded patch vectors.
        n_layers (int): Number of transformer encoder layers.
        n_attention_heads (int): Number of attention heads.
        mlp_hidden_dim (int): Hidden layer size for the intermediate MLPs.

    Returns:
        model (tf.keras.Model): The Keras model.
    """

    if isinstance(patch_size, int): patch_size = (patch_size, patch_size)

    # Calculate the number of patches
    num_patches = (input_shape[1] * input_shape[2]) // (patch_size[0] * patch_size[1])

    inp = tf.keras.layers.Input(shape=input_shape[1:], batch_size=input_shape[0], name='image')

    # Patch encoder layer
    x = patch_encoder(inp, patch_size, num_patches, embedding_dim)

    for block in range(n_layers):
        # Attention block
        x1 = layer_norm(x, name=(f'layer_norm_{2 * block}' if block != 0 else 'layer_norm'), trainable=trainable)
        x1 = self_attention(x1, n_attention_heads, embedding_dim, name=(f'mha_{block}' if block != 0 else 'mha'))
        x1 = tf.keras.layers.Add()([x1, x])

        # MLP block
        x2 = layer_norm(x1, name=f'layer_norm_{2 * block + 1}', trainable=trainable)
        x2 = mlp(x2, mlp_hidden_dim, embedding_dim)
        x = tf.keras.layers.Add()([x2, x1])

    x = layer_norm(x, name=f'layer_norm_{2 * block + 2}', trainable=trainable)

    ## ~ Classification head ~ ##
    cls_head = Slice(0)(x)
    out = tf.keras.layers.Dense(n_classes, kernel_initializer='zeros', name='cls_head')(cls_head)
    ## ~~~~~~~~~~~~~~~~~~~~~~~ ##

    model = tf.keras.Model(inputs=inp, outputs=out)

    return model

## Training

In [None]:
BATCH_SIZE = 32

# Load the MNIST dataset
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train, X_test = (X_train[..., np.newaxis] / 255.0), (X_test[..., np.newaxis] / 255.0)
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(1000) \
                                                                 .batch(BATCH_SIZE, drop_remainder=True) \
                                                                 .prefetch(tf.data.AUTOTUNE)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(BATCH_SIZE, drop_remainder=True) \
                                                              .prefetch(tf.data.AUTOTUNE)

def compile_and_fit(model, **kwargs):
    model.compile(optimizer="adam",
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  metrics=["accuracy"])
    model.fit(train_ds, validation_data=test_ds, **kwargs)

In [None]:
model = get_vision_transformer(input_shape=(BATCH_SIZE, 28, 28, 1),  # (batch_size, height, width, channels)
                               n_classes=10,
                               patch_size=(4, 4),
                               embedding_dim=16,
                               n_layers=2,
                               n_attention_heads=2,
                               mlp_hidden_dim=16)

compile_and_fit(model, epochs=3)

## Pruning, Clustering & QAT

### 1. Apply the pruning API

In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude
strip_pruning = tfmot.sparsity.keras.strip_pruning

N_EPOCHS = 1
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.1, final_sparsity=0.5,
                                                             begin_step=0, end_step=int(len(train_ds)*N_EPOCHS*0.7))
}
pruned_model = prune_low_magnitude(model, **pruning_params)
# Fine-tune with pruning
compile_and_fit(pruned_model, epochs=N_EPOCHS, callbacks=[tfmot.sparsity.keras.UpdatePruningStep()])
stripped_pruned_model = strip_pruning(pruned_model)
print('Success')

#### 1.1. Check that the weights are pruned

In [None]:
def print_sparsity(model):
    for w in model.weights:
        n_weights = w.numpy().size
        n_zeros = np.count_nonzero(w == 0)
        sparsity = n_zeros / n_weights * 100.0
        if sparsity > 0:
            print('    {} - {:.1f}% sparsity'.format(w.name, sparsity))

In [None]:
print('Sparse weights:')
print_sparsity(stripped_pruned_model)

### 2. Apply the clustering API

In [None]:
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster

cluster_weights = cluster.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
strip_clustering = tfmot.clustering.keras.strip_clustering

# Add sparsity-preserving clustering wrappers
pruned_clustered_model = cluster_weights(stripped_pruned_model,
                                         number_of_clusters=4,
                                         cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
                                         preserve_sparsity=True)
# Fine-tune with clustering
compile_and_fit(pruned_clustered_model, epochs=1)
stripped_pruned_clustered_model = strip_clustering(pruned_clustered_model)
print('Success')

#### 2.1. Check that the weights are pruned and clustered

In [None]:
def print_clusters(model):
    for w in model.weights:
        n_weights = w.numpy().size
        n_unique = len(np.unique(w))
        if n_unique < n_weights:
            print('    {} - {} unique weights'.format(w.name, n_unique))

In [None]:
print('Sparse weights:')
print_sparsity(stripped_pruned_clustered_model)
print('Clustered weights:')
print_clusters(stripped_pruned_clustered_model)

**Warning: The original model is modified after calling [`prune_low_magnitude`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/sparsity/keras/prune_low_magnitude) or [`cluster_weights`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/clustering/keras/cluster_weights).**

### 3. Quantisation-aware training API
#### 3.1. To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.

For Keras layers which are already supported in TFMOT, a default `QuantizeConfig` class is assigned to each one. However custom `QuantizeConfig` instances could also be created for these layers to give more control over how they are quantised.

In [None]:
from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers

LastValueQuantizer = quantizers.LastValueQuantizer
MovingAverageQuantizer = quantizers.MovingAverageQuantizer
AllValuesQuantizer = quantizers.AllValuesQuantizer

class NoOpQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which does not quantize any part of the layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return []

    def get_config(self):
        return {}

class OutputQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which only quantizes the output of a layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}

class WeightQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers."""

    def __init__(self):
        self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,
                                                   symmetric=True, narrow_range=True)
        self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,
                                                           symmetric=False, narrow_range=False)

    def get_weights_and_quantizers(self, layer):
        return [(layer.w, self.weight_quantizer)]

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        layer.w = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [self.activation_quantizer]

    def get_config(self):
        return {}

class VarianceQuantizeConfig(QuantizeConfig):
    """QuantizeConfig for the variance calculation in the layer normalisation layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}

Since custom layers and `QuantizeConfig`s are used, the whole model cannot directly be wrapped with QAT wrappers. <br>
So first we write a function to wrap the individual layers with QAT wrappers:

In [None]:
def apply_wrapper(wrapper_function, layer_param_dict):

    def wrap_layer(layer):
        if layer.name in layer_param_dict.keys():
            return wrapper_function(layer, **layer_param_dict[layer.name])
        return layer

    return wrap_layer

def layer_wrapper(model, wrapper_function, layer_param_dict):
    return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))

The custom layers should be quantized with the following `QuantizeConfig` classes:

| Custom Layer | QuantizeConfig |
| :- | :-: |
| ClipMin | NoOpQuantizeConfig |
| Slice | NoOpQuantizeConfig |
| StopGradient | NoOpQuantizeConfig |
| MatMul | OutputQuantizeConfig |
| Multiply | OutputQuantizeConfig |
| ScalarMultiply | OutputQuantizeConfig |
| Add | OutputQuantizeConfig |
| ScalarAdd | OutputQuantizeConfig |
| Subtract | OutputQuantizeConfig |
| RSqrt | OutputQuantizeConfig |
| Mean <br> Mean (variance) | OutputQuantizeConfig <br> VarianceQuantizeConfig |
| BroadcastToken | WeightQuantizeConfig |
| AddPositionalEmbedding | WeightQuantizeConfig |
| Scale | WeightQuantizeConfig |
| Centre | WeightQuantizeConfig |

In [None]:
def get_quant_configs(model):
    layer_param_dict = {}  # stores {Layer_Name: QuantizeConfig} pairs
    scope = {}  # stores all custom objects

    for layer in model.layers:

        if any([x in layer.name for x in ['clip', 'slice', 'stop_gradient']]):
            layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}
            scope[layer.__class__.__name__] = layer.__class__

        elif any([x in layer.name for x in ['mat_mul', 'multiply', 'scalar_multiply', 'add', \
                                            'scalar_add', 'mean', 'subtract', 'r_sqrt']]):
            layer_param_dict[layer.name] = {'quantize_config': OutputQuantizeConfig()}
            scope[layer.__class__.__name__] = layer.__class__

        elif any([x in layer.name for x in ['patch_encoder/cls_token', 'patch_encoder/add_pos_emb', \
                                            'scale', 'centre']]):
            layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}
            scope[layer.__class__.__name__] = layer.__class__

        elif 'variance' in layer.name:
            layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}
            scope[layer.__class__.__name__] = layer.__class__

    scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig
    scope['OutputQuantizeConfig'] = OutputQuantizeConfig
    scope['WeightQuantizeConfig'] = WeightQuantizeConfig
    scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig

    return layer_param_dict, scope

#### 3.2 Load the necessary API classes/functions

In [None]:
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_apply = tfmot.quantization.keras.quantize_apply
quantize_scope = tfmot.quantization.keras.quantize_scope
Default8BitClusterPreserveQuantizeScheme = tfmot.experimental.combine.Default8BitClusterPreserveQuantizeScheme
strip_clustering_cqat = tfmot.experimental.combine.strip_clustering_cqat

#### 3.3 Apply QAT

When calling the `quantize_apply` function, if an unsupported layer is missing from `layer_param_dict` or the `scope`, TFMOT will throw an error.

In [None]:
layer_param_dict, scope = get_quant_configs(stripped_pruned_clustered_model)

# Wrap each custom layer with the corresponding QuantizeConfig:
pcqat_model = layer_wrapper(stripped_pruned_clustered_model, quantize_annotate_layer, layer_param_dict)
# Quantize the rest of the model with the API defaults:
pcqat_model = quantize_annotate_model(pcqat_model)

with quantize_scope(scope):
    pcqat_model = quantize_apply(pcqat_model, scheme=Default8BitClusterPreserveQuantizeScheme(preserve_sparsity=True))

compile_and_fit(pcqat_model, epochs=2)
pcqat_model = strip_clustering_cqat(pcqat_model)  # strip clustering variables

WEIGHTS_PATH = './ViT_PCQAT.h5'
pcqat_model.save_weights(WEIGHTS_PATH)
print('Success')

#### 3.4. Check that the weights are still pruned and clustered

In [None]:
print('Sparse weights:')
print_sparsity(pcqat_model)
print('Clustered weights:')
print_clusters(pcqat_model)

### 4. Generate an int8 TFLite file

If we attempt to directly generate a TFLite file using the fine-tuned model above:
1. It will not have a correct batch size of 1.
2. It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `ClipMin` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.

Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model.

In [None]:
tf.keras.backend.clear_session()  # reset layer name counters

net = get_vision_transformer(input_shape=(1, 28, 28, 1),  # (batch_size, height, width, channels)
                             n_classes=10,
                             patch_size=(4, 4),
                             embedding_dim=16,
                             n_layers=2,
                             n_attention_heads=2,
                             mlp_hidden_dim=16,
                             trainable=False)
layer_param_dict, scope = get_quant_configs(net)
net = quantize_annotate_model(layer_wrapper(net, quantize_annotate_layer, layer_param_dict))
with quantize_scope(scope):
    net = quantize_apply(net)

net.load_weights(WEIGHTS_PATH, by_name=True)

In [None]:
MODEL_PATH = './ViT_PCQAT_int8.tflite'

converter = tf.lite.TFLiteConverter.from_keras_model(net)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

# Experimental flag which improves efficiency for some devices
converter._experimental_disable_batchmatmul_unfold = True

tflite_model = converter.convert()
with open(MODEL_PATH, "wb+") as tflite_file:
    tflite_file.write(tflite_model)

### 5. Evaluate the TFLite model

In [None]:
interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
input_scale, input_zero_point = input_details[0]['quantization']
output_scale, output_zero_point = output_details[0]['quantization']

int8_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='int8_accuracy')
progbar = tf.keras.utils.Progbar(len(X_test), stateful_metrics=['accuracy'])
for step, (img, lbl) in enumerate(zip(X_test, y_test)):
    # Set input tensor
    img = img[np.newaxis, ...] / input_scale + input_zero_point
    interpreter.set_tensor(input_details[0]['index'], tf.cast(img, input_details[0]['dtype']))
    interpreter.invoke()

    # Get output tensor
    output_data = interpreter.get_tensor(output_details[0]['index'])
    output_data = output_scale * (output_data.astype(np.float32) - output_zero_point)

    # Update accuracy
    int8_accuracy.update_state(lbl, output_data)
    progbar.update(step + 1, values=[('accuracy', int8_accuracy.result().numpy())])

print('Accuracy:', int8_accuracy.result().numpy())