# Sequence to Sequence Transformer model optimisation using TFMOT (Quantization Aware Training)

Example notebook to demonstrate how TFMOT can be used for optimising complex sequence to sequence transformer models

## Background

The sequence to sequence transformer is one of the initial transformer model architectures. The core idea behind the Transformer model is self-attention—the ability to attend to different positions of the input sequence to compute a representation of that sequence. The paper called ["Attention Is All You Need"](https://arxiv.org/pdf/1706.03762.pdf) might give a deeper insight into transformer model and their self-attention mechanism.

<img src="https://deepfrench.gitlab.io/deep-learning-project/resources/transformer.png" alt="Sequence to sequence transformer" width="1000" align="center" title="Transformer">

[1] The above image was taken from [here](https://deepfrench.gitlab.io/deep-learning-project/)

#### In this notebook:

* The aim of this tutorial is to first train the Transformer model from [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)
* Re-write the above model as a funtional FP32 model
* Perform Quantized Aware Training (QAT) on the FP32 model
* Create and test the tflite model generated from the FP32 model after performing QAT on it. 

Note: This tutorial has re-used some code and explanation from the original [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)

#### TFMOT limitations
- Subclassed models are not supported. Only sequential and functional model definitions are supported. (Pruning, Clustering & QAT)
- Custom subclassed layers are not supported. (Clustering & QAT)
    - Clustering will only work with subclassed layers if the weight variables you have to cluster are not nested within another layer (e.g. MHA).
    - QAT works correctly if the subclassed layer performs only 1 operation.
- Low-level tensorflow operators such as `tf.linalg.matmul` are not supported. (Only for QAT)
    - QAT expects all quantised layers to be a subclass of `tf.keras.layers.Layer`.

### 1. Setup

In [None]:
import pathlib
import random
import tempfile
import zipfile
import re
import os
import nltk
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization
import tensorflow_model_optimization as tfmot
from collections import defaultdict

def reset_random_seeds():
   os.environ['PYTHONHASHSEED']=str(2)
   tf.random.set_seed(2)
   np.random.seed(2)
   random.seed(2)

reset_random_seeds()

print('TensorFlow version: {}'.format(tf.__version__))
print('TFMOT version: {}'.format(tfmot.__version__))
print("NLTK verison: {}".format(nltk.__version__))
print("Numpy version: {}".format(np.__version__))

### 2. Downloading the data

The dataset used here is English to Spanish translation dataset

In [None]:
text_file = keras.utils.get_file(
    fname="spa-eng.zip",
    origin="http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip",
    extract=True,
)
text_file = pathlib.Path(text_file).parent / "spa-eng" / "spa.txt"

### 3. Parsing the data

Each target sentence (which is in Spanish) has `[start]` and `[end]` token prepended and appended, respectively, at this stage.

In [None]:
with open(text_file) as f:
    lines = f.read().split("\n")[:-1]
text_pairs = []
for line in lines:
    eng, spa = line.split("\t")
    spa = "[start] " + spa + " [end]"
    text_pairs.append((eng, spa))

In [None]:
for _ in range(5):
    print(random.choice(text_pairs))

Split the dataset into train, test and validation set

In [None]:
random.shuffle(text_pairs)
num_val_samples = int(0.15 * len(text_pairs))
num_train_samples = len(text_pairs) - 2 * num_val_samples
train_pairs = text_pairs[:num_train_samples]
val_pairs = text_pairs[num_train_samples : num_train_samples + num_val_samples]
test_pairs = text_pairs[num_train_samples + num_val_samples :]

print(f"{len(text_pairs)} total pairs")
print(f"{len(train_pairs)} training pairs")
print(f"{len(val_pairs)} validation pairs")
print(f"{len(test_pairs)} test pairs")

### 4. Vectorizing the text data

Vectorization refers to the preprocessing step where text features are mapped to integer sequences where each integer represents the index of a word in a vocubulary. For this, [`tf.keras.layers.TextVecorization`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer is used.

In our case, vectorization for English sequences is a little different from that of Spanish sequences:

- For English string sequences, default standardization is used which strips all punctuation characters
- For Spanish string sequences, custom standardization is used which strips all characters which are not in `{` a-z.?!,¿[]`}`

In [None]:
vocab_size = 15000
seq_len = 20
batch_size = 64
embed_dim = 256
latent_dim = 2048
num_heads = 8


def custom_standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    # The following regex replaces a character with ""
    # which is not one of the following:
    # 1. Lower case alphabet
    # 2. Space
    # 3. Is on of these characters: ".", "?", "!", ",", "¿", "[", "]"
    return tf.strings.regex_replace(lowercase, '[^ a-z.?!,¿\[\]]', "")


eng_vectorization = TextVectorization(
    max_tokens=vocab_size, output_mode="int", output_sequence_length=seq_len,
)
spa_vectorization = TextVectorization(
    max_tokens=vocab_size,
    output_mode="int",
    output_sequence_length=seq_len + 1,
    standardize=custom_standardization,
)

train_eng_texts = [pair[0] for pair in train_pairs]
train_spa_texts = [pair[1] for pair in train_pairs]
eng_vectorization.adapt(train_eng_texts)
spa_vectorization.adapt(train_spa_texts)

At each training step, the model will seek to predict target words N+1 (and beyond) using the source (or the input) sentence and the target words 0 to N. For this reason, we need (`inputs`, `targets`)

- `inputs`:

    After vectorization, our dataset is formatted to include the following four in the `inputs` (`inputs` is essentially a list of four inputs):

    * encoder_inputs : which contains the vectorized english sentence data
    * decoder_inputs : which contains the vectorized spanish (target) sentence data, i.e. target_sentence[:, :-1]. It is also the target sentence "so far", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence. 
    * encoder_masks  : which contains the corresponding mask data for encoder_inputs
    * decoder_masks  : which contains the corresponding mask data for decoder_inputs

    Please note that the two mask inputs are only required for the custom FP32 functional model as the original keras model is able to generate it's own mask. Therefore, the original model tends to ignore the two mask inputs (user doesn't need to worry about this). <br><br>
    
- `targets`:

    After vectorization, our dataset is formatted to assign the target sentence offset by one (i.e. target_sentence[:, 1:]) as the `targets`. In other words this is what model will try to predict.

In [None]:
def format_dataset(eng, spa):
    eng = eng_vectorization(eng)
    spa = spa_vectorization(spa)

    # Create input masks
    encoder_masks=tf.cast(tf.not_equal(np.int64(0),eng),tf.float32)
    decoder_masks=tf.cast(tf.not_equal(np.int64(0),spa[:, :-1]),tf.float32)
    
    return ({"encoder_inputs": eng, "encoder_masks": encoder_masks, "decoder_inputs": spa[:, :-1], "decoder_masks": decoder_masks}, spa[:, 1:])

def make_dataset(pairs, batch_size=64):
    eng_texts, spa_texts = zip(*pairs)
    eng_texts = list(eng_texts)
    spa_texts = list(spa_texts)
    dataset = tf.data.Dataset.from_tensor_slices((eng_texts, spa_texts))
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.map(format_dataset)
    
    return dataset.shuffle(2048).prefetch(16).cache()


train_ds = make_dataset(train_pairs, batch_size)
val_ds = make_dataset(val_pairs, batch_size)
test_ds = make_dataset(test_pairs, batch_size)

In [None]:
for inputs, targets in train_ds.take(1):
    print(f'inputs["encoder_inputs"].shape: {inputs["encoder_inputs"].shape}')
    print(f'inputs["decoder_inputs"].shape: {inputs["decoder_inputs"].shape}')
    print(f'inputs["encoder_masks"].shape: {inputs["encoder_masks"].shape}')
    print(f'inputs["decoder_masks"].shape: {inputs["decoder_masks"].shape}')
    print(f"targets.shape: {targets.shape}")

### 5. Utility functions

Typically, BLEU score is used to measure the quality of a translation.

In [None]:
def bleu_score(real_text, predicted_text):
    '''Get BLEU score'''
    return (nltk.translate.bleu_score.corpus_bleu(real_text,predicted_text))

For decoding (or in other words translating a source sentence to a target sentenceg), we provide a vectorized source sentence as `encoder_inputs` and a vecotrized `[start]` token (ofcourse, padded to match the right sequence length) as the `decoder_inputs`, then we repeatedly generated the next token, until we hit the token `[end]`.

A key thing to note is that in the custom FP32 functional model used in this notebook `encoder_masks` and `decoder_masks` are also fed into the model.

In [None]:
def get_text_result(model, num_samples_to_eval =200, no_input_masks=False):
    '''
    Function to calculate BLEU score on test set

    num_samples_to_eval: Represents the total number of test sentences to
                         consider during evaluation. If you want the entire 
                         test set to be used for evaluation then set 
                         max_sample = -1
    
    no_input_masks: Set as True for the original transformer model from
                    keras example
    '''

    spa_vocab = spa_vectorization.get_vocabulary()
    spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
    max_decoded_sentence_length = 20

    def decode_sequence_func(input_sentence):

        tokenized_input_sentence = eng_vectorization([input_sentence])
        encoder_mask = tf.cast(tf.not_equal(np.int64(0),tokenized_input_sentence), tf.float32)

        decoded_sentence = "[start]"
        for i in range(max_decoded_sentence_length):
            tokenized_target_sentence = spa_vectorization([decoded_sentence])[:, :-1]
            decoder_mask=tf.cast(tf.not_equal(np.int64(0),tokenized_target_sentence), tf.float32)
            if no_input_masks:
                predictions = model([tokenized_input_sentence, tokenized_target_sentence])
            else:
                predictions = model([tokenized_input_sentence, encoder_mask, tokenized_target_sentence,decoder_mask])
            sampled_token_index = np.argmax(predictions[0, i, :])
            sampled_token = spa_index_lookup[sampled_token_index]
            decoded_sentence += " " + sampled_token

            if sampled_token == "[end]":
                break

        return decoded_sentence


    hypothesis= []
    references = []
    test_sample_count = sum(1 for e in test_pairs) 
    progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)

    for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):
        translated = decode_sequence_func(inp)
        target=target.lower()
        target=re.sub('[^ a-z.?!,¿\[\]]', "",target)
        hypothesis.append(translated.split()[1:-1])
        references.append([target.split()[1:-1]])
        progbar.update(step + 1)

    print(str("Bleu Score: ") + str(bleu_score(references[:], hypothesis[:])))

    # Print first 10 actual and predicted spanish translation for sanity check
    for i in range(10):
        print(references[i][0])
        print(hypothesis[i])
        print("-----------------------/n")


