# Implementing Transformers From Scratch With Tensorflow

Transformers are machine learning models that have proven to be incredibly powerful for natural language processing (NLP) applications. Standard NLP tools such as LSTM or GRU networks suffer from vanishing gradients, therefore,  performance tends to suffer when processing long sequences. Transformers, however, use a special attention mechanism to bypass this vanishing gradient problem.

In this project we create a portuguese-to-english translator in Tensorflow. We do this by training and building a transformer model from first principles. 

NOTE: 
- The model in this example is not optimized for performance. This project is for the purposes of learning about transformers in general.
- This project was inspired by the following Tensorflow tutorial: https://www.tensorflow.org/text/tutorials/transformer

I'm using the same portuguese to english dataset as in the tutorial featured above ^. While implementing the transformer itself, however, I've avoided looking at any outside resources (aside from the original research paper **"Attention Is All You Need"**: https://arxiv.org/abs/1706.03762).

## Imports and Initialization

In [1]:
# Essential Imports
import numpy as np

# Tensorflow Imports
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as text

# Misc Imports
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


## Data Preprocessing

### Download Dataset

We start by acquiring our dataset. Tensoflow datasets contains a collection of ~50k examples of portuguese text (as well as the corresponding english translation):

In [2]:
# Download Dataset
examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', 
                               with_info=True,
                               as_supervised=True)
train_examples, val_examples = examples['train'], examples['validation']

2022-08-07 16:32:44.073517: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Tokenize Dataset

Now that we have our dataset downloaded, we need to convert it into a vectorized format. There are many different tokenization techniques, but here we'll utilize the tokenizer that comes with the dataset.

In [3]:
# Download a pre-fitted tokenizer 
model_name = 'ted_hrlr_translate_pt_en_converter'
tf.keras.utils.get_file(
    f'{model_name}.zip',
    f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip',
    cache_dir='.', cache_subdir='', extract=True
)
tokenizers = tf.saved_model.load(f"{model_name}")

pt_vocab_size = tokenizers.pt.get_vocab_size().numpy()
en_vocab_size = tokenizers.en.get_vocab_size().numpy()

In [4]:
# Apply the tokenizer to the dataset
def tokenize_data(_x,_y):
    return tokenizers.pt.tokenize(_x).to_tensor(), tokenizers.en.tokenize(_y).to_tensor()

def batch_data(_dataset, _batch_size):
    return _dataset.batch(_batch_size).map(tokenize_data)

batch_size = 128
train_dataset_batched = batch_data(train_examples, batch_size)
test_dataset_batched = batch_data(val_examples, batch_size)

### Positional Encoding

Traditional recurrent neural networks function by processing tokens sequentially. A hidden state is calculated at each time step, and is then fed into the network at the next time step. Processing inputs sequentially can be slow, and as the the number of time steps increases, then network begins to "forget" about tokens at the start of the sequence. 

In contrast, transformer networks process the inputs in parallel. This is much faster, but there's one problem. Consider the following two sentences:

- The cat ran over the lazy dog.
- The dog ran over the lazy cat.

The ordering of a sentence matters. These two sentences have very different meanings. But if we process inputs in parrallel, then these sentences are functionally the same. The network inputs contain no temporal information. We'll need to add it ourselves.

