In [1]:
%pip install -q transformers huggingface_hub
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

### Build-a-transformer

In this section, you will implement a transformer language model layer by layer, then use it to generate (hopefully) coherent text.

To understand how these layers work, please check out our guide to transformers from [nlp course for you -> transformers](https://lena-voita.github.io/nlp_course/seq2seq_and_attention.html#transformer_intro).


First, we download pre-trained weights for the [GPT2 model by OpenAI](https://openai.com/research/better-language-models) - a prominent model from 2019.



Idea & code by: Ilya Beletsky

In [2]:
from huggingface_hub import hf_hub_download
state_dict = torch.load(hf_hub_download("gpt2", filename="pytorch_model.bin"))
for key, value in tuple(state_dict.items()):
    if key.startswith('h.') and key.endswith('.weight') and value.ndim == 2:
        value.transpose_(1, 0)  # <-- for compatibility with modern PyTorch modules
    if key.startswith('h.') and key.endswith('.attn.bias') and value.ndim == 4:
        state_dict.pop(key)  # <-- triangular binar masks, not needed in this code

print('Weights:', repr(sorted(state_dict.keys()))[:320], '...')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Weights: ['h.0.attn.c_attn.bias', 'h.0.attn.c_attn.weight', 'h.0.attn.c_proj.bias', 'h.0.attn.c_proj.weight', 'h.0.ln_1.bias', 'h.0.ln_1.weight', 'h.0.ln_2.bias', 'h.0.ln_2.weight', 'h.0.mlp.c_fc.bias', 'h.0.mlp.c_fc.weight', 'h.0.mlp.c_proj.bias', 'h.0.mlp.c_proj.weight', 'h.1.attn.c_attn.bias', 'h.1.attn.c_attn.weight', 'h.1. ...


In the next few cells, we shall implement the model layer by layer to make use of those weights.

As you might recall, transformers contain two main layer types: attention and fully-connected layers.

The fully connected layers are by far easier to understand, so we shall begin there:

Please implement fully-connected layer __without residual or layer normalization__ (we'll add those in a bit).

In [4]:
class GeLUThatWasUsedInGPT2(nn.Module):
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3)))

class FullyConnected(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.c_fc = nn.Linear(dim, 4  * dim)
        self.gelu = GeLUThatWasUsedInGPT2()
        self.c_proj = nn.Linear(4 * dim, dim)

    def forward(self, x):
        # x.shape = [batch_size, seq_length, dim]
        x = self.c_fc(x)           # [batch_size, seq_length, 4*dim]
        x = self.gelu(x)            # [batch_size, seq_length, 4*dim]
        x = self.c_proj(x)          # [batch_size, seq_length, dim]
        return x


Now, let's test that it works with GPT-2 weights:

In [5]:
mlp = FullyConnected(dim=768)
mlp.load_state_dict({'c_fc.weight': state_dict['h.0.mlp.c_fc.weight'],
                     'c_fc.bias': state_dict['h.0.mlp.c_fc.bias'],
                     'c_proj.weight': state_dict['h.0.mlp.c_proj.weight'],
                     'c_proj.bias': state_dict['h.0.mlp.c_proj.bias']})

torch.manual_seed(1337)
x = torch.randn(1, 2, 768)  # [batch_size, sequence_length, dim]
checksum = torch.sum(mlp(x) * x)
assert abs(checksum.item() - 1282.3315) < 0.1, "layer outputs do not match reference"
assert torch.allclose(mlp(x[:, (1, 0), :])[:, (1, 0), :], mlp(x)), "mlp must be permutation-invariant"
print("Seems legit!")

Seems legit!


Now, let's get to attention layers.

Since GPT-2 needs to generate text from left to right, each generated token can only attend to tokens on the left (and itself). This kid of attention is called "Masked" self-attention, because it hides tokens to the right.

As before, please implement masked self-attention __without layernorm or residual connections.__

In [6]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.c_attn = nn.Linear(dim, dim * 3)  # query + key + value, combined
        self.c_proj = nn.Linear(dim, dim)  # output projection
        self.dim, self.num_heads = dim, num_heads
        self.head_size = dim // num_heads

    def forward(self, x):
        q, k, v = self.c_attn(x).split(dim=-1, split_size=self.dim)
        assert q.shape == k.shape == v.shape == x.shape, "q, k and v must have the same shape as x"


        # Note: this is an inefficient implementation that uses a for-loop.
        # To get the full grade during homework, please re-implement this code:
        # 1) do not use for-loops (or other loops). Compute everything in parallel with vectorized operations
        # 2) do not use F.scaled_dot_product_attention - write your own attention code using basic PyTorch ops
        head_outputs = []
        for head_index in range(self.num_heads):
            head_selector = range(self.head_size * head_index, self.head_size * (head_index + 1))

            head_queries = q[..., head_selector]
            head_keys = k[..., head_selector]
            head_values = v[..., head_selector]

            single_head_output = F.scaled_dot_product_attention(
                head_queries,
                head_keys,
                head_values,
                is_causal=True
            )
            # docs: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
            head_outputs.append(single_head_output)

        combined_head_outputs = torch.cat(head_outputs, dim=-1)
        return self.c_proj(combined_head_outputs)


Test that it works

In [7]:
attn = MaskedSelfAttention(dim=768, num_heads=12)
attn.load_state_dict({'c_attn.weight': state_dict['h.0.attn.c_attn.weight'],
                      'c_attn.bias': state_dict['h.0.attn.c_attn.bias'],
                      'c_proj.weight': state_dict['h.0.attn.c_proj.weight'],
                      'c_proj.bias': state_dict['h.0.attn.c_proj.bias']})

torch.manual_seed(1337)
x = torch.randn(1, 10, 768)  # [batch_size, sequence_length, dim]
checksum = torch.sum(attn(x) * x)
assert abs(checksum.item() - 2703.6772) < 0.1, "layer outputs do not match reference"
assert not torch.allclose(attn(x[:, (1, 0), :])[:, (1, 0), :], attn(x[:, (0, 1), :])), "masked attention must *not* be permutation-invariant"
print("It works!")

It works!


We can now combine attention and MLP to build the full transformer layer:

![img](https://i.imgur.com/1sq2vHO.png)

In [8]:
class TransformerLayer(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.ln_1 = nn.LayerNorm(dim)
        self.attn = MaskedSelfAttention(dim, num_heads)
        self.ln_2 = nn.LayerNorm(dim)
        self.mlp = FullyConnected(dim)

    def forward(self, x):
        # Attention block with residual connection
        x = x + self.attn(self.ln_1(x))

        # MLP block with residual connection
        x = x + self.mlp(self.ln_2(x))

        return x

In [9]:
layer = TransformerLayer(dim=768, num_heads=12)
layer.load_state_dict({k[5:]: v for k, v in state_dict.items() if k.startswith('h.10.')})
assert abs(torch.sum(layer(x) * x).item() - 9874.7383) < 0.1
print("Good job!")

Good job!


In [10]:
class GPT2(nn.Module):
    def __init__(self, vocab_size: int, dim: int, num_heads: int, num_layers: int, max_position_embeddings: int = 1024):
        super().__init__()
        self.wte = nn.Embedding(vocab_size, dim)  # token embeddings
        self.wpe = nn.Embedding(max_position_embeddings, dim)  # position embeddings
        self.ln_f = nn.LayerNorm(dim)   # final layer norm - goes after all transformer layers, but before logits

        self.h = nn.Sequential(*(TransformerLayer(dim, num_heads) for layer in range(num_layers)))

    def forward(self, input_ids):
        # input_ids.shape: [batch_size, sequence_length], int64 token ids
        position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0)

        token_embeddings = self.wte(input_ids)
        position_embeddings = self.wpe(position_ids)
        full_embeddings = token_embeddings + position_embeddings

        transformer_output = self.h(full_embeddings)
        transformer_output_ln = self.ln_f(transformer_output)

        # final layer: we predict logits by re-using token embeddings as linear weights
        output_logits = transformer_output_ln @ self.wte.weight.T
        return output_logits


In [11]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', add_prefix_space=True)
model = GPT2(vocab_size=50257, dim=768, num_heads=12, num_layers=12)
model.load_state_dict(state_dict)

input_ids = tokenizer("A quick", return_tensors='pt')['input_ids']

predicted_logits = model(input_ids)
most_likely_token_id = predicted_logits[:, -1].argmax().item()

print("Prediction:", tokenizer.decode(most_likely_token_id))

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Prediction:  look


In [None]:
text = "The Fermi paradox "
tokens = tokenizer.encode(text)
print(end=tokenizer.decode(tokens))
line_length = len(tokenizer.decode(tokens))

for i in range(500):
    # Predict logits with your model
    with torch.no_grad():
        logits = model(torch.as_tensor([tokens]))

    # Sample with probabilities
    p_next = torch.softmax(logits[0, -1, :], dim=-1).data.cpu().numpy()
    next_token_index = np.random.choice(len(p_next), p=p_next)

    tokens.append(int(next_token_index))
    print(end=tokenizer.decode(tokens[-1]))
    line_length += len(tokenizer.decode(tokens[-1]))
    if line_length > 120:
      line_length = 0
      print()



 The Fermi paradox ields a shock wave story where the tails are falling in gradual degrees to form lower vibrations in the
 lower layers of intuitable cruise instruments. This is well enough to explain the cessation of a long rode," he explained
. But before I had a chance to read the explanation I had to steer away from sail lamps. What I expected is that some drivers
 would be objects of local attraction and make a vow not to attempt the bank. But let's be on the safe side and not want passengers
 conforming on the boardwalk. Once radio sweeps ready, I recall some sign display from school children, calling for that small
 point on the bright yellow uniform, where spendthrift, slick-smelling people (including college students, who veer into storage
 during their nightly grueling every--day's activities without telling anyone) go to divert their attention toward debris
-laden retrieval when prompted. How french has anyone worked up a plan to hide his or her tongue? I imagine it arrange

__Reminder:__ after class, please go to `MaskedSelfAttention.forward` above and finish the job!
```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```

```


### Here's how you can do the same with transformers library

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained('gpt2', add_prefix_space=True)
model = transformers.AutoModelForCausalLM.from_pretrained('gpt2')
print('Generated text:', tokenizer.decode(
    model.generate(
        **tokenizer("The Fermi paradox ", return_tensors='pt'),
        do_sample=True, max_new_tokens=50
    ).flatten().numpy()
))


In [None]:
# BONUS: Vectorized MaskedSelfAttention without loops (for homework)
class MaskedSelfAttentionVectorized(nn.Module):
    def __init__(self, dim: int, num_heads: int):
        super().__init__()
        self.c_attn = nn.Linear(dim, dim * 3)
        self.c_proj = nn.Linear(dim, dim)
        self.dim, self.num_heads = dim, num_heads
        self.head_size = dim // num_heads

    def forward(self, x):
        batch_size, seq_length, dim = x.shape

        # Compute Q, K, V
        q, k, v = self.c_attn(x).split(dim=-1, split_size=self.dim)

        # Reshape to [batch_size, num_heads, seq_length, head_size]
        q = q.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)
        k = k.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)
        v = v.view(batch_size, seq_length, self.num_heads, self.head_size).transpose(1, 2)

        # Compute attention scores: Q @ K^T / sqrt(head_size)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_size)

        # Create causal mask (lower triangular)
        causal_mask = torch.triu(torch.ones(seq_length, seq_length, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(causal_mask, float('-inf'))

        # Apply softmax to get attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Apply attention weights to values
        attn_output = torch.matmul(attn_weights, v)

        # Reshape back: [batch_size, seq_length, num_heads, head_size] -> [batch_size, seq_length, dim]
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_length, dim)

        # Final projection
        return self.c_proj(attn_output)