Suggestion: While trying to run inference on a tflite file please make sure that the scale, zero_point and data type are correct for the inputs and outputs

In [None]:
def get_text_result_tflite(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):
    '''
    Function to calculate BLEU score for a given tflite file on the test set

    model_path: Path to the tflite file

    input_type: Could be float32 or int8/32. If the inputs in tflite graph
                are float32 set this value to 'float32' but if inputs are
                int8 (mask inputs) and int32 (non-maks inputs) set this
                value to 'int8/32'.

    output_type: Could be float32 or int8. If the outputs in tflite graph
                 are float32 set this value to 'float32' but if output
                 are int8 set this value to 'int8'.
                
    num_samples_to_eval: Evaluation of entire test set will take a lot
                         of time therefore, only first 200 samples are 
                         evaluated. To evaluate the entire test-set, 
                         set the value below to a negative value, e.g.
                         -1
    '''
    assert(input_type in ['float32', 'int8/32']), "input_type not supported"
    assert(output_type in ['float32', 'int8']), "output_type not supported"

    print('Performing BLEU evaluation for tflite file at {}'.format(model_path))

    spa_vocab = spa_vectorization.get_vocabulary()
    spa_index_lookup = dict(zip(range(len(spa_vocab)), spa_vocab))
    max_decoded_sentence_length = seq_len

    interpreter = tf.lite.Interpreter(model_path=model_path)

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_scale_1, input_zero_point_1 = input_details[0]['quantization']
    input_scale_2, input_zero_point_2 = input_details[1]['quantization']
    input_scale_3, input_zero_point_3 = input_details[2]['quantization']
    input_scale_4, input_zero_point_4 = input_details[3]['quantization']
    output_scale, output_zero_point = output_details[0]['quantization']

    interpreter.allocate_tensors()

    def decode_sequence_func(input_sentence):

        input_1 = eng_vectorization([input_sentence])
        input_2 = tf.cast(tf.not_equal(np.int64(0),input_1), tf.float32)
        if input_type == 'int8/32':
            input_2 = input_2/ input_scale_2 + input_zero_point_2

        decoded_sentence = "[start]"

        for i in range(max_decoded_sentence_length):
            input_3 = spa_vectorization([decoded_sentence])[:, :-1]
            input_4=tf.cast(tf.not_equal(np.int64(0),input_3), tf.float32)

            # Set input tensor
            interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))

            # Set input tensor
            interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))

            # Set input tensor
            interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))

            # Set input tensor
            if input_type == 'int8/32':
                input_4 = input_4/ input_scale_4 + input_zero_point_4
            interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))

            interpreter.invoke()
            
            # Get output tensor
            output_data = interpreter.get_tensor(output_details[0]['index'])
            predictions = output_data.astype(np.float32)
            if output_type == 'int8':
                predictions = output_scale * (predictions - output_zero_point)
            
            sampled_token_index = np.argmax(predictions[0, i, :])
            sampled_token = spa_index_lookup[sampled_token_index]
            decoded_sentence += " " + sampled_token

            if sampled_token == "[end]":
                break

        return decoded_sentence


    hypothesis= []
    references = []
    test_sample_count = sum(1 for e in test_pairs) 
    progbar = tf.keras.utils.Progbar(test_sample_count if num_samples_to_eval == -1 else num_samples_to_eval)

    for step, (inp, target) in enumerate(test_pairs[:num_samples_to_eval]):
        translated = decode_sequence_func(inp)
        target=target.lower()
        target=re.sub('[^ a-z.?!,¿\[\]]', "",target)
        hypothesis.append(translated.split()[1:-1])
        references.append([target.split()[1:-1]])
        progbar.update(step + 1)

    print(str("Bleu Score: ") + str(bleu_score(references[:], hypothesis[:])))
    for i in range(10):
        print(references[i][0])
        print(hypothesis[i])
        print("-----------------------/n")


