# The Transformer
In this lab scenario, you will implement *causal attention* for a transformer decoder model.
The transformer architecture was introduced in the [Attention Is All You Need](https://arxiv.org/abs/1706.03762) paper, and has dominated the field of language modeling.  
Here we will go through different parts of the transformer architecture and explain each of them briefly.

The whole notebook works fine on Colab CPU (~5min).
On the other hand, it's instantaneous on GPU (except for downloading the weights, which can take most of the time anyway).


## Transformer Overview

Transformer decoder models (such as LLaMa 3.1 and Mistral) are popular text-processing models.   

One can distinguish two versions of such models: **base** and **instruction-tuned**. The base models are usually trained on predicting the continuation of a given text (for each prefix they output a probability distribution over the next text fragment). In contrast, the instruction-tuned ones are base models that were additionally fine-tuned to follow instructions (often with a form of reinforcement learning from human feedback to generated text ([RLHF](https://en.wikipedia.org/wiki/Reinforcement_learning_from_human_feedback))).

### Tokenizer

The text is presented to the transformer as a sequence of **tokens**.  
Tokens are integers used to represent pieces of text.  
To be more precise: to convert text to tokens, we first prepare a dictionary of common text fragments.   
We usually want to have all possible letters in this dictionary, so that all texts can be tokenized.   
We then assign to each text piece from the dictionary an integer and use the dictionary to convert text into a sequence of tokens (integers).  
The class that converts text into tokens is called a tokenizer.  

In this lab scenario, we will use the OpenLLaMAv2 tokenizer and the HuggingFace library to tokenize text.   
HuggingFace contains a vast collection of transformer model weights and implementations along with training and inference code.  

In [1]:
# !pip install transformers==4.57.2  # This is the version this notebook was prepared for.

In [2]:
import warnings
from functools import partial
from pprint import pp

import torch
from torch import Tensor
from transformers import LlamaTokenizer, LlamaForCausalLM
from transformers.masking_utils import create_causal_mask
from tqdm import tqdm

# Suppress a useless warning from HuggingFace.
warnings.filterwarnings("ignore", message="(?s:.)*authentication is recommended but still optional to access public models")

In [3]:
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2", legacy=False)

text = "This is an example text that we will tokenize"
tokens_mask = tokenizer(text)
pp(tokens_mask)

detokenized = tokenizer.batch_decode(tokens_mask["input_ids"])
pp(detokenized)

tokenizer_config.json:   0%|          | 0.00/593 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/512k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/330 [00:00<?, ?B/s]

{'input_ids': [1, 660, 325, 371, 1938, 1880, 347, 389, 477, 8206, 753],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
['<s>',
 'This',
 'is',
 'an',
 'example',
 'text',
 'that',
 'we',
 'will',
 'token',
 'ize']


After tokenization the the HuggingFace tokenizer returns a sequence of tokens (`input_ids`) and information on whether the model should look at the ith element of the input (`attention_mask`).  
The other part is useful when we want to tokenize several sequences into one batch of elements of the same length. Then the attention mask can be used to hide the padding from the model.

In [4]:
text = ["This is an example text that we will tokenize", "Hello"]

# We set the padding token to be the same as the end-of-sequence token (EOS).
# The EOS token (</s> in this case) can mark the end of the sequence in training and can be used by a model to indicate it finishes its response.
# The BOS token (here <s>) can be used to mark the beginning of the input.
# Details vary between different models and implementations.
tokenizer.pad_token = tokenizer.eos_token

tokens_mask = tokenizer(text, return_tensors="pt", padding=True, truncation=False)
pp(tokens_mask)

detokenized = tokenizer.batch_decode(tokens_mask["input_ids"])
pp(detokenized)

{'input_ids': tensor([[   1,  660,  325,  371, 1938, 1880,  347,  389,  477, 8206,  753],
        [   1, 8479,    2,    2,    2,    2,    2,    2,    2,    2,    2]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}
['<s> This is an example text that we will tokenize',
 '<s> Hello</s></s></s></s></s></s></s></s></s>']


### Embedding
The input to the model is a batch of integer (token) sequences of shape `(batch, seq_len)` where:
* `batch` is the size of the batch;
* `seq_len` is the length of the longest input sequence inside the batch (the attention mask is used to mask the padding).

The first layer of the model replaces each integer with an embedding vector of length `hidden_size`.  
(Inside the model, there is a matrix of trainable parameters, randomly initialized, of shape `(num_dictionary_elements, hidden_size)`).  

After the embedding step, we pass a tensor of shape `(batch, seq_len, hidden_size)` through the remaining layers of the model.
In decoder-only models, the sequence length does not change.

### Transformer layer
The internal parts of the transformer are grouped into transformer layers.  
Usually, each layer consists of: layer norm, attention, layer norm, and a feed-forward layer.  
To be more precise the computation progresses roughly as presented below:
```python
def transformer_layer(x):
 x = attention(layer_norm_attn(x)) + x
 x = feed_forward(layer_norm_ff(x)) + x
 return x
```
Here:  
* **feed_forward** is an MLP (typically just two layers: linear-activation-linear) that acts on each token independently, in the same way. That is, it treats the sequence-length dimension like the batch dimension, and operates on the `hidden_size` dimension of an input of shape `(batch, seq_len, hidden_size)`.
* **layer_norm** – in LLaMa models replaced by [RMSNorm](https://pytorch.org/docs/stable/generated/torch.nn.modules.normalization.RMSNorm.html) (which are like LayerNorm, but without centering, i.e. without subtracting the mean). Similarly as `feed_forward` it operates only on the `hidden_size` dimension, treating other dimensions as a batch.
* **attention** – causal multi-head attention that you will implement in further parts of this notebook.

Let $t^{(1)}$ be an input tensor of shape `(batch, seq_len, hidden_size)`.
Attention will output a tensor $t^{(2)}$ of the same shape with the following property:
calculation of $t^{(2)}[b,s,h]$ depends only on values from $t^{(2)}[b,s',h']$ such that  `s' <= s`. In other words, calculation is done independently per batch entry and dependency is *causal* (the past can influence the future but the future cannot influence the past of the sequence).

### LM head
In the end, a linear projection is used to create weights for each element of the input dictionary.
To be more precise we take a tensor of shape `(batch, seq_len, hidden_dim)` and use norm + a linear projection from `hidden_dim` to `vocab_size`, in order to change it into tensor of shape `(batch, seq_len, vocab_size)`.  
Then we apply softmax over the last dimension (`vocab_size`) to get a probability distribution, for each element of the sequence.
The training loss will be cross entropy over next-token prediction.  
That is, for a tokenized input sequence $x$, the $i$-th output of the model (which only depends on $x[0], \dots, x[i]$), should be $x[i+1]$.
In other words, the ground-truth output is the input sequence shifted by one (with EOS added at the end).

### Example
Below we show the steps described above using OpenLLaMAv2 3B.

For full details, see: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py

In [5]:
## Tokenize the input.
tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_3b_v2", legacy=False)
tokenizer.pad_token = tokenizer.eos_token
text = ["2 + 7 = ", "2+7="]

tokens_mask = tokenizer(text, return_tensors="pt", padding=True)
tokens = tokens_mask["input_ids"]
attention_mask = tokens_mask["attention_mask"]
pp(tokens_mask)

{'input_ids': tensor([[    1, 29500, 29536,   835, 29500, 29574,   419, 29500],
        [    1, 29500, 29536, 29589, 29574, 29554,     2,     2]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 0, 0]])}


In [6]:
## Load the model from HuggingFace.
device = (
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)

# Downloads ~6.8GB in bfloat16 (2 bytes per parameter).
model = LlamaForCausalLM.from_pretrained(
    "openlm-research/open_llama_3b_v2", dtype=torch.bfloat16, device_map=device
)
model.eval()
print(sum(p.numel() for p in model.parameters()) * 2 / 1000 / 1000 / 1000)

config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/6.85G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/6.85G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

6.852947200000001


In [7]:
## Embed the tokenizer input.
print(model.model.embed_tokens)
with torch.no_grad():
    embedded_tokens = model.model.embed_tokens(tokens.to(device))
print(f"{tokens.shape=}\n{embedded_tokens.shape=}")
batch, seq_len, hidden_size = embedded_tokens.shape

Embedding(32000, 3200, padding_idx=0)
tokens.shape=torch.Size([2, 8])
embedded_tokens.shape=torch.Size([2, 8, 3200])


In [8]:
## Take the positions of each token [0, 1, 2, ...]
position_ids = torch.arange(seq_len, device=embedded_tokens.device)[None, ...]

# Embed them as (cos, sin, ...) rotations.
# In Llama, instead of concatenating these to the embedded input tokens,
# they will be applied to keys and queries inside each attention layer.
# ( RoPE: https://arxiv.org/pdf/2104.09864 )
# (The first tensor here is only used for its .device and .dtype).
with torch.no_grad():
    position_embeddings = model.model.rotary_emb(torch.zeros_like(embedded_tokens), position_ids)

In [9]:
# Compute the causal mask
# (including information about the attention_mask from padding).
causal_mask = create_causal_mask(
    config=model.config,
    input_embeds=embedded_tokens,
    attention_mask=attention_mask,
    position_ids=position_ids,
    past_key_values=None,
    cache_position=position_ids.squeeze(),
)
print(causal_mask)

tensor([[[[ True, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True,  True,  True]]],


        [[[ True, False, False, False, False, False, False, False],
          [ True,  True, False, False, False, False, False, False],
          [ True,  True,  True, False, False, False, False, False],
          [ True,  True,  True,  True, False, False, False, False],
          [ True,  True,  True,  True,  True, False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True

In [10]:
## Go through the layers of the model.
x = embedded_tokens

with torch.no_grad():
    for layer in tqdm(model.model.layers):
        x = layer(
            x,
            attention_mask=causal_mask,
            position_ids=position_ids,
            past_key_values=None,  # can be used to continue generation
            use_cache=False,
            cache_position=position_ids.squeeze(),
            position_embeddings=position_embeddings,
        )

    x = model.model.norm(x)
    x = model.lm_head(x)
    x = torch.nn.functional.softmax(x, dim=-1)
    print(f"{x.shape=}")
    next_token = torch.argmax(x[:, -1], dim=-1)
    print(next_token)
    print([tokenizer.decode(t) for t in next_token])


  0%|          | 0/26 [00:00<?, ?it/s][A
  4%|▍         | 1/26 [00:00<00:08,  2.95it/s][A
100%|██████████| 26/26 [00:00<00:00, 57.42it/s]


x.shape=torch.Size([2, 8, 32000])
tensor([29567,    13], device='cuda:0')
['9', '\n']


In [11]:
# Using HuggingFace generate().
with torch.no_grad():
    text = "The largest animal on earth is"
    tokens_mask = tokenizer(text, return_tensors="pt")
    output = model.generate(
        inputs=tokens_mask["input_ids"].to(device),
        max_new_tokens=8,
        num_beams=1,
        num_return_sequences=4,
        do_sample=True, # sample from the distribution created by softmax
        temperature=0.7, # divide pre softmax score by this value
        top_p=0.9 # cut out improbable tokens from sampling
    )
    pp(tokenizer.batch_decode(output))


['<s> The largest animal on earth is a whale. A whale is',
 '<s> The largest animal on earth is the Blue Whale. This mammal',
 '<s> The largest animal on earth is the elephant. They are often referred to',
 '<s> The largest animal on earth is the blue whale. It is the']


## Tools for implementing attention

#### `einsum`
[torch.einsum](https://docs.pytorch.org/docs/stable/generated/torch.einsum.html) is a useful tool for computing various forms of *contractions*, that is, expressions like $\sum_i A_{\dots,i,\dots} B_{\dots,i,\dots}$.
For example:
* `torch.einsum("ij,jk->ik", A, B)` is matrix multiplication.
* `torch.einsum('bij,bjk->bik', As, Bs)` is batched matrix multiplication.

In general, you go over every tuple (b,i,j,k,…) of all occuring letters, multiply the specified left and specified right element and accumulate that into the specified target element.
In other words, you sum over dimension that don't appear in output.

#### `where`
Use [torch.where(B,X,Y)](https://docs.pytorch.org/docs/stable/generated/torch.where.html) to implement an expression like $\begin{cases}
    X_{i,j} & \text{if }B_{i,j}\\
    Y_{i,j} & \text{otherwise}
\end{cases}$

#### `tril`
[torch.tril(A, diagonal=d)](https://docs.pytorch.org/docs/stable/generated/torch.tril.html) returns the lower-diagonal part of a matrix (not necessarily square), obtained be zeroing-out everything strictly above the main diagonal.
If $d \neq 0$ is given, everything above the diagonal `main + d` is zeroed out.

## Causal Attention Implementation
Your task is to finish the implementation of the attention mechanism below. In case of problems, you can refer to the original implementation that can be found [here](https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py#L258).

To be more precise, you are given query,key,value tensors (with positional encoding already applied), each of shape:
>`(batch, seq_len, num_heads, head_size)`  

Your task is to compute for each head a scaled dot product between each query and each key that is either at the same position as the query or precedes the query in the sequence.
That is, calculate a tensor `a` of shape:
> `(batch, num_heads, seq_len, seq_len)`

where  
$$
    a[b, h, q, k]=
\begin{cases}
    \sum_{d}{\mathrm{query}[b, q, h, d] \cdot \mathrm{key}[b, k, h, d]} / \sqrt{\mathrm{head\_size}}, & \text{if }k \leq q\\
     -\mathrm{large\_number},              & \text{otherwise}
\end{cases}
$$

Then you should calculate the softmax over the last dimension of `a` creating `p`.  

$$p = \mathrm{SoftMax}(a)$$
Then you should calculate
$$ out [b, q, h, d] = \sum_{k}{ p [b, h, q, k] \cdot \mathrm{value} [b, k, h, d] } $$

That is, for each query you should gather the `value`s using the probability distribution defined by `p`.  
In the end, you should reshape `out` to
> `(batch, seq_len, num_heads * head_size)`

and apply a linear projection `output_projection`.  

To compute the attention mask, use `tril` on `torch.ones((?, ?), device=x.device, dtype=torch.bool)`.
For simplicity, you may assume that the number of queries is equal to the number of keys.  
This is not always true: for example, when we run `generate()` from HuggingFace transformers library, it caches previous keys and values and create queries only for the new token(s).

In [15]:
def attention_forward(
    query: Tensor, key: Tensor, value: Tensor, output_projection: torch.nn.Linear,
) -> Tensor:
    batch, q_seq_len, num_heads, head_dim = query.shape
    batch, k_seq_len, num_heads, head_dim = key.shape

    assert value.shape == key.shape

    assert q_seq_len <= k_seq_len
    assert query.shape[0] == key.shape[0]
    assert query.shape[2:] == key.shape[2:]

    # TODO {

    # Dot products of every query with every key
    dot_products = torch.einsum('bkhd', 'bqhd', 'bhqk', query, key)

    # Causality: set upper diagonal part to -infty.
    lower_diag_part = torch.tril(dot_products)
    a = torch.where(lower_diag_part > 0, lower_diag_part/torch.sqrt(head_dim), float('-inf'))

    # Softmax for each query (over keys).
    p = torch.softmax(a, 3)

	  # Collect the values for each query, as sum over keys/values weighted by attention.
    out = torch.einsum('bhqk', 'bkhd', 'bqhd', p, value)

    # Reshape and apply the output projection.


    # TODO }
    assert out.shape == (batch, q_seq_len, num_heads * head_dim)
    return out

### Integration with OpenLLaMA
The code below integrates your solution from above into OpenLLaMA.

In [16]:
# Copied from https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def rotate_half(x: Tensor) -> Tensor:
    """Turns concat(x1, x2, dim=-1) into concat(-x2, x1, dim=-1)."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from  https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def apply_rotary_pos_emb(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) -> tuple[Tensor, Tensor]:
    """Applies Rotary Position Embedding (RoPE) to the query and key tensors.

    Args:
        q: The query tensor, shape (batch_size, heads, seq_len, head_dim).
        k: The key tensor, shape (batch_size, heads, seq_len, head_dim).
        cos: The cosine part of the rotary embedding, shape (batch_size, seq_len, head_dim).
        sin: The sine part of the rotary embedding, shape (batch_size, seq_len, head_dim).
    Returns: (q, k) tuple after rotation.

    """
    cos = cos.unsqueeze(1)  # Add the 'heads' dimension (1 but broadcastable to the actual number).
    sin = sin.unsqueeze(1)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# modified version of https://github.com/huggingface/transformers/blob/7f95372c6267d3163fd2aa74aeff9d84ddb6cc35/src/transformers/models/llama/modeling_llama.py
def custom_attention_forward(
    self,
    hidden_states: Tensor,
    position_embeddings: tuple[Tensor, Tensor],
    attention_mask: Tensor | None = None,
    position_ids: Tensor | None = None,
    past_key_values=None,
    use_cache: bool = False,
    cache_position: torch.LongTensor | None = None,
    **kwargs,
):
    x = hidden_states
    bsz, q_len, _ = x.shape

    query_states = self.q_proj(x)
    key_states = self.k_proj(x)
    value_states = self.v_proj(x)

    # Split last dim into num_heads & head_dim; then swap seq_len and num_head dims.
    query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

    # Apply RoPE.
    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

    # Use KV-cache.
    if past_key_values is not None:
        # sin and cos are specific to RoPE models; cache_position needed for the static cache
        cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
        key_states, value_states = past_key_values.update(
            key_states, value_states, self.layer_idx, cache_kwargs
        )

    # Swap back num_head and seq_len dims.
    query_states = query_states.transpose(1, 2)
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    attn_output = attention_forward(
        query=query_states,
        key=key_states,
        value=value_states,
        output_projection=self.o_proj,
    )

    return attn_output, None

### Testing
You can briefly test your solution below.

In [17]:
for l in model.model.layers:
    l.self_attn.forward = partial(custom_attention_forward, self=l.self_attn)

text = ["2 + 7 = "]

tokens_mask = tokenizer(text, return_tensors="pt")
tokens = tokens_mask["input_ids"]
attention_mask = tokens_mask["attention_mask"]

with torch.no_grad():
    output = model(input_ids=tokens.to(device))
    next_token = torch.argmax(output.logits[0, -1])
    print(next_token)
    decoded = tokenizer.decode(next_token)
    print(f"Model answer: {decoded}")
    assert decoded == "9"

ValueError: Number of einsum subscripts, 1, must be equal to the number of operands, 4.

In [None]:
## If you have implemented the attention that can handle token-by-token generation you can check your solution using the code below.

text = "Solve x + 3 = 7"
tokens_mask = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    output = model.generate(
        inputs=tokens_mask["input_ids"].to(device),
        max_new_tokens=8,
        num_beams=1,
        do_sample=False,
    )

    print(tokenizer.batch_decode(output))