# Transformer Taxonomy

Previously we've covered fine-tuning and evaluating a transformer.

 In this notebook, below topics are gonna be covered:
 1. Implementing attenion mechanism
 2. Implementing attention decoder
 3. Architectural difference between the encoder and decoder(For decoder only transformer refer [nanogpt](https://github.com/JpChii/nanogpt))
 3. Taxonomy of transformers

## The Transformer architecture

Transformer is based on encoder-decoder architectecture which is widely used for machine translation tasks.

***Encoder*** - Encoder Converts sequence of input tokens to embeddings also called *hidden state* or *context*

***Decoder*** - Decoder uses the *hidden state* to iterativley generate an output sequence of tokens, one token at a time.

*Encoder-Decoder architecture*

![Architecture](../notes/images/3-transformer-taxonomy/encoder-decoder.png)

* Encoder combines token embeddings and positional embeddings and pass them through a stack of encoder layers to product hidden state.
* Encoder's output is fed to each decoder layer. Decoder layer then predicts the next most probabale token.
* Let's say *Die* and *Zeit* are predicted, now the decoder gets these two and encoder outputs to predict the next token *fliegt*.
* This process'll be reperated until EOS token.

### The Encoder

In encoder, the encoder stack followed by token embeddings are made up of below layers,
* multi-head self-attention layer
* fully connected feed-forward layer that is applied to each input embedding

The output embeddings of each encoder layers have the same size as the inputs. The main role of encoder stack is to **update** input embedding with some contextutal information.

Ex: *apple* will be updated to be more *company-lik* and less *fruit-like* if work *keynote* and *phone* are close to it.

*Multi-head attention*

![Multi-head attention](../notes/images/3-transformer-taxonomy/multi-head-attention.png)

Each of these layer use skip connections and layer normalization to train deeop neural networks effectivley.

Let's start with self-attention layer:

#### Self-Attention

Self-attention computes weights for all hidden states in the same set(encoder or decoder).

***Main idea of self-attention***: Instead of a fixed embedding for each token, we can use the entire sequence and compute a *weighted average* of each embedding. Mathametically this can be defined as below,

For a sequence of x_1, x_2, ....., x_n, self-attention produces a sequences of new embeddings xx_1, xx_2, ....,xx_n where each xx_i is a linear combination of all the x_j.

*Expression*

![expression](../notes/images/3-transformer-taxonomy/self-attention-expression.png)

The coeffecients w_ji are `attention weight* and are normalized to 1. How averaging works?

Consider *time flies like an arrow* where *flies* is a verb. By assigning more weights to arrow and time, we come with a representation of flies which has some context in it. Embeddings generated this way are called *contextualized embeddings*.

Let's check how these attention weights are calculates...

#### Scaled dot-product attention

The common way to implement self-attention is scaled-dot product from the transformer introduction paper.

There are four main steps:
* Project key, query, value for each embedding
  * These are Linear matrices
* dot-product query and key to get attention weights matrix
  * query is current the embedding token and is multiplied with keys of all token. The dot product will be high based on similarity
* To avoid explosion of attention weights by dot product it's normalized
* Finally token embeddings are updated by 
    dot product of attention weights and values.

In an encoder setup, each token will have information from all other tokens, so the first token [CLS] will have all the information about the entire sequence and can be used as a last hidden state to be fed to find similarity or to decoder or to perform classification.

In an decoder setup, we restrict the access to future using tril to force the model to predict next tokens. This is also called as auto-regressive setup.

### Visualizing attention

Attention can be visualized with [BertViz for Jupyter](https://oreil.ly/eQK3I)

To visualize attention weights, neuron_view module traces computation of wieghts to show the query and key vectors are combined to product the final weight.

In [1]:
from transformers import AutoTokenizer
from bertviz.transformers_neuron_view import BertModel
from bertviz.neuron_view import show

model_ckpt = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
model = BertModel.from_pretrained(model_ckpt)

text = "time flies like an arrow"
show(model=model,model_type="bert", tokenizer=tokenizer, sentence_a=text, display_mode="light", layer=0, head=8)

Downloading: 0.00B [00:00, ?B/s]

Downloading: 0.00B [00:00, ?B/s]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

flies has the highest magnitude with arrow.

*Operations in scaled dot-product attention*

![scaled dot product attention](../notes/images/3-transformer-taxonomy/operations-in-scaled-dot-product-attention.png)

Let's implement the transformer architecure using PyTorch.

Let's first tokenize the text.

In [2]:
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False) # Removing Special tokens (CLS, SEP) for simplicity
inputs.keys(), inputs

(dict_keys(['input_ids', 'token_type_ids', 'attention_mask']),
 {'input_ids': tensor([[ 2051, 10029,  2066,  2019,  8612]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])})

Next create some dense embeddings. *Dense* in this context means each entry in the embeddings contain nonzero value, whereas one-hot encoding are *sparse* as all values except one are zero's.

We can do this in PyTorch using `torch.nn.Embedding` that acts as a lookup table for each input ID.

In [3]:
from torch import nn
from transformers import AutoConfig
config = AutoConfig.from_pretrained(model_ckpt)
token_emb = nn.Embedding(
    num_embeddings=config.vocab_size, # Lookup is for each id in vocabulary
    embedding_dim=config.hidden_size, # Embedding dimension for each id
)
token_emb


Embedding(30522, 768)

In [4]:
config.vocab_size

30522

Now we have a lookup for each id(30522) in vocab with a dimension of 768.

`AutoConfig` for each checkpoint in `transformers` load the `config.json` associated with the checkpoint. From this we can access paramters like `vocab_size`, `hidden_size` like how we've done above.

Now the token embeddings at this point are independent of the context. This means that homonyms (words that have the same spelling but different meaning), have the same representation.

The role of the subsequent attention layers is to update the embedding representation of each token with the content of its context.

In [5]:
# Let's generate some embeddings
input_embeds = token_emb(inputs.input_ids)
input_embeds.size()

torch.Size([1, 5, 768])

Now we've `[batch_size, seq_len(no_of_tokens), embedding_dim(hidden_size)]`.

We'll implement positional encodings later. Let's implement the scaled dot product attention.

In [6]:
import torch
from math import sqrt

# Initializing query, key, value to input_embeds
query = key = value = input_embeds
dim_k = key.size(-1)
# (1, 5, 768) * (1, 768 * 5) --> (1, 5, 5)
# Scaling with embedding dim
scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
scores.size()

torch.Size([1, 5, 5])

In [7]:
scores

tensor([[[ 2.4924e+01,  2.9782e-01,  7.9484e-01,  2.2592e-01, -1.4588e-02],
         [ 2.9782e-01,  2.9413e+01,  9.5575e-01,  1.1878e+00, -1.4463e+00],
         [ 7.9484e-01,  9.5575e-01,  2.6958e+01, -4.3528e-01,  4.2961e-02],
         [ 2.2592e-01,  1.1878e+00, -4.3528e-01,  2.7287e+01,  4.5015e-01],
         [-1.4588e-02, -1.4463e+00,  4.2961e-02,  4.5015e-01,  2.6981e+01]]],
       grad_fn=<DivBackward0>)

For simplicity we've use the same for simplicity, but they'll have their independent weight matrices. And scaled with embedding dim to avoid explosion during training and causes softmax to saturate which we apply next.

`torch.bmm` dot products each individual matrices independently.

In [8]:
# Applying softmax
import torch.nn.functional as F
# Normalizing weights
weights = F.softmax(scores, dim=-1)
weights.sum(dim=-1)

tensor([[1., 1., 1., 1., 1.]], grad_fn=<SumBackward1>)

In [9]:
# Validating normalization
scores[0][0].sum(), weights[0][0].sum()

(tensor(26.2279, grad_fn=<SumBackward0>), tensor(1., grad_fn=<SumBackward0>))

In [103]:
weights.dtype

torch.float32

In [10]:
# Finally multiply attention weights with values
# (1, 5, 5) * (1, 5, 768), -1 dimension of weights is stretched
attn_outputs = torch.bmm(weights, value)
attn_outputs.shape

torch.Size([1, 5, 768])

In [11]:
# Let's implement all these steps into a function
def scaled_dot_product_attention(key,query, value):
    """_summary_

    Accepts key, query, value
     * Key query are dot product, scaled by embedding dim
     * Normalized with softmax
     * Resulting attention weights are dot product with values

    Args:
        key (_type_): Input token keys
        query (_type_): Input token queries
        value (_type_): Input token values

    Returns:
        New embeddings.
    """

    # Get embedding dim for scaling
    dim_k = query.size(-1)
    # (1, 5, 768) * (1, 768 * 5) --> (1, 5, 5)
    # Scaling with embedding dim
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    # Normalizing
    weights = F.softmax(scores, dim=-1)
    # Multiply attention weights with values
    attn_outputs = torch.bmm(weights, value)
    return attn_outputs

The query key dot product will be highest for the same word in a sentence. Since we use the same weights for them. But `flies` is more defined by `arrow` than `flies` itself.

We can overcome this by using different weight matrices for key, query, value projecting them to different spaces.

## Multi-headed attention

In practice self-attention layer applies independent linear transformations for key, query and vlaue vectors. These transformations project(embedding dim) and each projection has it's own set of learnable parameters allwoing the layer to learn different semantic aspect of the sequence.

This one piece of self-attention layer is called head. By having multiple attention heads, model can learn multiple aspects of similarity compared to one head.

This is similar to cnn's where each kernel learns something and get's averaged or max pooled.

*Visualization*
![alt mult-head attention](../notes/images/3-transformer-taxonomy/multi-head-attention-in-depth.png)

In [12]:
# Creating attention head
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        """_summary_

        Single self-attention layer to calculate attention

        Args:
            embed_dim (int): embedding dimension
            head_dim (int): head dimension
        """
        super().__init__()
        self.head_dim = head_dim
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)

    def scaled_dot_product_attention(self, key, query, value):
        """_summary_

        Accepts key, query, value
         * Key query are dot product, scaled by embedding dim
         * Normalized with softmax
         * Resulting attention weights are dot product with values

        Args:
            key (_type_): Input token keys
            query (_type_): Input token queries
            value (_type_): Input token values

        Returns:
            New embeddings.
        """

        # Get embedding dim for scaling
        dim_k = query.size(-1)
        scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
        # Normalizing
        weights = F.softmax(scores, dim=-1)
        # Multiply attention weights with values
        attn_outputs = torch.bmm(weights, value)
        return attn_outputs
        
    
    def forward(self, hidden_state):
        """_summary_

        Last dimension of hidden_state must match with embed_dim

        Args:
            hidden_state (torch array): hidden state is the input embeddings
        """
        attn_outputs = self.scaled_dot_product_attention(
            key=self.key(hidden_state),
            query=self.query(hidden_state),
            value=self.value(hidden_state),
        )
        return attn_outputs

Here hidden state is fed to key, query, value
key, query, value `(batch_dim, seq_len, emb_dim) * (emb_dim, head_dim) --> (batch_dim, seq_len, head_dim)`

Will be the resulting tensors shape from linear transformations.

Although `head_dim` doesn't need to be smaller than `embed_dim`, In practice it's chosen as multiple of `embed_dim` to keep computation across each head constant.

For example, BERT has 13 attentions head, so the dimension of each head is 768/12 = 64.

Now we've a single attention head. Let's implement the multi-head attention.

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config) -> None:
        """_summary_

        MultiHead Attention constructor

        Args:
            config (dict): expects a dict class object with dot notation, generally is a transformer config
        """
        super().__init__()
        self.config = config
        embed_dim = self.config.hidden_size
        num_heads = self.config.num_attention_heads
        head_dim = embed_dim // num_heads
        self.heads = nn.ModuleList(
            [AttentionHead(embed_dim, head_dim) for _ in range(num_heads)]
        )
        # (batch_dim, seq_len, head_dim*num_heads#which is embed_dim) * (embed_dim, embed_dim) Head dim will be broadcasted 
        # Returns (batch_dim, seq_len, embed_dim)
        self.output_linear = nn.Linear(embed_dim, embed_dim)

    def forward(self, hidden_state):
        x = torch.cat(
            [h(hidden_state) for h in self.heads], # List of tensors to be concatenated
            dim=-1
        )
        x = self.output_linear(x)
        return x


The concatenated output from attention heads -> (batch_size, seq_len, num_heads * head_dim)
Linead layer (embed_dim, embed_dim)
embed_dim and num_heads * head_dim should be same for broadcasting to work... This is ensured by setting num of heads to embed_dim // num_heads.

(1, 5, 64*12) * (768, 768) --> (1, 5, 768) --> Output shape of multi head attention defined above.

In [14]:
multihead_attn = MultiHeadAttention(config)
attn_outputs = multihead_attn(input_embeds)
attn_outputs.size()

torch.Size([1, 5, 768])

Before wrapping attention, let's visualize it.

In [15]:
from bertviz import head_view
from transformers import AutoModel

model = AutoModel.from_pretrained(model_ckpt, output_attentions=True)

sentence_a = "time flies like an arrow"
sentence_b = "fruit flies like a banana"

viz_inputs = tokenizer(sentence_a, sentence_b, return_tensors="pt")
attention = model(**viz_inputs).attentions
sentence_b_start = (viz_inputs.token_type_ids == 0).sum(dim=1)
tokens = tokenizer.convert_ids_to_tokens(viz_inputs.input_ids[0])

head_view(attention, tokens, sentence_b_start, heads=[8])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


<IPython.core.display.Javascript object>

## THe Feed-Forward Layer

The feed-forward layer in the encoder and decoder is just a simple two-layer fully connected network but with a twist: Insttead of procesing the whole sequence of embeddings as a single vector, it processes each embedding *independently*.

For this reason, this layer is often called as a position-wise feed-forward layer. Also referred as one-dimension convolution with a kernel size of one(Open AI GPT codebase nomenclature).

Rule of thumb is fir the hidden size of the first layer to be four times the size of the embeddings and a GELU activation.

This is where most of the capacity and memorization is hypothesized to happend and also most scaled part when scaling up the models.

In [16]:
config.hidden_size, config.intermediate_size // 4 

(768, 768)

In [17]:
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.linear_2(x)
        x = self.gelu(x)
        x = self.dropout(x)
        return x

When we pass a input of shape `(batch_size, seq_len, hidden_dim)` it acts independently on hidden_dim of each token which is the embeddings.

In [18]:
feed_forward = FeedForward(config)
ff_outputs = feed_forward(attn_outputs)
ff_outputs.size()

torch.Size([1, 5, 768])

The FeedForward layer is the compute on attention_outputs.

## Adding Layer Normalization

The transformer architecture makes use of layer normalization and skip connections. Layer normalization for maintaining zero mean and unit variance distributions and skip connections to pass a tensor to next layer without processing to have a direct gradient path during initial backward passes.

There are two main choices adopted in the literature:

*Post layer Normalization*:

This is used in transformer paper. LayerNorm's after skip connection. This will cause gradients to diverge and an leraning rate warm-up will be required with this setup.

![alt post layer norm](../notes/images/3-transformer-taxonomy/post-layer-norm.png)

*Pre layer Normalization*:

This is a more common aproach, where LayerNorm is used with skip connections. Making it more stable training and without learning rate warm-up.

![alt pre layer norm](../notes/images/3-transformer-taxonomy/pre-layer-norm.png)




We'll use pre-layer normalization

In [19]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.ff = FeedForward(config)

    def forward(self, x):
        # Normalize embeddings
        hidden_state = self.layer_norm_1(x)

        # Apply attention with skip connection
        x = x + self.attention(hidden_state)

        # Apply feed forward with a skip-connection
        x = x + self.ff(self.layer_norm_2(x))
        return x

In [20]:
ln = nn.LayerNorm(config.hidden_size)
ln(input_embeds).shape

torch.Size([1, 5, 768])

In [21]:
encoder_layer = TransformerEncoderLayer(config)
input_embeds.shape, encoder_layer(input_embeds).shape

(torch.Size([1, 5, 768]), torch.Size([1, 5, 768]))

One problem with this approach is, encoder has not information on the position of the embeeddings. With Multi-Head attention layer as a fancy weighted sum.

## Positional Embedddings

Basic idea: Augment token embeddings with a position-dependent pattern of values arranged in a vector. With this attentions head and feed-forwarad layers will have the postional information into their transformations.

When the pretraining dataset is large, we can use a learnable pattern for this with position index as input.

Let's create custom `embeddings` combines a embedding layer with `input_ids` and another embedding layer with `position_ids`. The resulting embeddings is simply the sum of both embeddings.

In [22]:
class Embeddings(nn.Module):
    def __init__(self, config: AutoConfig):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        # Create position IDs for input sequence
        # 1st index gives the token length
        seq_length = input_ids.size(1)
        postion_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0) # Unzqueeze at 0 to add batch dim of 1

        # Create token and position embeddings
        token_embeds = self.token_embeddings(input_ids)
        position_embeds = self.position_embeddings(postion_ids)

        # Augment both embeddings
        embeddings = token_embeds + position_embeds

        # Apply layer norm
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

In [23]:
embedding_layer = Embeddings(config)
embedding_layer(inputs.input_ids).shape

torch.Size([1, 5, 768])

While learnable position embeddings are easy to implement there are other alternatives:

![alt positional-embedding-alternatives](../notes/images/3-transformer-taxonomy/position-embedding-alternatives.png)

In [24]:
# Create Encoder
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        
    def forward(self, x):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [25]:
tf_encoder = TransformerEncoder(config)

In [26]:
tf_encoder(inputs.input_ids).size()

torch.Size([1, 5, 768])

Let's build a classifier next.

## Adding a Classification Head

Transformer models are usually divided into a task-independent body and task-specific head. 

What we've built is the body, so if we want to classify a text, we will need to attach a classification head to the body.

Right now we've a hidden state for each token, but we need to make one prediction.

Traditionally we can use the hidden state of first token to make prediction and attach a dropout, linear layer to make a classification prediction.

In [27]:
from transformers import BertConfig
class TransformerForTextClassification(nn.Module):
    def __init__(self, config: BertConfig):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, x):
        x = self.encoder(x)[:, 0, :] # select hidden state of [CLS] token
        x = self.dropout(x)
        x = self.classifier(x)
        return x

In [28]:
config.num_labels = 3
tf_classifier = TransformerForTextClassification(config=config)
preds = tf_classifier(inputs.input_ids)

In [29]:
preds.shape

torch.Size([1, 3])

In [30]:
preds

tensor([[ 0.7946, -1.3903, -0.9027]], grad_fn=<AddmmBackward0>)

In [31]:
F.softmax(preds)

  F.softmax(preds)


tensor([[0.7718, 0.0868, 0.1414]], grad_fn=<SoftmaxBackward0>)

## The Decoder

Main difference between the decoder and encoder is that the decoder has *two* attention sublayers:

*Masked multi-head self-attention layer*:

Ensures that teh tokens we generate at each timestep are only based on the past outputs and the current token being predicted. Without this decoder could cheat during training by simply copying the target translations.

*Encoder-decoder attention layer*:

Performs mult-head attention over the output key and value vectors of the encoder stack, with the intermediate representations of the decoder acting as queries.

This way the encoder-decoder attention layer learns how to relate tokens from two different sequences, such as two different langauages. The decoder has access to the encoder keys and values in each block.

Let's implement the masked attention layer.
I've already implemented in [nanogpt notebook](https://github.com/JpChii/nanogpt/blob/main/gpt.ipynb) using tril which is used in the book as well.

In [105]:
seq_len = inputs.input_ids.size(-1)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)
mask[0], mask.shape

(tensor([[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1.]]),
 torch.Size([1, 5, 5]))

With `tril()` we've created a lower triangular matrix, next we'll use this with masked_fill to set 0's to infinity to avoid the model looking into the future.

In [106]:
scores.masked_fill(mask == 0, -float("inf"))

tensor([[[ 2.4924e+01,        -inf,        -inf,        -inf,        -inf],
         [ 2.9782e-01,  2.9413e+01,        -inf,        -inf,        -inf],
         [ 7.9484e-01,  9.5575e-01,  2.6958e+01,        -inf,        -inf],
         [ 2.2592e-01,  1.1878e+00, -4.3528e-01,  2.7287e+01,        -inf],
         [-1.4588e-02, -1.4463e+00,  4.2961e-02,  4.5015e-01,  2.6981e+01]]],
       grad_fn=<MaskedFillBackward0>)

In [107]:
scores.dtype

torch.float32

In [100]:
scores = F.softmax(scores, dim=-1)

tensor([[[1.0000e+00, 2.0184e-11, 3.3179e-11, 1.8784e-11, 1.4769e-11],
         [2.2675e-13, 1.0000e+00, 4.3782e-13, 5.5216e-13, 3.9635e-14],
         [4.3407e-12, 5.0985e-12, 1.0000e+00, 1.2686e-12, 2.0466e-12],
         [1.7676e-12, 4.6252e-12, 9.1251e-13, 1.0000e+00, 2.2119e-12],
         [1.8875e-12, 4.5089e-13, 1.9993e-12, 3.0041e-12, 1.0000e+00]]],
       grad_fn=<SoftmaxBackward0>)

In [104]:
scores.dtype

torch.float32

By setting values to -inf and the using softmax to normalize the exponentional of -inf which is done by softmax will set it 0.
Implmenting the same.

In [113]:
def scaled_dot_product_attention(query, key, value, mask=None):
    dim_k = query.size(1)
    scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
    __mask = torch.tril(torch.ones(dim_k, dim_k)).unsqueeze(0)
    print(mask)
    if mask:
        scores = scores.masked_fill(__mask == 0, -float("inf"))

    weights = F.softmax(scores, dim=-1)
    return weights.bmm(value)

In [114]:
scaled_dot_product_attention(input_embeds, input_embeds, input_embeds, mask=True)

True


tensor([[[-0.6512, -0.3236,  2.5604,  ..., -0.0562, -1.4029,  0.1642],
         [-0.3357, -0.1322, -0.3980,  ..., -3.1651, -0.1336,  0.5005],
         [-1.6203,  1.2405, -0.7478,  ...,  1.5923, -0.0281, -0.7642],
         [ 0.4592, -0.7998,  0.8267,  ..., -0.7702,  1.3250, -0.0132],
         [-0.8928,  0.8185,  1.5782,  ...,  2.0118,  0.9097,  1.7563]]],
       grad_fn=<BmmBackward0>)

Encoder-decoder attention:
* inputs --> x(from decoder embeddings) and hidden_state(from encoder)
* use x as query, project hidden_state to key and value using linear layers
* perform scaled_dot_product_attention without mask

Let's create a multi-attention for decoder with masked multi-head self attention and encoder-decoder attention layer.

In [115]:
class MaskedAttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.head_dim = head_dim
        self.embed_dim = embed_dim
        self.query = nn.Linear(self.embed_dim, self.head_dim)
        self.key = nn.Linear(self.embed_dim, self.head_dim)
        self.value = nn.Linear(self.embed_dim, self.head_dim)

    def scaled_dot_product_attention(self, query, key, value, mask=True):
        dim_k = query.size(1)
        scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
        __mask = torch.tril(torch.ones(dim_k, dim_k)).unsqueeze(0)
        if mask:
            scores = scores.masked_fill(__mask == 0, -float("inf"))

        weights = F.softmax(scores, dim=-1)
        return weights.bmm(value)
    
    def forward(self, hidden_state):
        attn_outputs = self.scaled_dot_product_attention(
            query=self.query(hidden_state),
            key=self.key(hidden_state),
            value=self.value(hidden_state),
            mask=True
        )
        return attn_outputs

In [124]:
ma = MaskedAttentionHead(embed_dim=config.hidden_size, head_dim=config.num_attention_heads)
ma(input_embeds).size()

torch.Size([1, 5, 12])

In [117]:
class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = self.config.hidden_size
        self.num_heads = self.config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.heads = nn.ModuleList(
            [MaskedAttentionHead(self.embed_dim, self.head_dim) for _ in range(self.num_heads)]
        )
        self.output_linear = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, hidden_state):
        x = torch.cat(
            [h(hidden_state) for h in self.heads],
            dim=-1
        )
        x = self.output_linear(x)
        return x

In [122]:
mma = MaskedMultiHeadAttention(config)
mma(input_embeds).size()

torch.Size([1, 5, 768])

In [119]:
class DecoderAttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super().__init__()
        self.head_dim = head_dim
        self.embed_dim = embed_dim
        self.query = nn.Linear(self.embed_dim, self.head_dim)
        self.key = nn.Linear(self.embed_dim, self.head_dim)
        self.value = nn.Linear(self.embed_dim, self.head_dim)

    def scaled_dot_product_attention(self, query, key, value):
        dim_k = query.size(1)
        scores = torch.bmm(query, key.transpose(1, 2)) / sqrt(dim_k)
        weights = F.softmax(scores, dim=-1)
        return weights.bmm(value)
    
    def forward(self, hidden_state, encoder_hidden_state):
        """_summary_

        Args:
            hidden_state (torch.tensor): hidden_state form MultiHeadMaskedAttention
            encoder_hidden_state (_type_): hidden_state from Encoder

        Returns:
            _type_: _description_
        """
        attn_outputs = self.scaled_dot_product_attention(
            query=self.query(hidden_state),
            key=self.key(encoder_hidden_state),
            value=self.value(encoder_hidden_state),
        )
        return attn_outputs

In [121]:
dah = DecoderAttentionHead(embed_dim=config.hidden_size, head_dim=config.num_attention_heads)
dah(input_embeds, input_embeds).size()

torch.Size([1, 5, 12])

In [125]:
class MultiHeadDecoderAttentionHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = self.config.hidden_size
        self.num_heads = self.config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.heads = nn.ModuleList(
            [DecoderAttentionHead(self.embed_dim, self.head_dim) for _ in range(self.num_heads)]
        )
        self.output_linear = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, hidden_state, encoder_hidden_state):
        x = torch.cat(
            [h(hidden_state, encoder_hidden_state) for h in self.heads],
            dim=-1
        )
        x = self.output_linear(x)
        return x

In [126]:
mhda = MultiHeadDecoderAttentionHead(config)
mhda(input_embeds, input_embeds).size()

torch.Size([1, 5, 768])

In [127]:
# Implementing Decoder Layer
# Only changes will be forward call will accept an additional hidden state
class TransformerDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_3 = nn.LayerNorm(config.hidden_size)
        self.masked_attention = MaskedMultiHeadAttention(config)
        self.encoder_decoder_attention = MultiHeadDecoderAttentionHead(config)
        self.ff = FeedForward(config)

    def forward(self, x, encoder_hidden_state):
        x = self.layer_norm_1(x)
        x = x + self.masked_attention(x)
        x = x + self.layer_norm_2(x)
        # Encoder decoder attention
        x = x + self.encoder_decoder_attention(x, encoder_hidden_state)
        x = x + self.ff(self.layer_norm_3(x))
        return x

In [130]:
tf_dec_layer = TransformerDecoderLayer(config)
tf_dec_layer(input_embeds, input_embeds).size()

torch.Size([1, 5, 768])

In [129]:
class TransformerDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embeddings(config)
        self.layers = nn.ModuleList([TransformerDecoderLayer(config) for _ in range(config.num_hidden_layers)])
    
    def forward(self, x, encoder_hidden_state):
        x = self.embeddings(x)
        for layer in self.layers:
            x = layer(x, encoder_hidden_state)
        return x

In [135]:
tf_dec = TransformerDecoder(config)
tf_dec(inputs.input_ids, tf_encoder(inputs.input_ids)).size()

torch.Size([1, 5, 768])

In [136]:
class TransformerEncoderDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.encoder = TransformerEncoder(config)
        self.decoder = TransformerDecoder(config)
    
    def forward(self, x):
        hidden_state = self.encoder(x)
        x = self.decoder(x, hidden_state)
        return x

In [141]:
tf_encdec = TransformerEncoderDecoder(config)
last_hidden_states = tf_encdec(inputs.input_ids)
last_hidden_states.size()

torch.Size([1, 5, 768])

In [142]:
from torchinfo import summary
summary(tf_encdec)

Layer (type:depth-idx)                                       Param #
TransformerEncoderDecoder                                    --
├─TransformerEncoder: 1-1                                    --
│    └─Embeddings: 2-1                                       --
│    │    └─Embedding: 3-1                                   23,440,896
│    │    └─Embedding: 3-2                                   393,216
│    │    └─LayerNorm: 3-3                                   1,536
│    │    └─Dropout: 3-4                                     --
│    └─ModuleList: 2-2                                       --
│    │    └─TransformerEncoderLayer: 3-5                     7,087,872
│    │    └─TransformerEncoderLayer: 3-6                     7,087,872
│    │    └─TransformerEncoderLayer: 3-7                     7,087,872
│    │    └─TransformerEncoderLayer: 3-8                     7,087,872
│    │    └─TransformerEncoderLayer: 3-9                     7,087,872
│    │    └─TransformerEncoderLayer: 3-10       

Now we've implemented the decoder with encoder decoder attention. Hidden state of decoder is used t project key and values in encoder decoder cross attention layer.
At the end we have 768 dim hidden_state for each token which we can use for predictions by passing texts from two different languages to encoder and decoder layer.
Right now we've just got a single input for encoder and decoder. The application of this might probably be an autocomplete model since we're feeding the same text.

*Transformer tree of life*
![alt](../notes/images/3-transformer-taxonomy/transformer-tree-of-life.png)

Look at the book for a short description on each of these models.

To do in future:
* Try training a model with this architecutre and get the predictions
* Add the short description for each of these models