In [None]:
def get_tflite_accuracy(model_path, input_type = 'int8/32', output_type = 'int8', num_samples_to_eval = 200):
    '''
    Function to calculate accuracy of a given tflite file on the test set

    model_path: Path to the tflite file

    input_type: Could be float32 or int8/32. If the inputs in tflite graph
                are float32 set this value to 'float32' but if inputs are
                int8 (mask inputs) and int32 (non-maks inputs) set this
                value to 'int8/64'.

    output_type: Could be float32 or int8. If the outputs in tflite graph
                 are float32 set this value to 'float32' but if output
                 are int8 set this value to 'int8'.
                
    num_samples_to_eval: Evaluation of entire test set will take a lot
                         of time therefore, only first 200 samples are 
                         evaluated. To evaluate the entire test-set, 
                         set the value below to a negative value, e.g.
                         -1
    '''
    assert(input_type in ['float32', 'int8/32']), "input_type not supported"
    assert(output_type in ['float32', 'int8']), "output_type not supported"

    print('Performing accuracy evaluation for tflite file at {}'.format(model_path))

    interpreter = tf.lite.Interpreter(model_path=model_path)

    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    input_scale_1, input_zero_point_1 = input_details[0]['quantization']
    input_scale_2, input_zero_point_2 = input_details[1]['quantization']
    input_scale_3, input_zero_point_3 = input_details[2]['quantization']
    input_scale_4, input_zero_point_4 = input_details[3]['quantization']
    output_scale, output_zero_point = output_details[0]['quantization']
    interpreter.allocate_tensors()

    test_ds_tflite = make_dataset(test_pairs, 1)
    accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
    progbar = tf.keras.utils.Progbar(sum(1 for e in test_ds_tflite) if num_samples_to_eval == -1 else num_samples_to_eval, stateful_metrics=['accuracy'])

    for step, (input, target) in enumerate(test_ds_tflite):

        # Set input tensor
        input_1 = input['encoder_inputs']
        interpreter.set_tensor(input_details[0]['index'], tf.cast(input_1, input_details[0]['dtype']))

        # Set input tensorprogress bars for loopp python
        input_2=input['encoder_masks']
        if input_type == 'int8/32':
            input_2 = tf.cast(input_2, tf.float32)
            input_2 = input_2/ input_scale_2 + input_zero_point_2
        interpreter.set_tensor(input_details[1]['index'], tf.cast(input_2, input_details[1]['dtype']))

        # Set input tensor
        input_3 = input['decoder_inputs']
        interpreter.set_tensor(input_details[2]['index'], tf.cast(input_3, input_details[2]['dtype']))

        # Set input tensor
        input_4=input['decoder_masks']
        if input_type == 'int8/32':
            input_4 = tf.cast(input_4, tf.float32)
            input_4 = input_4/ input_scale_4 + input_zero_point_4
        interpreter.set_tensor(input_details[3]['index'], tf.cast(input_4, input_details[3]['dtype']))
        interpreter.invoke()
        
        # Get output tensor
        output_data = interpreter.get_tensor(output_details[0]['index'])
        output_data = output_data.astype(np.float32)
        if output_type == 'int8':
            output_data = output_scale * (output_data - output_zero_point)
        
        # Update accuracy
        mask = input['decoder_inputs']
        accuracy.update_state(target, output_data, mask)
        progbar.update(step + 1, values=[('accuracy', accuracy.result().numpy())])
        
        if step == num_samples_to_eval:
            break

Use the following function to get the size of the tflite file when zipped

In [None]:
def get_gzipped_model_size(file):
  '''Returns the size of a gzipped tflite file in kilobytes'''

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

### 6. Functions related to Training the model

The loss function used is Masked Sparse Categorical Crossentropy loss (which uses the `tf.keras.losses.SparseCategoricalCrossentropy` but with masks).
The loss function needs masks to be propogated correctly through the model layers down to the loss function which, the custom FP32 model wasn't able to do correctly therefore, a custom training loop was needed to calculate the loss correctly.

In [None]:
epochs = 11

def evaluate(model_to_eval, training=False):

    val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()
    val_acc = tf.keras.metrics.SparseCategoricalAccuracy()

    @tf.function
    def eval_step(inp, y_true):
        preds = model_to_eval(inp, training=training)
        # masked loss
        val_loss.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32))  
        # masked accuracy
        val_acc.update_state(y_true, preds,tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32))  

    for step, (inp, y_true) in enumerate(val_ds):
        eval_step(inp, y_true)

    return {'loss': val_loss.result().numpy(), 'accuracy': val_acc.result().numpy()}


