# Transformer
We will implement the Transformer architecture presented in class.

In [4]:
import torch
import torch.nn as nn

We start with the attention. Define a class `TransformerAttention` that will contain all the functions related to the Transformer's attention that we need. Add an `__init__` method that takes `hidden_dim` and `num_heads` as parameters.

In [5]:
class TransformerAttention(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        #self.head_dim = hidden_dim // num_heads

Now we're adding its functions one after the other. We start with the best part: the attention function. Implement scaled-dot product attention when given `query`, `key`, and `value` tensors as inputs. The dimensions of these tensors are: `[batch_size, sequence_length, head_dim]`. Scaled dot-product attention is defined as:
$$\text{DPA}(Q, K, V) = \text{softmax}(\frac{Q K^\top}{\sqrt{d}}) V$$

In [6]:
import math
import torch.nn.functional as F

def dot_product_attention(self, query, key, value):
    # [seq_len, head_dim] x [head_dim, seq_len] -> [seq_len, seq_len]
    # query: [batch_size, seq_len, head_dim]
    # key: [batch_size, seq_len, head_dim]
    # value: [batch_size, seq_len, head_dim]
    # query x key.T --> dircetly does not work
    head_dim = self.hidden_dim // self.num_heads
    attention_scores = query @ key.transpose(-1, -2) # [batch_size, seq_len, seq_len]
    attention_scores = attention_scores / math.sqrt(head_dim)

    attention_weights = F.softmax(attention_scores, dim=-1) # [batch_size, seq_len, seq_len]
    output = attention_weights @ value # [batch_size, seq_len, head_dim]

    return output



TransformerAttention.dot_product_attention = dot_product_attention

Implement a function `split_to_heads` that takes a tensor of dimensions `[?, ?, hidden_dim]` and splits it into `num_heads` tensors of size `[?, ?, head_dim]`, where $\text{head\_dim} = \frac{\text{hidden\_dim}}{\text{num\_heads}}$. The `?` dimensions are the same as before, but your implementation should be independent of them.

In [7]:
def split_to_heads(self, tensor):
    assert self.hidden_dim % self.num_heads == 0, "Hidden dim must be divisible by the number of heads"
    head_dim = self.hidden_dim // self.num_heads
    return tensor.split(head_dim, dim=-1) # [batch_size, seq_len, hidden_dim] -> [batch_size, seq_len, num_heads, head_dim]

TransformerAttention.split_to_heads = split_to_heads

Now implement the `forward` method of `TransformerAttention` (and extend the `__init__` method if necessary). It should:
1. project its inputs into `query`, `key` and `value` tensors with 3 separate linear layers
2. split the tensors into chunks for each head to process
3. perform attention for each head separately
4. concatenate the results
5. run the output through another linear layer

Step 1 and 2 look reversed from the diagram we saw in class, but this is more intuitive and also how Hugging Face implements these operations.

In [8]:
def __init__(self, hidden_dim, num_heads):
    super(TransformerAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads

    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)

    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

def forward(self, x):
    # 1. project x to q,k,v
    q = self.query_projection(x)
    k = self.key_projection(x)
    v = self.value_projection(x) # [batch_size, seq_len, hidden_dim]
    
    # 2. split to num_head tensors
    h_q = self.split_to_heads(q) # num_heads * [batch_size, seq_len, head_dim]
    h_k = self.split_to_heads(k)
    h_v = self.split_to_heads(v)

    # 3. apply dot product attention
    attention_outputs = []
    for h_q, h_k, h_v in zip(h_q, h_k, h_v):
        attention_outputs.append(self.dot_product_attention(h_q, h_k, h_v))
    #list of length num_heads with tensors of shape: [batch_size, seq_len, head_dim]
    
    # 4. concat heads
    attention_output_tensor = torch.cat(attention_outputs, dim=-1) # [batch_size, seq_len, hidden_dim]

    # 5. linear layer
    output = self.output_projection(attention_output_tensor) # [batch_size, seq_len, hidden_dim]

    return output


TransformerAttention.__init__ = __init__
TransformerAttention.forward = forward

Create a class `TransformerAttentionBlock` that runs Transformer attention, then adds the input as a residual to the output and performs layer normalization.

In [9]:
class TransformerAttentionBlock(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.attention = TransformerAttention(hidden_dim, num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        attention_output = self.attention(x)
        output = self.layer_norm(x + attention_output)
        return output

Create a class `FeedForwardNetwork` that consists of two linear layers with a ReLU in between. Also add a residual connection from the input to the output and apply layer normalization.

In [10]:
class FeedForwardNetwork(nn.Module):
    
    def __init__(self, hidden_dim, inner_dim):
        super().__init__()
        self.fc1 = nn.Linear(hidden_dim, inner_dim)
        self.fc2 = nn.Linear(inner_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        output = self.fc1(x)
        output = F.relu(output)
        output = self.fc2(output)
        output = self.layer_norm(x + output)
        return output

Now we can combine the `TransformerAttentionBlock` and the `FeedForwardNetwork` into a `TransformerLayer`. 

In [11]:
class TransformerLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super().__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)
    
    def forward(self, x):
        output = self.self_attention(x)
        output = self.ffn(output)
        return output

We are ready to compose our `TransformerEncoder` of a given number of `TransformerLayer`s.

In [12]:
class TransformerEncoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([TransformerLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Let's test our implementation with the hyperparameters...

In [13]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
batch_size = 2
x_len = 10

... and check if it produces the correct output shapes.

In [14]:
x = torch.randn(batch_size, x_len, embedding_dim)
encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
output = encoder(x)
assert list(output.shape) == [batch_size, x_len, hidden_dim], "Wrong output shape"

## Transformer Decoder
For the Transformer decoder, two components are missing.
1. A causal mask in the `TransformerAttention`.
2. A cross-attention module in the `TransformerLayer`.

We start by generalizing the `TransformerAttention` class to use a causal mask in `dot_product_attention` if it is used for decoder self-attention. We check this by accessing an `is_decoder_self_attention` attribute of `self`, which we have to add as an argument to `TransformerAttention`'s `__init__` method first.

In [15]:
# Add an `is_decoder_self_attention` attribute to TransformerAttention.__init__
def __init__(self, hidden_dim, num_heads, is_decoder_self_attention=False):
    super(TransformerAttention, self).__init__()  # we get an error here if we call super().__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.is_decoder_self_attention = is_decoder_self_attention

    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)

    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

TransformerAttention.__init__ = __init__

In [25]:
# Change the dot_product attention to use a causal mask in case it is used in the decoder self-attention.
def dot_product_attention(self, query, key, value):
        head_dim = self.hidden_dim // self.num_heads
        attention_scores = query @ key.transpose(-1, -2) # [batch_size, seq_len, seq_len]
        attention_scores = attention_scores / math.sqrt(head_dim)
        
        if self.is_decoder_self_attention:
                seq_len = attention_scores.size(1)
                mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(attention_scores.device)
                attention_scores = attention_scores.masked_fill(mask.to(torch.bool), float('-inf'))

        attention_weights = F.softmax(attention_scores, dim=-1) # [batch_size, seq_len, seq_len]
        output = attention_weights @ value # [batch_size, seq_len, head_dim]

        return output

TransformerAttention.dot_product_attention = dot_product_attention

Now we add cross-attention. We do this by updating the `TransformerAttention`'s `forward` method to take `encoder_hidden_states` as an optional input. Check the lecture slides to see which input gets projected into queries, keys and values.

In [17]:
def forward(self, x, encoder_hidden_states=None):
    # 1. project x to q,k,v
    q = self.query_projection(x)
    if encoder_hidden_states is None:
        k = self.key_projection(x)
        v = self.value_projection(x)
    else:
        k = self.key_projection(encoder_hidden_states)
        v = self.value_projection(encoder_hidden_states)
    
    # 2. split to num_head tensors
    h_q = self.split_to_heads(q) # num_heads * [batch_size, seq_len, head_dim]
    h_k = self.split_to_heads(k)
    h_v = self.split_to_heads(v)

    # 3. apply dot product attention
    attention_outputs = []
    for h_q, h_k, h_v in zip(h_q, h_k, h_v):
        attention_outputs.append(self.dot_product_attention(h_q, h_k, h_v))
    #list of length num_heads with tensors of shape: [batch_size, seq_len, head_dim]
    
    # 4. concat heads
    attention_output_tensor = torch.cat(attention_outputs, dim=-1) # [batch_size, seq_len, hidden_dim]

    # 5. linear layer
    output = self.output_projection(attention_output_tensor) # [batch_size, seq_len, hidden_dim]

    return output
    
TransformerAttention.forward = forward

We have to extend the `TransformerAttentionBlock` to allow that additional argument in its `forward` method.

In [18]:
def forward(self, x, encoder_hidden_states=None):
    output = self.attention(x, encoder_hidden_states)
    output = self.layer_norm(x + output)
    return output

TransformerAttentionBlock.forward = forward

Now we implement a `TransformerDecoderLayer` that consists of decoder self-attention, cross-attention and a feed-forward network. In the `forward` method, use the encoder hidden states as inputs to the cross-attention module.

In [20]:
class TransformerDecoderLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super().__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.self_attention.attention.is_decoder_self_attention = True
        self.cross_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)
    
    def forward(self, x, encoder_hidden_states):
        output = self.self_attention(x)
        output = self.cross_attention(output, encoder_hidden_states)
        output = self.ffn(output)
        return output

Add a `TransformerDecoder` that holds the decoder layers.

In [21]:
class TransformerDecoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([TransformerDecoderLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, encoder_hidden_states):
        for layer in self.layers:
            x = layer(x, encoder_hidden_states)
        return x

## Transformer Seq2seq Model
We can now put everything together. Create and instantiate a Transformer model that encodes a random input `x`, then generates an output hidden representation for each decoder input `y` that we could then feed into a classifier to predict the words.

In [22]:
class TransformerModel(nn.Module):
    
    def __init__(self, hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads):
        super().__init__()
        self.encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
        self.decoder = TransformerDecoder(hidden_dim, ffn_dim, num_decoder_layers, num_heads)
    
    def forward(self, x, y):
        encoder_hidden_states = self.encoder(x)
        output = self.decoder(y, encoder_hidden_states)
        return output

We will use the following hyperparameters.

In [23]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
num_decoder_layers = 2
batch_size = 2
x_len = 10
y_len = 7

Now we can run our model and test that the output dimensions are correct.

In [26]:
x = torch.randn(batch_size, x_len, embedding_dim)
y = torch.randn(batch_size, y_len, embedding_dim)
model = TransformerModel(hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads)
output = model(x, y)
assert list(output.shape) == [batch_size, y_len, hidden_dim], "Wrong output shape"
num_model_params = sum(param.numel() for param in model.parameters())
assert num_model_params == 50480, f"Wrong number of parameters: {num_model_params}"

## What is missing for a real implementation?
Look at the [implementation of the Transformer layer for BERT by HuggingFace](https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/bert/modeling_bert.py#L223), from line 223 until 641.

**Question:** Name the things you see HuggingFace's implementation do that is still missing in your own implementation.

**Answer:** 

possible corrections: 
- Dropout
- Positional encoding
- 