In [5]:
class PositionalEncodingLayer(tf.keras.layers.Layer):
    def __init__(self, d_model):
        self.d_model = d_model
        super(PositionalEncodingLayer, self).__init__()
    
    def pos_encoding_mask(self, inputs): 
        # pos - sentence index
        # i - embedding index 
        # inputs - shape (batch_size, sentence_length, embedding_dim)

        pos_arr = tf.range(start = 0, limit = tf.shape(inputs)[1] , dtype = tf.float32)[:, np.newaxis]
        i_arr = tf.range(start = 0, limit = self.d_model , dtype = tf.float32)[np.newaxis, :]
        
        # assign a fixed frequency to each dimension of the embedding
        omega = 1 / np.power(10000, (2 * (i_arr//2)) / np.float32(self.d_model))
        angles = np.matmul(pos_arr, i_arr)
        
        pos_mask = np.zeros((tf.shape(inputs)[1], self.d_model))
        pos_mask[:,0::2] = tf.math.cos(angles)[:,0::2]
        pos_mask[:,1::2] = tf.math.sin(angles)[:,1::2]
        pos_mask = tf.convert_to_tensor(pos_mask, dtype = tf.float32)
        return pos_mask 
    
    def call(self, inputs):
        pos_mask = self.pos_encoding_mask(inputs)
        return tf.math.add(inputs, pos_mask)

The solution proposed by the authors of "Attention Is All You Need":

Suppose each token `token` at position `sentence_pos` is represented by a d-dimensional vector. We want to create a new representation `token_new = pos_mask + token` that encodes both the semantic information of `token`, while also containing information about its position in the sentence. 

One solution: assign a unique (fixed) frequency `f_embedding_dim` to each dimenstion of the word embedding. The value of `pos_mask` is determined by a sinusoid that oscillates with `f_embedding_dim` over `sentence_pos`:

$f($ `sentence_pos` $)$ = $sin($ `f_embedding_dim` *`sentence_pos`$)$

where `f_embedding_dim` is a number that increases with the value of `embedding_dim`. The idea is that as we reach higher values of `embedding_dim`, it becomes easier to resolve words that are close together, but more difficult to resolve words that are further apart. The opposite logic also applies to lower dimensions.

The embedding used by the authors of "Attention is All You Need":
![title](imgs/pos_encoding.png)

In [6]:
d_model = 20 # dimensionality of the word embedding
n_heads = 8 # number of attention heads

## Create Transformer Model

The attention mechanism is a powerful technique used, in this case, to process sequences of text. Suppose we have an `n x d` array, representing an embedded sequence of tokens. After applying the attention mechanism, we want to end up with another `n x d` array, representing the original sentence after the attention block has filtered out and processed pertinent information. We call this `MultiHeadedAttention` because we'll stack multiple attention blocks together. Hopefully, each block attends to different parts of the sentence and learns to gather useful information.
![title](imgs/attention.png)

In [7]:
class MultiHeadedAttention(tf.keras.layers.Layer):
    def __init__(self, n_heads, d_model ):
        self.n_heads = n_heads
        self.d_model = d_model
        super(MultiHeadedAttention, self).__init__()
    
    def build(self, input_shape):        
        # Create Self-Attention Kernels
        self.kernels = {head_ind:{} for head_ind in range(n_heads)}
        kernel_types = ["query", "key", "value"]
        for head_ind in range(n_heads):
            for kernel_type in kernel_types:
                self.kernels[head_ind][kernel_type] = tf.keras.layers.Dense(self.d_model, 
                                                                            activation = 'relu')
        self.attention_normalization = tf.keras.layers.Normalization()
        self.linear_output_layer = tf.keras.layers.Dense(self.d_model, activation = 'relu')
        
        # Residual Connections
        self.res_normalization_layer = tf.keras.layers.Normalization()
        self.res_linear_layer = tf.keras.layers.Dense(self.d_model, 
                                                  activation = 'relu')  
        self.dropout = tf.keras.layers.Dropout(0.2)
        
    def call(self, inputs, padding_mask = None):
        '''
        The attention block has three inputs:
            - Queries
            - Keys
            - Values
            
        You can think of the keys as the "memory" of the attention block.
        The attention block compares the queries to the keys to produce an attention matrix.
        This attention matrix acts as a filter that is applied to the values, the output 
            of the model is a weighted average of the value sequence - containing only 
            information that the attention block deems important.
        '''
        
        query_inp, key_inp, value_inp = inputs
        # inputs - shape (batch_size, sentence_length, embedding_dim)
        
        ## Self-Attention Block
        all_values = []
        for head_ind in range(self.n_heads):         
            # First we pass the queries and keys through 
            # their own dense layers. this allows each attention head
            # to generate unique attention vectors
            queries = self.kernels[head_ind]["query"](query_inp)
            keys = self.kernels[head_ind]["key"](key_inp)
            
            # Compute the attention matrix
            scores = tf.matmul(queries, tf.transpose(keys, perm = [0,2,1]))
            
            # Rescale the scores
            d_k = (tf.shape(scores).numpy()[-1])
            scaled_scores = (1/np.sqrt(d_k))*scores
            
            # (Optional) Apply a mask to the input
            if padding_mask is not None:
                scaled_scores += (padding_mask*-1e9)
    
            softmax_scores = tf.keras.activations.softmax(scaled_scores, axis = 2)
        
            # Apply the attention matrix to the values
            values = tf.matmul(softmax_scores, tf.transpose(value_inp, perm = [0,1,2]))
            all_values.append(values)
        all_values = tf.concat(all_values, axis = 2)
        normalized_values = self.attention_normalization(all_values)
        linear_layer_values = self.dropout(self.linear_output_layer(normalized_values))
        
        # Residual Connection
        # Add the attention block output to the original input and normalize
        residual_connection = tf.math.add(query_inp, linear_layer_values)
        linear_layer_out = self.res_linear_layer(residual_connection)
        output = self.res_normalization_layer(linear_layer_out)
        return output

In [8]:
def outerprod(input_x):
    values = input_x[:, tf.newaxis, :]
    values = tf.matmul(tf.transpose(values, perm = [0,2,1]), values)
    return values

In [9]:
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)
    # add extra dimensions to add the padding
    # to the attention logits.
    return seq#[:, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

def create_lookahead_mask(input_x):
    # tf.linalg.band_part(input, -1, 0) set the lower diagonal to 0 
    lookahead_mask = 1 - tf.linalg.band_part(tf.ones_like(input_x), -1, 0)
    return tf.cast(lookahead_mask, tf.float32) 

Now we're ready to build the transformer model itself. The transformer outputs a probability vector of dimensions `(1,vocab_size)`, predicting the next token in the sequence. The new sequence is fed back into the model, and the process continues until we ask the model to stop.


**Encoder Block**

The initial sequence is tokenized and embedded using the positional encoding technique mentioned above. The embedded sequence is passed through a multi-headed attention layer. The output of the encoder block is a new sequence that represents a weighted average of the original sequence, containing only the information that the attention block deems important.
![title](imgs/encoder.png)

In [10]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, n_heads):
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        super(Encoder, self).__init__(dynamic = True)

    def build(self, input_shape):
        input_length = input_shape[1]
        self.embedding_layer = tf.keras.layers.Embedding(input_dim = pt_vocab_size,
                                                    output_dim = self.embedding_dim, 
                                                    input_length = input_length)
        self.pos_encoding_layer = PositionalEncodingLayer(self.embedding_dim)
        self.multiheadedattention = MultiHeadedAttention(n_heads = self.n_heads, d_model = self.embedding_dim)
    
    def call(self, inputs):
        # inputs - shape(batch_size, sentence_length , embedding_size)
        padding_mask = create_padding_mask(outerprod(inputs))
        
        # Convert the tokenized sequence into embedded vectors
        embedding = self.embedding_layer(inputs) 
        positional_embedding = self.pos_encoding_layer(embedding)
        
        # Self attention layer
        attention = self.multiheadedattention([positional_embedding]*3, padding_mask)
        return attention
    
    def compute_output_shape(self, input_shape):
        return tf.TensorShape((input_shape[1], self.embedding_dim))

**Decoder Block**

(1) The output from the previous timestep is tokenized and embedded using the positional encoding technique mentioned above. The embedded sequence is passed through a multi-headed attention layer. 

(2) We'll take the output from the encoder block and use it as the values and keys for another multi-headed attention block. the queries are the output of (1). The idea is to compare the transformer's output against the orignal sequence to produce the next output.

(3) The output of (2) is passed through a softmax layer to produce a probability vector of size `(1, vocab_size)`, which is used to produce the next token.
![title](imgs/decoder.png)

In [11]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, embedding_dim, n_heads):
        self.embedding_dim = embedding_dim
        self.n_heads = n_heads
        super(Decoder, self).__init__(dynamic = True)

    def build(self, input_shape):
        input_length = input_shape[1]
        self.embedding_layer = tf.keras.layers.Embedding(input_dim = pt_vocab_size,
                                                    output_dim = self.embedding_dim, 
                                                    input_length = input_length)
        self.pos_encoding_layer = PositionalEncodingLayer(self.embedding_dim)
        self.multiheadedattention_1 = MultiHeadedAttention(n_heads = self.n_heads, 
                                                         d_model = self.embedding_dim)
        self.multiheadedattention_2 = MultiHeadedAttention(n_heads = self.n_heads, 
                                                         d_model = self.embedding_dim)
        self.softmax_layer = tf.keras.layers.Dense(en_vocab_size, activation = 'softmax')  
        
    def call(self, inputs):
        # inputs - shape(batch_size, sentence_length , embedding_size)
        queries_encoder, keys_encoder, output_prev = inputs
        
        # Create padding masks
        padding_mask = create_padding_mask(outerprod(output_prev))
        lookahead_mask = create_lookahead_mask(outerprod(output_prev))
        mask = tf.math.maximum(padding_mask, lookahead_mask)
        
        # Encode the output of the previous timestep
        output_prev_embedded = self.pos_encoding_layer(self.embedding_layer(output_prev)) 
        # Pass the previous timestep's output through a self attention layer
        attention_1 = self.multiheadedattention_1([output_prev_embedded]*3, mask)
        
        # Use the output of the previous timestep as the queries for a new attention layer
        # The keys and values are the output of the encoder block
        attention_2 = self.multiheadedattention_2([attention_1, queries_encoder, keys_encoder])
        
        # produce a probability vector to predict the next token in the sequence
        probs = self.softmax_layer(attention_2)
        return probs
    
    def compute_output_shape(self, input_shape):

        return tf.TensorShape((input_shape[0][0], en_vocab_size))