def train(model_to_train, save_best_weights =True, model_type='original', lr=1e-3, epochs = epochs):

    if model_type == 'original':
        ckpt_path = './eng_spa_transformer_qat_tutorial_model.h5'
    elif model_type == 'fp32':
        ckpt_path = './eng_spa_transformer_qat_tutorial_fp32_model.h5'
    elif model_type == 'qat':
        ckpt_path = './eng_spa_transformer_qat_tutorial_qat_model.h5'
    else:
        print('Please select the correct model type!!')
        return None
    
    print('Training (save_best_weights={}, model_type={})'.format(save_best_weights, model_type))

    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    optimiser = tf.keras.optimizers.Adam(learning_rate=lr)
    train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
    model_to_train.optimizer = optimiser

    @tf.function
    def train_step(inp, y_true):
        mask =tf.cast(tf.not_equal(np.int64(0),inp['decoder_inputs']),tf.float32)
        preds=None
        loss=None
        
        with tf.GradientTape() as tape:
            preds = model_to_train(inp, training=True)
            # Masked loss
            loss = loss_fn(y_true, preds, mask)
            grads = tape.gradient(loss, model_to_train.trainable_weights)
            optimiser.apply_gradients(zip(grads, model_to_train.trainable_weights))

        # Masked accuracy    
        train_acc.update_state(y_true, preds, mask)
        return loss

    max_val = float('-inf')

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch + 1, epochs), flush=True)
        # Train
        progbar = tf.keras.utils.Progbar(len(train_ds), interval=.5,
                                        stateful_metrics=['acc'])        

        for step, (inp, y_true) in enumerate(train_ds):
                loss = train_step(inp, y_true)
                progbar.update(step + 1, values=[('loss', loss),
                                                ('acc', train_acc.result())])

        # Evaluate
        val_results = evaluate(model_to_train)

        validation_accuracy = val_results['accuracy']
        print('Validation accuracy: {}'.format(validation_accuracy))

        if save_best_weights and validation_accuracy > max_val:
            
            print('Best validation accuracy so far, saving weights')
            model_to_train.save_weights(ckpt_path)
            max_val = validation_accuracy

        train_acc.reset_states()        

    if not save_best_weights:
        model_to_train.save_weights(ckpt_path)
    # Load weights
    model_to_train.load_weights(ckpt_path)

### 7. Building the original Transformer Keras model mentioned in the [Keras tutorial](https://keras.io/examples/nlp/neural_machine_translation_with_transformer/)


(a) Define the custom layers for the model

In [None]:
class TransformerEncoder(layers.Layer):
    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.dense_dim = dense_dim
        self.num_heads = num_heads
        self.attention = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [layers.Dense(dense_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, mask=None):
        if mask is not None:
            padding_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype="int32")
        attention_output = self.attention(
            query=inputs, value=inputs, key=inputs, attention_mask=padding_mask
        )
        proj_input = self.layernorm_1(inputs + attention_output)
        proj_output = self.dense_proj(proj_input)
        return self.layernorm_2(proj_input + proj_output)
    
    def get_config(self):
        config = super().get_config()
        config.update({'embed_dim': self.embed_dim,
                       'dense_dim': self.dense_dim,
                       'num_heads': self.num_heads})
        return config


class PositionalEmbedding(layers.Layer):
    def __init__(self, seq_len, vocab_size, embed_dim, **kwargs):
        super(PositionalEmbedding, self).__init__(**kwargs)
        self.token_embeddings = layers.Embedding(
            input_dim=vocab_size, output_dim=embed_dim
        )
        self.position_embeddings = layers.Embedding(
            input_dim=seq_len, output_dim=embed_dim
        )
        self.seq_len = seq_len
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim

    def call(self, inputs):
        length = tf.shape(inputs)[-1]
        positions = tf.range(start=0, limit=length, delta=1)
        embedded_tokens = self.token_embeddings(inputs)
        embedded_positions = self.position_embeddings(positions)
        return embedded_tokens + embedded_positions

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)

    def get_config(self):
        config = super().get_config()
        config.update({'embed_dim': self.embed_dim,
                       'vocab_size': self.vocab_size,
                       'seq_len': self.seq_len})
        return config


class TransformerDecoder(layers.Layer):
    def __init__(self, embed_dim, latent_dim, num_heads, **kwargs):
        super(TransformerDecoder, self).__init__(**kwargs)
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.attention_1 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.attention_2 = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim
        )
        self.dense_proj = keras.Sequential(
            [layers.Dense(latent_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm_1 = layers.LayerNormalization()
        self.layernorm_2 = layers.LayerNormalization()
        self.layernorm_3 = layers.LayerNormalization()
        self.supports_masking = True

    def call(self, inputs, encoder_outputs, mask=None):
        causal_mask = self.get_causal_attention_mask(inputs)
        if mask is not None:
            padding_mask = tf.cast(mask[:, tf.newaxis, :], dtype="int32")
            padding_mask = tf.minimum(padding_mask, causal_mask)

        attention_output_1 = self.attention_1(
            query=inputs, value=inputs, key=inputs, attention_mask=causal_mask
        )
        out_1 = self.layernorm_1(inputs + attention_output_1)

        attention_output_2 = self.attention_2(
            query=out_1,
            value=encoder_outputs,
            key=encoder_outputs,
            attention_mask=padding_mask,
        )
        out_2 = self.layernorm_2(out_1 + attention_output_2)

        proj_output = self.dense_proj(out_2)
        return self.layernorm_3(out_2 + proj_output)

    def get_causal_attention_mask(self, inputs):
        input_shape = tf.shape(inputs)
        batch_size, seq_len = input_shape[0], input_shape[1]
        i = tf.range(seq_len)[:, tf.newaxis]
        j = tf.range(seq_len)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)],
            axis=0,
        )
        return tf.tile(mask, mult)
    
    def get_config(self):
        config = super().get_config()
        config.update({'embed_dim': self.embed_dim,
                       'latent_dim': self.latent_dim,
                       'num_heads': self.num_heads})
        return config

(b) Build the end-to-end model

In [None]:
def get_encoder_decoder_model():
    encoder_inputs = keras.Input(shape=(20,), dtype="int64", name="encoder_inputs")
    x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(encoder_inputs)
    encoder_outputs = TransformerEncoder(embed_dim, latent_dim, num_heads)(x)
    decoder_inputs = keras.Input(shape=(20,), dtype="int64", name="decoder_inputs")
    encoded_seq_inputs = encoder_outputs
    x = PositionalEmbedding(seq_len, vocab_size, embed_dim)(decoder_inputs)
    x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, encoded_seq_inputs)
    x = layers.Dropout(0.5)(x)
    decoder_outputs = layers.Dense(vocab_size, activation="softmax")(x)
    
    transformer = keras.Model(
        [encoder_inputs, decoder_inputs], decoder_outputs, name="transformer"
    )

    return transformer

transformer = get_encoder_decoder_model()

(c) Training the original Transformer model from Keras example

In [None]:
transformer.summary()
train(transformer, model_type='original')

(d) Evaluate performance

In [None]:
# Get BLEU score on test set for original transformer model
get_text_result(transformer, no_input_masks=True)

