# Transformer
We will implement the Transformer architecture presented in class.

In [1]:
import torch
import torch.nn as nn

We start with the attention. Define a class `TransformerAttention` that will contain all the functions related to the Transformer's attention that we need. Add an `__init__` method that takes `hidden_dim` and `num_heads` as parameters.

In [2]:
class TransformerAttention(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

Now we're adding its functions one after the other. We start with the best part: the attention function. Implement scaled-dot product attention when given `query`, `key`, and `value` tensors as inputs. The dimensions of these tensors are: `[batch_size, sequence_length, head_dim]`. Scaled dot-product attention is defined as:
$$\text{DPA}(Q, K, V) = \text{softmax}(\frac{Q K^\top}{\sqrt{d}}) V$$

In [3]:
import math
import torch.nn.functional as F

def dot_product_attention(self, query, key, value):
    head_dim = query.size(-1)
    attention_scores = torch.matmul(query, key.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(head_dim)
    attention_probs = F.softmax(attention_scores, dim=-1)
    return torch.matmul(attention_probs, value)

TransformerAttention.dot_product_attention = dot_product_attention

Implement a function `split_to_heads` that takes a tensor of dimensions `[?, ?, hidden_dim]` and splits it into `num_heads` tensors of size `[?, ?, head_dim]`, where $\text{head\_dim} = \frac{\text{hidden\_dim}}{\text{num\_heads}}$. The `?` dimensions are the same as before, but your implementation should be independent of them.

In [4]:
def split_to_heads(self, tensor):
    assert self.hidden_dim % self.num_heads == 0, "Hidden dim needs to be divisible by num_heads"
    head_dim = self.hidden_dim // self.num_heads
    # Alternative:
    # return tensor.view(tensor.size(0), tensor.size(1), self.num_heads, head_dim).unbind(2)
    return tensor.split(head_dim, dim=-1)

TransformerAttention.split_to_heads = split_to_heads

Now implement the `forward` method of `TransformerAttention` (and extend the `__init__` method if necessary). It should:
1. project its inputs into `query`, `key` and `value` tensors with 3 separate linear layers
2. split the tensors into chunks for each head to process
3. perform attention for each head separately
4. concatenate the results
5. run the output through another linear layer

Step 1 and 2 look reversed from the diagram we saw in class, but this is more intuitive and also how Hugging Face implements these operations.

In [5]:
def __init__(self, hidden_dim, num_heads):
    super(TransformerAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)
    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

def forward(self, x):
    query = self.query_projection(x)
    key = self.key_projection(x)
    value = self.value_projection(x)

    # Note: this operation can be vectorized for efficiency:
    # Instead of creating `num_heads` tensors, keep 1 tensor and add a dimension of size `num_heads`
    head_queries = self.split_to_heads(query)
    head_keys = self.split_to_heads(key)
    head_values = self.split_to_heads(value)
    
    attention_outputs = []
    for head_query, head_key, head_value in zip(head_queries, head_keys, head_values):
        attention_outputs.append(self.dot_product_attention(head_query, head_key, head_value))

    attention_output_tensor = torch.cat(attention_outputs, dim=-1)

    return self.output_projection(attention_output_tensor)

TransformerAttention.__init__ = __init__
TransformerAttention.forward = forward

Create a class `TransformerAttentionBlock` that runs Transformer attention, then adds the input as a residual to the output and performs layer normalization.

In [6]:
class TransformerAttentionBlock(nn.Module):
    
    def __init__(self, hidden_dim, num_heads):
        super().__init__()
        self.attention = TransformerAttention(hidden_dim, num_heads)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        output = self.attention(x)
        return self.layer_norm(x + output)

Create a class `FeedForwardNetwork` that consists of two linear layers with a ReLU in between. Also add a residual connection from the input to the output and apply layer normalization.

In [7]:
class FeedForwardNetwork(nn.Module):
    
    def __init__(self, hidden_dim, inner_dim):
        super().__init__()
        self.linear1 = nn.Linear(hidden_dim, inner_dim)
        self.linear2 = nn.Linear(inner_dim, hidden_dim)
        self.layer_norm = nn.LayerNorm(hidden_dim)
    
    def forward(self, x):
        output = self.linear1(x)
        output = F.relu(output)
        output = self.linear2(output)
        return self.layer_norm(x + output)

Now we can combine the `TransformerAttentionBlock` and the `FeedForwardNetwork` into a `TransformerLayer`. 

In [8]:
class TransformerLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super().__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)
    
    def forward(self, x):
        out = self.self_attention(x)
        out = self.ffn(out)
        return out

We are ready to compose our `TransformerEncoder` of a given number of `TransformerLayer`s.

In [9]:
class TransformerEncoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([TransformerLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

Let's test our implementation with the hyperparameters...

In [10]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
batch_size = 2
x_len = 10

... and check if it produces the correct output shapes.

In [11]:
x = torch.randn(batch_size, x_len, embedding_dim)
encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
output = encoder(x)
assert list(output.shape) == [batch_size, x_len, hidden_dim], "Wrong output shape"

## Transformer Decoder
For the Transformer decoder, two components are missing.
1. A causal mask in the `TransformerAttention`.
2. A cross-attention module in the `TransformerLayer`.

We start by generalizing the `TransformerAttention` class to use a causal mask in `dot_product_attention` if it is used for decoder self-attention. We check this by accessing an `is_decoder_self_attention` attribute of `self`, which we have to add as an argument to `TransformerAttention`'s `__init__` method first.

In [12]:
# Add an `is_decoder_self_attention` attribute to TransformerAttention.__init__
def __init__(self, hidden_dim, num_heads, is_decoder_self_attention=False):
    super(TransformerAttention, self).__init__()
    self.hidden_dim = hidden_dim
    self.num_heads = num_heads
    self.is_decoder_self_attention = is_decoder_self_attention
    self.query_projection = nn.Linear(hidden_dim, hidden_dim)
    self.key_projection = nn.Linear(hidden_dim, hidden_dim)
    self.value_projection = nn.Linear(hidden_dim, hidden_dim)
    self.output_projection = nn.Linear(hidden_dim, hidden_dim)

TransformerAttention.__init__ = __init__

In [13]:
# Change the dot_product attention to use a causal mask in case it is used in the decoder self-attention.
def dot_product_attention(self, query, key, value):
        head_dim = query.size(-1)
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(head_dim)
        if self.is_decoder_self_attention:
            batch_size, seq_length, _ = query.size()
            
            # the causal mask is a lower triangular matrix (see lecture slides)
            # it is 0 for attentions into the future, 1 otherwise
            causal_mask = torch.tril(torch.ones(batch_size, seq_length, seq_length))
            
            # or: HuggingFace's implementation of a causal mask (same result)
            seq_ids = torch.arange(seq_length)
            causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
            causal_mask = causal_mask.to(torch.float32)  # convert to float for next operations
            
            # to apply it to the attention scores, which can be in (-inf, inf),
            # we subtract a large number from the scores that would see the future
            # such that they become very small => probabilities in softmax become 0
            scores_to_mask = 1 - causal_mask
            attention_scores = attention_scores - scores_to_mask * 1e8
        attention_probs = F.softmax(attention_scores, dim=-1)
        return torch.matmul(attention_probs, value)

TransformerAttention.dot_product_attention = dot_product_attention

Now we add cross-attention. We do this by updating the `TransformerAttention`'s `forward` method to take `encoder_hidden_states` as an optional input. Check the lecture slides to see which input gets projected into queries, keys and values.

In [14]:
def forward(self, x, encoder_hidden_states=None):
    query = self.query_projection(x)
    
    # In cross_attention, keys and values come from encoder_hidden_states.
    if encoder_hidden_states is not None:
        key = self.key_projection(encoder_hidden_states)
        value = self.value_projection(encoder_hidden_states)
    else:
        key = self.key_projection(x)
        value = self.value_projection(x)

    head_queries = self.split_to_heads(query)
    head_keys = self.split_to_heads(key)
    head_values = self.split_to_heads(value)
    attention_outputs = []
    for head_query, head_key, head_value in zip(head_queries, head_keys, head_values):
        attention_outputs.append(self.dot_product_attention(head_query, head_key, head_value))
    attention_output_tensor = torch.cat(attention_outputs, dim=-1)
    return self.output_projection(attention_output_tensor)
    
TransformerAttention.forward = forward

We have to extend the `TransformerAttentionBlock` to allow that additional argument in its `forward` method.

In [15]:
def forward(self, x, encoder_hidden_states=None):
    output = self.attention(x, encoder_hidden_states)
    return self.layer_norm(x + output)

TransformerAttentionBlock.forward = forward

Now we implement a `TransformerDecoderLayer` that consists of decoder self-attention, cross-attention and a feed-forward network. In the `forward` method, use the encoder hidden states as inputs to the cross-attention module.

In [16]:
class TransformerDecoderLayer(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_heads):
        super().__init__()
        self.self_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.self_attention.attention.is_decoder_self_attention = True
        self.cross_attention = TransformerAttentionBlock(hidden_dim, num_heads)
        self.ffn = FeedForwardNetwork(hidden_dim, ffn_inner_dim)
    
    def forward(self, x, encoder_hidden_states):
        out = self.self_attention(x)
        out = self.cross_attention(out, encoder_hidden_states)
        out = self.ffn(out)
        return out

Add a `TransformerDecoder` that holds the decoder layers.

In [17]:
class TransformerDecoder(nn.Module):
    
    def __init__(self, hidden_dim, ffn_inner_dim, num_layers, num_heads):
        super().__init__()
        self.layers = nn.ModuleList([TransformerDecoderLayer(hidden_dim, ffn_inner_dim, num_heads) for _ in range(num_layers)])
    
    def forward(self, x, encoder_hidden_states):
        for layer in self.layers:
            x = layer(x, encoder_hidden_states)
        return x

## Transformer Seq2seq Model
We can now put everything together. Create and instantiate a Transformer model that encodes a random input `x`, then generates an output hidden representation for each decoder input `y` that we could then feed into a classifier to predict the words.

In [18]:
class TransformerModel(nn.Module):
    
    def __init__(self, hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads):
        super().__init__()
        self.encoder = TransformerEncoder(hidden_dim, ffn_dim, num_encoder_layers, num_heads)
        self.decoder = TransformerDecoder(hidden_dim, ffn_dim, num_decoder_layers, num_heads)
    
    def forward(self, x, y):
        encoder_hidden_states = self.encoder(x)
        return self.decoder(y, encoder_hidden_states)

We will use the following hyperparameters.

In [19]:
hidden_dim = 20
embedding_dim = hidden_dim
ffn_dim = 100
num_heads = 4
num_encoder_layers = 6
num_decoder_layers = 2
batch_size = 2
x_len = 10
y_len = 7

Now we can run our model and test that the output dimensions are correct.

In [20]:
x = torch.randn(batch_size, x_len, embedding_dim)
y = torch.randn(batch_size, y_len, embedding_dim)
model = TransformerModel(hidden_dim, ffn_dim, num_encoder_layers, num_decoder_layers, num_heads)
output = model(x, y)
assert list(output.shape) == [batch_size, y_len, hidden_dim], "Wrong output shape"
num_model_params = sum(param.numel() for param in model.parameters())
assert num_model_params == 50480, f"Wrong number of parameters: {num_model_params}"

## What is missing for a real implementation?
Look at the [implementation of the Transformer layer for BERT by HuggingFace](https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/bert/modeling_bert.py#L223), from line 223 until 641.

**Question:** Name the things you see HuggingFace's implementation do that is still missing in your own implementation.

**Answer:** 
- Batching/masks: when processing a batch in vectorized form, an additional argument `attention_mask` is supplied that masks padded positions
- Dropout
- Embeddings
- Absolute/relative position embeddings
- Caching past keys and values for faster generation
- Pruning attention heads (advanced technique, for efficiency)
- Outputting attention probabilities and hidden states for analysis (e.g. visualization)
- Bert model specifics