## Test Model

### Build Model

In [12]:
input_encoder = tf.keras.layers.Input(shape = (None,))
input_prev_output = tf.keras.layers.Input(shape = (None,))

encoder_output = Encoder(d_model, n_heads)(input_encoder)
p_tokens = Decoder(d_model, n_heads)([encoder_output, encoder_output, input_prev_output])#, transformer_output])

model = tf.keras.Model(inputs = [input_encoder, input_prev_output], 
                       outputs = [p_tokens]) 
optimizer = tf.keras.optimizers.Adam(learning_rate = 1e-4)

### Train Model

In [None]:
n_episodes = 10
for ep in tqdm(range(n_episodes)):
    
    # Iterate through each batch
    for batch_ind, (port, eng) in enumerate(train_dataset_batched):  
        prev_transformer_output = tf.zeros((batch_size, tf.shape(eng)[1].numpy()),dtype = tf.int64)
        with tf.GradientTape(persistent=True) as tape:
            # Query Model
            token_probs = model([port,prev_transformer_output])
            next_token = tf.argmax(token_probs, axis = 2)
            
            # Calculate loss
            eng_one_hot = tf.one_hot(eng, depth = en_vocab_size)
            cce = tf.keras.losses.CategoricalCrossentropy()
            loss = cce(eng_one_hot, token_probs)
        
        # Update Model Parameters
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        #if batch_ind == 300: break
        
    # Compare the predicted translation to the true translation
    translation_pred = tokenizers.en.detokenize(next_token)
    translation_true = tokenizers.en.detokenize(eng)
    orig = tokenizers.pt.detokenize(port)
        
    for pt_line,en_line_pred, en_line_true in zip(orig.numpy(), translation_pred.numpy(), translation_true.numpy()):
        print(loss)
        print(f"Portuguese: {pt_line.decode('utf-8')}")
        print(f"Translation (Pred): {en_line_pred.decode('utf-8')}")
        print(f"Translation (True): {en_line_true.decode('utf-8')}")
        break


 10%|████████▎                                                                          | 1/10 [01:05<09:46, 65.18s/it]