In [None]:
# Get accuracy on test set for the original transformer model from Keras example
evaluate(transformer)

### 8. Create FP32 Function Model for the Transformer model

Custom Keras layers must be defined for all of the low-level TensorFlow operators (each must only contain a single operation for QAT).

Since none of these will have any prunable weights, first we create a base prunable layer class to extend, instead of `tf.keras.layers.Layer`.

(a) Create a base prunable layer class 

In [None]:
class PrunableLayer(tf.keras.layers.Layer, tfmot.sparsity.keras.PrunableLayer):
    def get_prunable_weights(self): return []

(b) Define low level TensorFlow operations as Keras subclassed layers

Note that some of these layers have trainable weights defined using the `add_weight` method. These weights will not be pruned or clustered.

In [None]:
class Tanh(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, x):
        return tf.math.tanh(x)


class Relu(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, x):
        return tf.maximum(0., x)

    
class MatMul(PrunableLayer):
    def __init__(self, transpose_b=False, **kwargs):
        super().__init__(**kwargs)
        self.transpose_b = transpose_b       
    
    def call(self, inputs):
        return tf.linalg.matmul(*inputs, transpose_b=self.transpose_b)
    
    def get_config(self):
        config = super().get_config()
        config.update({'transpose_b': self.transpose_b})
        return config


class Multiply(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, inputs):
        return tf.multiply(*inputs)


