# Transformer
We will implement the Transformer architecture presented in class.

In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

We start with the attention. Define a class `TransformerAttention` that will contain all the functions related to the Transformer's attention that we need. Add an `__init__` method that takes `hidden_dim` and `num_heads` as parameters.

In [42]:
class TransformerAttention(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

Now we're adding its functions one after the other. We start with the best part: the attention function. Implement scaled-dot product attention when given `query`, `key`, and `value` tensors as inputs. The dimensions of these tensors are: `[batch_size, sequence_length, head_dim]`. Scaled dot-product attention is defined as:
$$\text{DPA}(Q, K, V) = \text{softmax}(\frac{Q K^\top}{\sqrt{d}}) V$$

In [43]:
import math

def dot_product_attention(self, query, key, value):
    # query: [batch_size, seq_length, head_dim]
    # key: [batch_size, seq_length, head_dim]
    # value: [batch_size, seq_length, head_dim]

    attention_scores = query @ key.transpose(-1, -2) # [batch_size, seq_length, seq_length]
    head_dim = query.size(-1)
    scaled_attention_scores = (attention_scores) / math.sqrt(head_dim)
    attention_weights = F.softmax(scaled_attention_scores, dim=-1)

    return attention_weights @ value # [batch_size, seq_length, head_dim]

TransformerAttention.dot_product_attention = dot_product_attention

Implement a function `split_to_heads` that takes a tensor of dimensions `[?, ?, hidden_dim]` and splits it into `num_heads` tensors of size `[?, ?, head_dim]`, where $\text{head\_dim} = \frac{\text{hidden\_dim}}{\text{num\_heads}}$. The `?` dimensions are the same as before, but your implementation should be independent of them.

In [44]:
def split_to_heads(self, tensor):
    assert self.hidden_dim % self.num_heads == 0, "Hidden dim must be divisible by the number of heads"
    head_dim = int(self.hidden_dim / self.num_heads)
    return tensor.split(head_dim, dim=-1)

TransformerAttention.split_to_heads = split_to_heads

model = TransformerAttention(15, 3)
my_tensor = torch.zeros(10, 5, 15)
model.split_to_heads(my_tensor)

(tensor([[[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0.]],
 
         [[0., 0., 0.,

Now implement the `forward` method of `TransformerAttention` (and extend the `__init__` method if necessary). It should:
1. project its inputs into `query`, `key` and `value` tensors with 3 separate linear layers
2. split the tensors into chunks for each head to process
3. perform attention for each head separately
4. concatenate the results
5. run the output through another linear layer

Step 1 and 2 look reversed from the diagram we saw in class, but this is more intuitive and also how Hugging Face implements these operations.

In [45]:
def __init__(self, hidden_dim, num_heads):
    super(TransformerAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads

    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)

    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

def forward(self, x):
    # x: [batch_size, seq_length, hidden_dim]

    # project inputs into query, key and value tensors
    q = self.query_projection(x) # [batch_size, seq_length, hidden_dim]
    k = self.key_projection(x) # [batch_size, seq_length, hidden_dim]
    v = self.value_projection(x) # [batch_size, seq_length, hidden_dim]

    # split tensors into chunks for each head to process
    q_split = self.split_to_heads(q) # num_heads * [batch_size, seq_length, head_dim]
    k_split = self.split_to_heads(k) # num_heads * [batch_size, seq_length, head_dim]
    v_split = self.split_to_heads(v) # num_heads * [batch_size, seq_length, head_dim]

    # perform attention for each head separately
    attentions = [self.dot_product_attention(q_head, k_head, v_head) for q_head, k_head, v_head in zip(q_split, k_split, v_split)] # num_heads * [batch_size, seq_length, head_dim]

    # concatenate the results
    concatenated_attentions = torch.cat(attentions, dim=-1) # [batch_size, seq_length, hidden_dim]

    # rund the output through another linear layer
    output = self.output_projection(concatenated_attentions) # [batch_size, seq_length, hidden_dim]

    return output

TransformerAttention.__init__ = __init__
TransformerAttention.forward = forward

Create a class `TransformerAttentionBlock` that runs Transformer attention, then adds the input as a residual to the output and performs layer normalization.

In [46]:
class TransformerAttentionBlock(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.attention = TransformerAttention(hidden_dim, num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        attention_output = self.attention(x)
        return self.layer_norm(x + attention_output)

Create a class `FeedForwardNetwork` that consists of two linear layers with a ReLU in between. Also add a residual connection from the input to the output and apply layer normalization.

In [47]:
class FeedForwardNetwork(nn.Module):
    
    def __init__(self, hidden_dim, inner_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, inner_dim)
        self.linear2 = nn.Linear(inner_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        output = self.linear1(x)
        output = F.relu(output)
        output = self.linear2(output)
        return self.layer_norm(x + output)

Now we can combine the `TransformerAttentionBlock` and the `FeedForwardNetwork` into a `TransformerLayer`. 

In [48]:
class TransformerLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super().__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)
    
    def forward(self, x):
        output = self.self_attention(x)
        return self.ffn(output)

We are ready to compose our `TransformerEncoder` of a given number of `TransformerLayer`s.

In [49]:
class TransformerEncoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([TransformerLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Let's test our implementation with the hyperparameters...

In [50]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
batch_size = 2
x_len = 10

... and check if it produces the correct output shapes.

In [51]:
x = torch.randn(batch_size, x_len, embedding_dim)
encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
output = encoder(x)
assert list(output.shape) == [batch_size, x_len, hidden_dim], "Wrong output shape"

## Transformer Decoder
For the Transformer decoder, two components are missing.
1. A causal mask in the `TransformerAttention`.
2. A cross-attention module in the `TransformerLayer`.

We start by generalizing the `TransformerAttention` class to use a causal mask in `dot_product_attention` if it is used for decoder self-attention. We check this by accessing an `is_decoder_self_attention` attribute of `self`, which we have to add as an argument to `TransformerAttention`'s `__init__` method first.

In [52]:
# Add an `is_decoder_self_attention` attribute to TransformerAttention.__init__
def __init__(self, hidden_dim, num_heads, is_decoder_self_attention=False):
    super(TransformerAttention, self).__init__()  # we get an error here if we call super().__init__()

    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.is_decoder_self_attention = is_decoder_self_attention

    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)

    self.output_projection = nn.Linear(hidden_dim, hidden_dim)


TransformerAttention.__init__ = __init__

In [53]:
# Change the dot_product attention to use a causal mask in case it is used in the decoder self-attention.
def dot_product_attention(self, query, key, value):
    # query: [batch_size, seq_length, head_dim]
    # key: [batch_size, seq_length, head_dim]
    # value: [batch_size, seq_length, head_dim]

    attention_scores = query @ key.transpose(-1, -2) # [batch_size, seq_length, seq_length]
    head_dim = query.size(-1)
    scaled_attention_scores = (attention_scores) / math.sqrt(head_dim)

    # add causal mask
    if self.is_decoder_self_attention:
        causal_mask = torch.ones_like(attention_scores).bool().tril()
        scaled_attention_scores = scaled_attention_scores.masked_fill(~causal_mask, -1e8) # cannot set masked out values to 0 because this will mess up softmax afterwards

    attention_weights = F.softmax(scaled_attention_scores, dim=-1)

    return attention_weights @ value # [batch_size, seq_length, head_dim]

TransformerAttention.dot_product_attention = dot_product_attention

Now we add cross-attention. We do this by updating the `TransformerAttention`'s `forward` method to take `encoder_hidden_states` as an optional input. Check the lecture slides to see which input gets projected into queries, keys and values.

In [54]:
def forward(self, x, encoder_hidden_states=None):
    # x: [batch_size, seq_length, hidden_dim]

    # project inputs into query, key and value tensors
    q = self.query_projection(x) # [batch_size, seq_length, hidden_dim]

    if encoder_hidden_states is None:
        # we're in the encoder
        k = self.key_projection(x) # [batch_size, seq_length, hidden_dim]
        v = self.value_projection(x) # [batch_size, seq_length, hidden_dim]
    else:
        # we're in the decoder
        k = self.key_projection(encoder_hidden_states) # [batch_size, seq_length, hidden_dim]
        v = self.key_projection(encoder_hidden_states) # [batch_size, seq_length, hidden_dim]

    # split tensors into chunks for each head to process
    q_split = self.split_to_heads(q) # num_heads * [batch_size, seq_length, head_dim]
    k_split = self.split_to_heads(k) # num_heads * [batch_size, seq_length, head_dim]
    v_split = self.split_to_heads(v) # num_heads * [batch_size, seq_length, head_dim]

    # perform attention for each head separately
    attentions = [self.dot_product_attention(q_head, k_head, v_head) for q_head, k_head, v_head in zip(q_split, k_split, v_split)] # num_heads * [batch_size, seq_length, head_dim]

    # concatenate the results
    concatenated_attentions = torch.cat(attentions, dim=-1) # [batch_size, seq_length, hidden_dim]

    # rund the output through another linear layer
    output = self.output_projection(concatenated_attentions) # [batch_size, seq_length, hidden_dim]

    return output
    
TransformerAttention.forward = forward

We have to extend the `TransformerAttentionBlock` to allow that additional argument in its `forward` method.

In [55]:
def forward(self, x, encoder_hidden_states=None):
    output = self.attention(x, encoder_hidden_states)
    return self.layer_norm(x + output)

TransformerAttentionBlock.forward = forward

Now we implement a `TransformerDecoderLayer` that consists of decoder self-attention, cross-attention and a feed-forward network. In the `forward` method, use the encoder hidden states as inputs to the cross-attention module.

In [56]:
class TransformerDecoderLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super(TransformerDecoderLayer, self).__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.cross_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)

        self.self_attention.attention.is_decoder_self_attention = True
    
    def forward(self, x, encoder_hidden_states):
        output = self.self_attention(x)
        output = self.cross_attention(output, encoder_hidden_states)
        output = self.ffn(output)
        return output

Add a `TransformerDecoder` that holds the decoder layers.

In [57]:
class TransformerDecoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super(TransformerDecoder, self).__init__()
        self.layers = nn.ModuleList([TransformerDecoderLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, encoder_hidden_states):
        for layer in self.layers:
            x = layer(x, encoder_hidden_states)
        return x

## Transformer Seq2seq Model
We can now put everything together. Create and instantiate a Transformer model that encodes a random input `x`, then generates an output hidden representation for each decoder input `y` that we could then feed into a classifier to predict the words.

In [58]:
class TransformerModel(nn.Module):
    
    def __init__(self, hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads):
        super().__init__()
        self.encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
        self.decoder = TransformerDecoder(hidden_dim, ffn_dim, num_decoder_layers, num_heads)
    
    def forward(self, x, y):
        encoder_hidden_states = self.encoder(x)
        return self.decoder(y, encoder_hidden_states)

We will use the following hyperparameters.

In [59]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
num_decoder_layers = 2
batch_size = 2
x_len = 10
y_len = 7

Now we can run our model and test that the output dimensions are correct.

In [60]:
x = torch.randn(batch_size, x_len, embedding_dim)
y = torch.randn(batch_size, y_len, embedding_dim)
model = TransformerModel(hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads)
output = model(x, y)
assert list(output.shape) == [batch_size, y_len, hidden_dim], "Wrong output shape"
num_model_params = sum(param.numel() for param in model.parameters())
assert num_model_params == 50480, f"Wrong number of parameters: {num_model_params}"

## What is missing for a real implementation?
Look at the [implementation of the Transformer layer for BERT by HuggingFace](https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/bert/modeling_bert.py#L223), from line 223 until 641.

**Question:** Name the things you see HuggingFace's implementation do that is still missing in your own implementation.

**Answer:** 