tf.Tensor(8.88289, shape=(), dtype=float32)
Portuguese: agora , os transgressores tambem beneficiam .
Translation (Pred): ##rationsrations graphle claimduced grapholis tearsrations metabolic belong metabolicrations tearsolis belongduced claimle graphrations exit + discovery watched clicks graph discovery endingrationsrations belong theoryrations belongrations ending discovery graph clicks watched manage +rationsrations graph ending claim claim grapholis managerations metabolic belong graphrations discoveryolis belonguts claimle graphrations medicine metabolic belong watched clicks graph your clicksrationsrations neither opportunitiesrations belongrationsrations discovery graph clicks watched manage +rationsrations readerrations + claim grapholis metabolicrations metabolic
Translation (True): now , the offenders , they also benefit .


 20%|████████████████▌                                                                  | 2/10 [02:09<08:36, 64.59s/it]

tf.Tensor(8.878183, shape=(), dtype=float32)
Portuguese: agora , os transgressores tambem beneficiam .
Translation (Pred): ##rationsrations graphle claimduced grapholis belongrations metabolic belong metabolicrations belongolis belong iron claimle graphrations exit + discovery watched clicks graph your endingrationsrationsrations opportunitiesrations belongrations ending discovery graph clicks watched manage +rationsrations graph ending claim claim grapholis managerations metabolic belong watchedrations discoveryolis belonguts claimle graphrations medicine metabolic belong watched clicks graph your worstrationsrations neither opportunitiesrations belongrationsrations discovery graph clicks watched manage +rationsrations graphrations + claim grapholis metabolicrations metabolic
Translation (True): now , the offenders , they also benefit .


 30%|████████████████████████▉                                                          | 3/10 [03:12<07:27, 63.97s/it]