# Calling Multiply with a scalar input will lead to an error.
# Use the following ScalarMultiply class instead.
class ScalarMultiply(PrunableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = scalar        
        
    def call(self, x):
        return tf.math.multiply(x, self.scalar)
    
    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config


class Add(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
       
    def call(self, inputs):
        return tf.math.add(*inputs)


# Calling Add with a scalar input will lead to an error.
# Use the following ScalarAdd class instead.
class ScalarAdd(PrunableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = scalar   
    
    def call(self, x):
        return tf.math.add(x, self.scalar)
    
    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config


class Slice(PrunableLayer):
    def __init__(self, seq_idx, **kwargs):
        super().__init__(**kwargs)
        self.seq_idx = seq_idx      
    
    def call(self, x):
        return x[:, self.seq_idx, ...]
    
    def get_config(self):
        config = super().get_config()
        config.update({'seq_idx': self.seq_idx})
        return config


class Mean(PrunableLayer):
    def __init__(self, axes=None, keepdims=True, **kwargs):
        super().__init__(**kwargs)
        self.axes=axes
        self.keepdims = keepdims      
    
    def call(self, x):
        return tf.math.reduce_mean(x, axis=self.axes, keepdims=self.keepdims)
    
    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes,
                       'keepdims': self.keepdims})
        return config


class Subtract(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)       
        
    def call(self, inputs):
        return tf.math.subtract(*inputs)


class ScalarSubtract(PrunableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)
        self.scalar = scalar   
    
    def call(self, x):
        return tf.math.subtract(self.scalar,x)
    
    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config


class SquaredDiffrence(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)       
        
    def call(self,inputs):
        return tf.math.squared_difference(*inputs)


class StopGradient(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, x):
        return tf.stop_gradient(x)


class RSqrt(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def call(self, x):
        return tf.math.rsqrt(x)


class Clip(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
       
    def call(self, x):
        return tf.clip_by_value(x, 0.001, 255.0)


class BroadcastToken(PrunableLayer):
    """Layer to broadcast the class token"""
    def __init__(self, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim

    def build(self, input_shape):
        self.w = self.add_weight(shape=(1, 1, self.embedding_dim), initializer='zeros', 
                                 trainable=True, name='token')
        super().build(input_shape)

    def call(self, x):
        batch_size = tf.shape(x)[0]
        return tf.broadcast_to(self.w, [batch_size, 1, self.embedding_dim])

    def get_config(self):
        config = super().get_config()
        config.update({'embedding_dim': self.embedding_dim})
        return config


class AddPositionalEmbedding(PrunableLayer):
    """Layer to add positional embeddings to the tokens"""
    def __init__(self, seq_len, embedding_dim, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.seq_len = seq_len

    def build(self, input_shape):
        self.w = self.add_weight(shape=(self.seq_len, self.embedding_dim), initializer= 'uniform',
                                 trainable=True, name='pos_emb')
        super().build(input_shape)

    def call(self, x):
        return x + self.w

    def get_config(self):
        config = super().get_config()
        config.update({'embedding_dim': self.embedding_dim, 'seq_len': self.seq_len})
        return config


class AddTokenEmbedding(PrunableLayer): 
    """Layer to add token embeddings to the tokens"""
    def __init__(self, vocab_size, embedding_dim, train = True, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.train = train

    def build(self, input_shape):
        self.w = self.add_weight(shape=(self.vocab_size, self.embedding_dim), initializer= 'uniform',
                                 trainable=self.train, name='token_emb')
        super().build(input_shape)

    def call(self, x):
        return tf.gather(self.w,x)

    def get_config(self):
        config = super().get_config()
        config.update({'embedding_dim': self.embedding_dim, 'vocab_size': self.vocab_size, 'train': self.train})
        return config

    def compute_output_shape(self, input_shape):
        return(input_shape[-1], self.embedding_dim)


class Scale(PrunableLayer):
    """Multiply with gamma (LayerNorm)"""
    def __init__(self, axes, **kwargs):
        super().__init__(**kwargs)
        self.axes = axes        
        
    def build(self, input_shape):
        param_shape = [input_shape[dim] for dim in self.axes]
        self.w = self.add_weight(name='gamma', shape=param_shape,
                                 trainable=True, initializer='ones')
        super().build(input_shape)
        
    def call(self, x):
        return tf.multiply(x, self.w)
    
    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes})
        return config

    
class Centre(PrunableLayer):
    """Add beta (LayerNorm)"""
    def __init__(self, axes, **kwargs):
        super().__init__(**kwargs)
        self.axes = axes        
        
    def build(self, input_shape):
        param_shape = [input_shape[dim] for dim in self.axes]
        self.w = self.add_weight(name='beta', shape=param_shape,
                                 trainable=True, initializer='zeros')
        super().build(input_shape)
        
    def call(self, x):
        return tf.math.add(x, self.w)
    
    def get_config(self):
        config = super().get_config()
        config.update({'axes': self.axes})
        return config


class Minimum(PrunableLayer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)       
        
    def call(self,inputs):
        return tf.minimum(*inputs)


class MinimumScalar(PrunableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)       
        self.scalar=scalar

    def call(self,inputs):
        return tf.minimum(inputs, self.scalar)
    
    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config


class MaximumScalar(PrunableLayer):
    def __init__(self, scalar, **kwargs):
        super().__init__(**kwargs)       
        self.scalar=scalar

    def call(self,inputs):
        return tf.maximum(inputs, self.scalar)
    
    def get_config(self):
        config = super().get_config()
        config.update({'scalar': self.scalar})
        return config


class Cast(PrunableLayer):
    def __init__(self, type = tf.int32, **kwargs):
        super().__init__(**kwargs)  
        self.type=type    

    def call(self,inputs):
        return tf.cast(inputs, self.type)

    def get_config(self):
        config = super().get_config()
        config.update({'type': self.type})
        return config

(c) Define Transormer layers like multiheaded-attention, layer-norm, etc. functionally

In [None]:
def self_attention(query, key, value, n_heads, dim, mask=None, name='mha', block_name=None, out_dim=None):
    """Multi-head attention layer"""
    depth = dim // n_heads
    if out_dim is None: out_dim = query.shape[-1]
    q = tf.keras.layers.Dense(units=dim, name=f'{name}/query')(query)
    k = tf.keras.layers.Dense(units=dim, name=f'{name}/key')(key)
    v = tf.keras.layers.Dense(units=dim, name=f'{name}/value')(value)

    q = tf.keras.layers.Reshape((-1, n_heads, depth))(q)
    q = tf.keras.layers.Permute((2, 1, 3))(q)
    k = tf.keras.layers.Reshape((-1, n_heads, depth))(k)
    k = tf.keras.layers.Permute((2, 1, 3))(k)
    v = tf.keras.layers.Reshape((-1, n_heads, depth))(v)
    v = tf.keras.layers.Permute((2, 1, 3))(v)
    qk = ScalarMultiply(depth ** -0.5)(MatMul(transpose_b=True)([q, k]))

    if mask is not None:
        if isinstance(mask, tf.Tensor):
            qk = ScalarMultiply(mask)(qk)
            mask=1. - mask
            mask = mask * -10
            qk = ScalarAdd(mask)(qk)
            
        else:
            qk = Multiply()([qk, mask])
            mask = ScalarSubtract(1.)(mask)
            mask = ScalarMultiply(-10)(mask)
            qk = Add(name=f'add/{name}')([qk, (mask)])
            
    attn_weights = tf.keras.layers.Softmax(axis=-1)(qk)
    attn_out = MatMul()([attn_weights, v]) 
    attn_out = tf.keras.layers.Permute((2, 1, 3))(attn_out)
    attn_out = tf.keras.layers.Reshape((-1, dim))(attn_out)
    out = tf.keras.layers.Dense(out_dim, name=f'{name}/output_dense',  dtype="float32")(attn_out)
    
    return out, attn_weights

def AddPositionalEmbeddingForEncoderDecoder(inputs, seq_len, vocab_size, embed_dim, block_name, freeze):
    x = AddTokenEmbedding(vocab_size, embed_dim, train = not freeze, name= ('token_embedding/' + block_name))(inputs)
    x = AddPositionalEmbedding(seq_len, embed_dim, name= ('positional_embedding/' + block_name))(x)
    return x
   
def enc_padding_mask(inputs):
    computed_mask=tf.keras.layers.Reshape((1, 1, -1))(inputs)
    return computed_mask   

def causal_mask(inputs):
    seq_len=inputs.shape[1]
    causal_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
    return causal_mask

def dec_padding_mask(inputs, cau_mask):
    padding_mask=  enc_padding_mask(inputs)
    padding_mask = MinimumScalar(scalar=cau_mask)(padding_mask)
    return padding_mask

def layer_norm(x, axes=2, epsilon=0.001, name='layer_norm', trainable = True):
    """LayerNormalization"""
    if isinstance(axes, int): axes = [axes]
        
    mean = Mean(axes=axes, dtype=x.dtype)(x)
    ## This block can be replaced with a squared_difference layer ##
    diff = Subtract()([x, StopGradient()(mean)])                  ##
    sq_diff = Multiply()([diff, diff])                            ##
    ## ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ##
    variance = Mean(axes=axes,dtype=x.dtype ,name=f'{name}/variance')(sq_diff)
    if not trainable:
        inv = RSqrt()(variance)
        x = Multiply()([diff, inv])
    else:
        # MaximumScalar prevents division by 0.
        inv = RSqrt()(MaximumScalar(epsilon)(variance))
        # This layer is removed for inference so it is named.
        x = Subtract(name=f'{name}/grad_subtract')([x, mean]) 
        x = Multiply()([x, inv])

    x = Scale(axes=axes)(x)
    x = Centre(axes=axes)(x)
    
    return x

def mlp(x, hidden_dim, out_dim=None):
    """Multi-layer perceptron block"""
    if out_dim is None: out_dim = x.shape[-1]

    x = tf.keras.layers.Dense(units=hidden_dim)(x)
    x = Relu()(x)
    x = tf.keras.layers.Dense(units=out_dim)(x)
    return x

(d) Build end-to-end model

In [None]:
from collections import defaultdict
def get_translation_model(input_shape, batch_size=batch_size, seq_len=seq_len, vocab_size=vocab_size, embed_dim=embed_dim, num_heads=num_heads, freeze= False, trainable=True):
    
    aux_output=defaultdict(list)
    ## Encoder
    
    # Input to encoder
    enc_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name="encoder_inputs")
    encoder_inputs=Cast()(enc_inputs)
    
    x = AddPositionalEmbeddingForEncoderDecoder(encoder_inputs, seq_len, vocab_size, embed_dim, 'encoder', freeze)
    encoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name="encoder_masks")
    encoder_padding_mask = enc_padding_mask(encoder_padding_mask_inputs)

    # Encoder Attention block
    attention_output, attention_weights = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=encoder_padding_mask, name=(f'mha'), block_name=(f'encoder'))
    proj_input = tf.keras.layers.Add()([x, attention_output])
    proj_input = layer_norm(proj_input, name=(f'layer_norm'), trainable=trainable)

    # MLP block
    proj_output = mlp(proj_input, latent_dim, embed_dim)
    x = tf.keras.layers.Add()([proj_input, proj_output])
    encoder_outputs = layer_norm(x, name=(f'layer_norm_1'), trainable=trainable)
    
    ## Decoder
    
    # Input to decoder
    dec_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name="decoder_inputs")
    decoder_inputs=Cast()(dec_inputs)

    x = AddPositionalEmbeddingForEncoderDecoder(decoder_inputs, seq_len, vocab_size, embed_dim, 'decoder', freeze)
    decoder_causal_mask = causal_mask(decoder_inputs)
    decoder_padding_mask_inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, name="decoder_masks")
    decoder_padding_mask = dec_padding_mask(decoder_padding_mask_inputs, decoder_causal_mask)
    
    
    # Decoder Attention Block 1
    attention_output_1, attention_weights_1 = self_attention(x, x, x, num_heads, embed_dim*num_heads, mask=decoder_causal_mask, name=(f'mha_1'), block_name=(f'decoder_1'))
    x1 = tf.keras.layers.Add()([x, attention_output_1])
    out_1 = layer_norm(x1,  name=(f'layer_norm_2'), trainable=trainable)
    
    # Decoder Attention Block 2
    attention_output_2, attention_weights_2 = self_attention(out_1, encoder_outputs, encoder_outputs, num_heads, embed_dim*num_heads, mask=decoder_padding_mask, name=(f'mha_2'), block_name=(f'decoder_2'))
    x2 = tf.keras.layers.Add()([out_1, attention_output_2])
    out_2 = layer_norm(x2,  name=(f'layer_norm_3'), trainable=trainable)
    
    # MLP Block
    proj_output = mlp(out_2, latent_dim, embed_dim)
    x3 =  tf.keras.layers.Add()([out_2, proj_output])
    x3 = layer_norm(x3, name=(f'layer_norm_4'), trainable=trainable)
    

    x3 = tf.keras.layers.Dropout(0.5)(x3)
    x3 = tf.keras.layers.Dense(units=vocab_size, name="dense_last", activation='softmax')(x3)

    transformer = keras.Model(
        [enc_inputs,encoder_padding_mask_inputs, dec_inputs, decoder_padding_mask_inputs], x3, name="transformer"
    )
    
    return transformer

tf.keras.backend.clear_session()  # reset layer name counters

transform = get_translation_model(input_shape = (seq_len,), batch_size = batch_size)

(e) Train the FP32 model

In [None]:
transform.summary()
train(transform, model_type='fp32')

(f) Evaluate Performance

In [None]:
# Get BLEU score on test set for FP32 transformer model
get_text_result(transform)

In [None]:
# Get accuracy on test set for the FP32 transformer model from Keras example
evaluate(transform)

### 9. Convert FP32 model to FP32 tflite model

(a) Generate a non-optimized tflite (float32 operations) file for FP32 model

In [None]:
i = tf.keras.Input(shape=(20,), batch_size=1)
j = tf.keras.Input(shape=(20,), batch_size=1)
k = tf.keras.Input(shape=(20,), batch_size=1)
l = tf.keras.Input(shape=(20,), batch_size=1)
net = tf.keras.Model(inputs=[i, j,k,l,], outputs=transform.call([i,j,k,l]))

MODEL_PATH = './encoder_decoder_fp32.tflite'

converter = tf.lite.TFLiteConverter.from_keras_model(net)
tflite_model = converter.convert()
with open(MODEL_PATH, "wb+") as tflite_file:
    tflite_file.write(tflite_model)

(b) Evaluate performance

In [None]:
get_tflite_accuracy(MODEL_PATH, input_type='float32', output_type='float32')

In [None]:
get_text_result_tflite(MODEL_PATH, input_type='float32', output_type='float32')

In [None]:
print("Model size: ", get_gzipped_model_size(MODEL_PATH), ' KB')

### 10. Perform QAT on FP32 model with TFMOT

(a) To use the custom Keras layers we defined, we need to pass a [`QuantizeConfig`](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/quantization/keras/QuantizeConfig) for each of these layers.

For Keras layers which are already supported in TFMOT, a default QuantizeConfig class is assigned to each one. However, custom QuantizeConfig instances could also be created for these layers to give more control over how they are quantised.

In [None]:
from tensorflow_model_optimization.quantization.keras import QuantizeConfig, quantizers

LastValueQuantizer = quantizers.LastValueQuantizer
MovingAverageQuantizer = quantizers.MovingAverageQuantizer
AllValuesQuantizer = quantizers.AllValuesQuantizer

class NoOpQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which does not quantize any part of the layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return []
        
    def get_config(self):
        return {}


class TFOpQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which only quantizes the output of a layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}


class MaskOpQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which only quantizes the output of a layer and is meant for the input masks."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}

    
class VarianceQuantizeConfig(QuantizeConfig):
    """QuantizeConfig for the variance calculation in the layer normalisation layer."""

    def get_weights_and_quantizers(self, layer):
        return []

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [AllValuesQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}
        

class WeightQuantizeConfig(QuantizeConfig):
    """QuantizeConfig which quantizes the custom weights in the patch encoder and layer normalisation layers."""

    def __init__(self):
        self.weight_quantizer = LastValueQuantizer(num_bits=8, per_axis=False,
                                                   symmetric=True, narrow_range=True)
        self.activation_quantizer = MovingAverageQuantizer(num_bits=8, per_axis=False,
                                                           symmetric=False, narrow_range=False)

    def get_weights_and_quantizers(self, layer):
        return [(layer.w, self.weight_quantizer)]

    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        layer.w = quantize_weights[0]

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    def get_output_quantizers(self, layer):
        return [self.activation_quantizer]

    def get_config(self):
        return {}

(b) Define wrapper function

Since custom layers and QuantizeConfigs are used, the whole model cannot directly be wrapped with QAT wrappers.
So first we write a function to wrap the individual layers with QAT wrappers:

In [None]:
def apply_wrapper(wrapper_function, layer_param_dict):
    
    def wrap_layer(layer):
        if layer.name in layer_param_dict.keys():
            return wrapper_function(layer, **layer_param_dict[layer.name])
        return layer

    return wrap_layer

def layer_wrapper(model, wrapper_function, layer_param_dict):
    return tf.keras.models.clone_model(model, clone_function=apply_wrapper(wrapper_function, layer_param_dict))

(c) Assign QuantizeConfigs to custom layers