tf.Tensor(8.873693, shape=(), dtype=float32)
Portuguese: agora , os transgressores tambem beneficiam .
Translation (Pred): ##rationsrations graphle claimduced grapholis belongrations metabolic belong metabolicrations belongolis belong iron claimle graphrations exit + discovery medicine clicks graph your endingrationsrationsrations opportunitiesrations belongrations ending discovery graph clicks watched manage +rationsrations graph ending + claim grapholis watchedrations metabolic belong watchedrations discoveryolis belong iron claimle graphrations medicine metabolic belong watched clicks graph your worstrationsrations neither opportunitiesrationstrationsrations discovery graph clicks watched manage +rationsrations graphrations + claim grapholis metabolicrations metabolic
Translation (True): now , the offenders , they also benefit .


 40%|█████████████████████████████████▏                                                 | 4/10 [04:15<06:21, 63.54s/it]

tf.Tensor(8.869382, shape=(), dtype=float32)
Portuguese: agora , os transgressores tambem beneficiam .
Translation (Pred): ##rationsrations graph astduced grapholis belongrations metabolic belong metabolicrations belongolis belong iron claimle graphrations exit + discovery medicine clicks graph your endingrationsrationsrations opportunitiesrationstrations ending your graph clicks watched manage +rationsrations reader ast claim grapholis watchedrations metabolic belong medicinerations worldolis belong ironlele graphrations medicine metabolic metabolic medicine clicks graph your endingrationsrations neither opportunitiesrationstrationsrations discovery graph clicks watched manage +rations metabolic graphrations + claim graphle metabolicrations metabolic
Translation (True): now , the offenders , they also benefit .