In [None]:
def get_quantize_config(model):
    layer_param_dict = {}  # stores {Layer_Name: QuantizeConfig} pairs
    scope = {}  # stores all custom objects

    for layer in model.layers:
            
            if layer.name.startswith(('clip', 'minimum', 'minimum_scalar', 'maximum_scalar', 'cast', 'stop_gradient')):
                layer_param_dict[layer.name] = {'quantize_config': NoOpQuantizeConfig()}
                scope[layer.__class__.__name__] = layer.__class__
            
            elif 'grad_subtract' in layer.name or layer.name.startswith(('mat_mul', 'multiply', 'scalar_multiply', 'add',
                                                                         'scalar_add', 'slice', 'mean', 'subtract',
                                                                         'scalar_subtract', 'r_sqrt', 'relu')):
                layer_param_dict[layer.name] = {'quantize_config': TFOpQuantizeConfig()}
                scope[layer.__class__.__name__] = layer.__class__
                
            elif layer.name.startswith(( 'scale', 'centre', 'positional_embedding', 'token_embedding' )):
                layer_param_dict[layer.name] = {'quantize_config': WeightQuantizeConfig()}
                scope[layer.__class__.__name__] = layer.__class__

            # Make sure to quantise the encoder and decoder mask input layers so that they can be quantized to INT8
            
            elif layer.name.startswith(('encoder_masks', 'decoder_masks' )):
                layer_param_dict[layer.name] = {'quantize_config': MaskOpQuantizeConfig()}

            elif 'variance' in layer.name:
                layer_param_dict[layer.name] = {'quantize_config': VarianceQuantizeConfig()}
                scope[layer.__class__.__name__] = layer.__class__
        
    scope['NoOpQuantizeConfig'] = NoOpQuantizeConfig
    scope['TFOpQuantizeConfig'] = TFOpQuantizeConfig
    scope['WeightQuantizeConfig'] = WeightQuantizeConfig
    scope['VarianceQuantizeConfig'] = VarianceQuantizeConfig
    scope['MaskOpQuantizeConfig'] = MaskOpQuantizeConfig

    return layer_param_dict, scope

layer_param_dict, scope = get_quantize_config(transform)


Few layers like `cast`, `encoder_inputs` and `decoder_inputs` musn't be annontated with any QuantizeConfig as this will result into a `quantize` node being added after the inputs in tflite graph, which would pass down an int8 value to the `tfl.gather` operation. <br>And since, the `tfl.gather` operation expects only int32 and int64 as the indices, an int8 value in the `tfl.gather` operation will result into error ([Please refer TF Lite Ops Page](https://www.tensorflow.org/mlir/tfl_ops#tflgather_mlirtflgatherop)).

In [None]:
def remove_unwanted_layers(model, layer_param_dict):
    # All the layers that don't need quantization can be added along side 'cast', 'encoder_inputs' and 'decoder_inputs'
    layers_to_not_quantise = [x.name for x in model.layers if not any([y in x.name for y in ['cast', 'encoder_inputs', 'decoder_inputs'
                                                                                            ]])]
    layer_param_dict = {k: v for k, v in layer_param_dict.items() if k in layers_to_not_quantise}
    for k in layers_to_not_quantise:
        if k not in layer_param_dict:
            layer_param_dict[k] = {'quantize_config': None}

    return layer_param_dict

layer_param_dict = remove_unwanted_layers(transform, layer_param_dict)    

(d) Load the necessary API classes/functions

In [None]:
quantize_annotate_layer = tfmot.quantization.keras.quantize_annotate_layer
quantize_annotate_model = tfmot.quantization.keras.quantize_annotate_model
quantize_apply = tfmot.quantization.keras.quantize_apply
quantize_scope = tfmot.quantization.keras.quantize_scope

(e) Annotate individual layers

When calling the quantize_apply function, if an unsupported layer is missing from the scope, TFMOT will throw an error.

In [None]:
# Wrap each custom layer with the corresponding QuantizeConfig:

qat_model = layer_wrapper(transform, quantize_annotate_layer, layer_param_dict)

with quantize_scope(scope):
    qat_model = quantize_apply(qat_model)


(f) Perform QAT

In [None]:
qat_model.summary()
train(qat_model, model_type='qat', epochs=3)

(g) Evaluate Performance

In [None]:
get_text_result(qat_model)

In [None]:
evaluate(qat_model)

### 11. Create INT8 tflite file for QAT FP32 model

If we attempt to directly generate a TFLite file using the fine-tuned model above:

- It will not have a correct batch size of 1.
- It will have operators which are unnecessary during inference. Precisely, the extra `Subtract` operators and `MaximumScalar` operator in the layer normalisation blocks, which were used during training and fine-tuning, should be removed from the graph before creating the TFLite file.

Therefore the network should be redefined with a batch size of 1 and with the redundant operators removed. The weights of the fine-tuned optimised model can then be loaded into this new model.

(a) Remove layers which are not required

In [None]:
tf.keras.backend.clear_session()  # reset layer name counters

new_qat_model = get_translation_model(input_shape = (seq_len,), batch_size = batch_size, trainable = False)

(b) Annotate individual layers

In [None]:
# Get the QuantizeConfig and Scope which would be used to annotate the layers
layer_param_dict, scope = get_quantize_config(new_qat_model)
# Remove unwanted QuantizeConfigs
layer_param_dict = remove_unwanted_layers(new_qat_model, layer_param_dict)    

new_qat_model = layer_wrapper(new_qat_model, quantize_annotate_layer, layer_param_dict)

with quantize_scope(scope):
    new_qat_model = quantize_apply(new_qat_model)

(c) Load weights into the model

In [None]:
new_qat_model.load_weights('./eng_spa_transformer_qat_tutorial_qat_model.h5', by_name=True)

In [None]:
# Sanity check to see if weights are loaded correctly
evaluate(new_qat_model)

(d) Create tflite file (int8 ops)

In [None]:
i = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)
j = tf.keras.Input(shape=(20,), batch_size=1)
k = tf.keras.Input(shape=(20,), batch_size=1, dtype = tf.int32)
l = tf.keras.Input(shape=(20,), batch_size=1)

# The following is done to ensure that the batch size of input in
# tflite graph is 1
net = tf.keras.Model(inputs=[i, j,k,l,], outputs=new_qat_model.call([i,j,k,l]))

MODEL_PATH = './encoder_decoder_qat.tflite'

converter = tf.lite.TFLiteConverter.from_keras_model(net)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# The following two lines ensure that the mask inputs
# and the output are int8
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

# Toggle this option to fold/unfold batchmatmul
converter._experimental_disable_batchmatmul_unfold = True

tflite_model = converter.convert()
with open(MODEL_PATH, "wb+") as tflite_file:
    tflite_file.write(tflite_model)

(e) Evaluate Performance

NOTE: These steps are slow to execute therefore, the number of samples on which evaluation is performed is set to 200 by default (but definitely can be modified by the user)

In [None]:
get_tflite_accuracy(MODEL_PATH)

In [None]:

get_text_result_tflite(MODEL_PATH)

In [None]:
print("Model size: ", get_gzipped_model_size(MODEL_PATH), ' KB')