# GPT

**GPT (Generative Pre-trained Transformer)** is a *decoder-only* Transformer model trained to generate text *autoregressively*, meaning it predicts the next token in a sequence given all previous tokens.

### Differences from the Standard Transformer Decoder
Architecturally, GPT is almost identical to the decoder block from the original Transformer, with one key simplification:
- GPT **removes the encoder-decoder attention layer**, since there is no encoder providing context.
- It retains **masked self-attention**, ensuring each token can only attend to past tokens (no future information).
- The feed-forward network, residual connections, and layer normalization are unchanged.

In summary:
> GPT = Transformer Decoder – (Encoder-Decoder Attention)

### Training Objective
GPT is trained as a **language model** using the *autoregressive* objective:
$$
P(x_1, x_2, ..., x_T) = \prod_{t=1}^{T} P(x_t | x_1, ..., x_{t-1})
$$
At each step, the model receives a sequence of tokens and learns to predict the next one.
Because there is no encoder, all contextual understanding is built from the left-to-right accumulation of information within the masked self-attention layers.


![GPT](img/gpt.png)


In [72]:
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
import torch
from torch import Tensor
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformer_lens import HookedTransformer
import einops
import numpy as np
import circuitsvis as cv
from IPython.display import display, HTML

In [2]:
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)
gpt2 = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
hooked_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False
)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
device

device(type='cuda')

In [4]:
text = "The raccoon sat on the mat."
token_ids = tokenizer.encode(text)
print(f"Token (ids): {token_ids}")
print(f"Tokens (string): {tokenizer.tokenize(text)}")
print(f"Text string: {tokenizer.decode(token_ids, skip_special_tokens=False)}")
print(f"Stuff to input a model: {tokenizer(text, return_tensors='pt')}")

Token (ids): [464, 3444, 20912, 3332, 319, 262, 2603, 13]
Tokens (string): ['The', 'Ġrac', 'coon', 'Ġsat', 'Ġon', 'Ġthe', 'Ġmat', '.']
Text string: The raccoon sat on the mat.
Stuff to input a model: {'input_ids': tensor([[  464,  3444, 20912,  3332,   319,   262,  2603,    13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


In [5]:
text = "Once upon a"
input_ids = tokenizer(text, return_tensors="pt").to(device)
with torch.inference_mode():
    output_logits = gpt2(**input_ids)["logits"]
print(f"Logits: {output_logits}")
print(f"Logits shape: {output_logits.shape}")

Logits: tensor([[[ -34.5645,  -34.4081,  -38.3079,  ...,  -41.6996,  -39.7801,
           -35.0521],
         [ -84.7256,  -82.9327,  -87.0166,  ...,  -91.6668,  -86.2355,
           -84.7095],
         [-109.0798, -105.7259, -109.9116,  ..., -114.2847, -107.6934,
          -105.3613]]], device='cuda:0')
Logits shape: torch.Size([1, 3, 50257])


In [6]:
output_probas = output_logits.softmax(dim=-1)
print(f"Probabilites over vocabulary: {output_probas}")

Probabilites over vocabulary: tensor([[[8.9157e-04, 1.0424e-03, 2.1106e-05,  ..., 7.1022e-07,
          4.8419e-06, 5.4752e-04],
         [3.2892e-06, 1.9759e-05, 3.3276e-07,  ..., 3.1810e-09,
          7.2668e-07, 3.3428e-06],
         [4.0889e-07, 1.1700e-05, 1.7799e-07,  ..., 2.2447e-09,
          1.6359e-06, 1.6848e-05]]], device='cuda:0')


In [7]:
most_likely_next_tokens = tokenizer.batch_decode(output_logits.argmax(dim=-1)[0])
print(list(zip(tokenizer.tokenize(text), most_likely_next_tokens)))

[('Once', ' the'), ('Ġupon', ' a'), ('Ġa', ' time')]


In [8]:
next_token = output_logits[0, -1].argmax(dim=-1)
next_char = tokenizer.decode(next_token)
print(
    "The next token is:", repr(next_char)
)  # repr is to show special tokens and spaces
print("How the sentence becomes: ", text + next_char)

The next token is: ' time'
How the sentence becomes:  Once upon a time


In [9]:
# Initialize text
text = "Once upon a"
# Convert text to tensor format
tokens = tokenizer(text, return_tensors="pt").to(device)
print("Generating text...\n")
# Generate 10 characters iteratively
for i in range(10):
    with torch.inference_mode():
        # Get model predictions
        output_logits = gpt2(**tokens).logits
        # Select the most likely next token
        next_token = output_logits[0, -1].argmax(dim=-1)
        # Decode the token to a character
        next_char = tokenizer.decode(next_token)
    # Display the sequence so far
    current_text = tokenizer.decode(tokens["input_ids"][0])  # Reconstruct the string
    print(f"Generation step {i + 1}:")
    print(f"Sequence so far: {current_text!r}")
    print(f"{tokens['input_ids'].shape[-1] + 1}th char = {next_char!r}\n")
    # Append the new character and re-tokenize
    text += next_char
    tokens = tokenizer(text, return_tensors="pt").to(device)
print("Final text:", text)

Generating text...

Generation step 1:
Sequence so far: 'Once upon a'
4th char = ' time'

Generation step 2:
Sequence so far: 'Once upon a time'
5th char = ','

Generation step 3:
Sequence so far: 'Once upon a time,'
6th char = ' the'

Generation step 4:
Sequence so far: 'Once upon a time, the'
7th char = ' world'

Generation step 5:
Sequence so far: 'Once upon a time, the world'
8th char = ' was'

Generation step 6:
Sequence so far: 'Once upon a time, the world was'
9th char = ' a'

Generation step 7:
Sequence so far: 'Once upon a time, the world was a'
10th char = ' place'

Generation step 8:
Sequence so far: 'Once upon a time, the world was a place'
11th char = ' of'

Generation step 9:
Sequence so far: 'Once upon a time, the world was a place of'
12th char = ' great'

Generation step 10:
Sequence so far: 'Once upon a time, the world was a place of great'
13th char = ' beauty'

Final text: Once upon a time, the world was a place of great beauty


In [10]:
reference_text = "Once upon a time, there was a fox who lived in a forest."
tokens = hooked_gpt2.to_tokens(reference_text).to(device)
logits, cache = hooked_gpt2.run_with_cache(tokens)
html = cv.attention.attention_pattern(
    tokens=hooked_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 3][0][7],
)
styled_html = f"""
<div style="width:800px; font-size:16px;">
    {html}
</div>
"""

display(HTML(styled_html))

In [30]:
print(gpt2)
print(gpt2.config)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
GPT2Config {
  "_attn_implementation_autoset": true,


In [31]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output:", output)
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output:", output)
    print("Output shape:", output.shape, "\n")

In [32]:
sequence = "Once upon a time, "
tokenized_sequence = tokenizer.tokenize(sequence)
tokens = tokenizer(sequence, return_tensors="pt").to(device)["input_ids"]
print("Tokenized sequence:", tokenized_sequence)
print("Token IDs:", tokens)

Tokenized sequence: ['Once', 'Ġupon', 'Ġa', 'Ġtime', ',', 'Ġ']
Token IDs: tensor([[7454, 2402,  257,  640,   11,  220]], device='cuda:0')


In [33]:
batch = 1  # starting with only one batch (thus 1 sentence)
seq_len = len(tokenized_sequence)  # 6


@dataclass
class Config:
    n_ctx: int = gpt2.config.n_ctx  # 1024
    d_model: int = gpt2.config.n_embd  # hidden size, or embedding dimension
    n_heads: int = gpt2.config.n_head  # number of attention heads
    n_layers: int = gpt2.config.n_layer  # number of transformer blocks
    d_mlp: int = 4 * d_model  # MLP hidden size, 3072
    d_head: int = d_model // n_heads  # dimension of each attention head, 64
    layer_norm_eps: float = gpt2.config.layer_norm_epsilon  # layer norm epsilon
    d_vocab: int = gpt2.config.vocab_size  # number of tokens in the vocabulary
    init_range: float = (
        gpt2.config.initializer_range
    )  # initialization range for weights
    debug: bool = True


cfg = Config()
print(cfg)

Config(n_ctx=1024, d_model=768, n_heads=12, n_layers=12, d_mlp=3072, d_head=64, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, debug=True)


In [34]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=cfg.init_range)

    def forward(
            self, int_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # just a mapping from int tokens to float vectors
        return self.W_E[int_tokens]

rand_int_test(Embed, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[ 0.0038, -0.0055, -0.0157,  ...,  0.0031,  0.0222, -0.0032],
         [ 0.0239,  0.0089,  0.0070,  ..., -0.0113, -0.0085, -0.0087],
         [-0.0225, -0.0015, -0.0037,  ...,  0.0027, -0.0052,  0.0158],
         [ 0.0408, -0.0259,  0.0038,  ..., -0.0022, -0.0025,  0.0281],
         [-0.0519,  0.0172, -0.0119,  ..., -0.0072,  0.0173, -0.0104],
         [ 0.0233, -0.0210,  0.0044,  ..., -0.0187,  0.0101,  0.0057]]],
       device='cuda:0', grad_fn=<IndexBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [35]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=cfg.init_range)

    def forward(
            self, int_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # take first seq_len learnt positional embeddings
        batch, seq_len = int_tokens.shape
        return einops.repeat(
            self.W_pos[:seq_len],
            "seq_len d_model -> batch seq_len d_model",
            batch=batch  # 1
        )

rand_int_test(PosEmbed, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[-0.0191, -0.0203,  0.0417,  ...,  0.0282, -0.0032,  0.0003],
         [ 0.0268,  0.0016, -0.0034,  ..., -0.0117, -0.0140,  0.0083],
         [-0.0132, -0.0190, -0.0066,  ...,  0.0175,  0.0058,  0.0253],
         [-0.0289,  0.0164, -0.0315,  ..., -0.0021,  0.0537,  0.0199],
         [-0.0102, -0.0250,  0.0203,  ...,  0.0277,  0.0105, -0.0184],
         [ 0.0227, -0.0239,  0.0294,  ..., -0.0652,  0.0156, -0.0050]]],
       device='cuda:0', grad_fn=<ExpandBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [36]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.empty(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))  # Bias!
        nn.init.normal_(self.w, std=cfg.init_range)

    def forward(
            self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # compute mean
        embedding_mean = embedding.mean(dim=-1, keepdim=True)
        # compute standard deviation + eps
        embedding_std = (embedding.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()
        # compute normalized embedding
        embedding = (embedding - embedding_mean) / embedding_std
        return embedding * self.w + self.b


rand_float_test(LayerNorm, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 1.3676e-02,  1.9080e-02, -2.3117e-03,  ..., -5.4136e-03,
          -2.7389e-03,  3.5097e-02],
         [ 1.3717e-02, -3.8329e-05,  1.7807e-03,  ...,  1.5169e-03,
           1.9697e-02,  2.9513e-02],
         [ 1.8159e-02, -1.1157e-03, -1.0391e-02,  ...,  7.9021e-03,
           2.1161e-03,  1.8344e-02],
         [ 4.7796e-02, -6.9015e-03,  2.4161e-03,  ...,  2.5075e-02,
           8.2112e-03,  1.5835e-02],
         [ 1.4386e-02, -1.1134e-02,  5.8101e-03,  ...,  1.9905e-02,
           1.6334e-03, -2.5775e-02],
         [ 4.9066e-02,  3.0447e-02,  6.4585e-03,  ...,  6.0914e-04,
           7.8220e-03, -1.8753e-02]]], device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [37]:
class Attention(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_K = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_K = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_Q = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_Q = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_V = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_V = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_O = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        self.b_O = nn.Parameter(torch.zeros(cfg.d_model))

        nn.init.normal_(self.W_K, std=cfg.init_range)
        nn.init.normal_(self.W_Q, std=cfg.init_range)
        nn.init.normal_(self.W_V, std=cfg.init_range)
        nn.init.normal_(self.W_O, std=cfg.init_range)
        self.mask = -1e4

    def forward(
            self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        batch, seq_len, _ = embedding.shape
        device = embedding.device
        # compute K, Q, V projections
        K = einops.einsum(
            embedding,
            self.W_K,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_K
        Q = einops.einsum(
            embedding,
            self.W_Q,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_Q
        V = einops.einsum(
            embedding,
            self.W_V,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_V
        # compute attention scores
        attention_scores = einops.einsum(
            Q,
            K,
            "batch dest_pos n_heads d_head, batch source_pos n_heads d_head -> batch n_heads dest_pos source_pos"
        )
        # scale and mask attention scores (causal attention)
        attention_scores = attention_scores / (self.cfg.d_head ** 0.5)
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool()
        attention_scores = attention_scores.masked_fill(mask, self.mask) # Masked positions should be close to 0, not actually 0
        # softmax attention scores
        attention_scores = attention_scores.softmax(dim=-1)
        # compute weighted sum of values
        z = einops.einsum(
            V,
            attention_scores,
            "batch seq_len n_heads d_head, batch n_heads dest_pos source_pos -> batch dest_pos n_heads d_head"
        )
        # compute output projection
        attn_output = einops.einsum(
            z,
            self.W_O,
            "batch seq_len n_heads d_head, n_heads d_head d_model -> batch seq_len d_model"
        ) + self.b_O
        return attn_output

rand_float_test(Attention, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622],
         [ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622],
         [ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622],
         [ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622],
         [ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622],
         [ 0.7654,  0.3072, -0.0493,  ..., -0.1509, -1.6303,  0.0622]]],
       device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [38]:
def gelu_new(
    input: Float[torch.Tensor, "batch pos d_mlp"],
) -> Float[torch.Tensor, "batch pos d_mlp"]:
    # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
    return (
        0.5
        * input
        * (
            1.0
            + torch.tanh(
                np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))
            )
        )
    )


class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty(cfg.d_model, cfg.d_mlp))
        self.b_in = nn.Parameter(torch.zeros(cfg.d_mlp))
        self.W_out = nn.Parameter(torch.empty(cfg.d_mlp, cfg.d_model))
        self.b_out = nn.Parameter(torch.zeros(cfg.d_model))
        nn.init.normal_(self.W_in, std=cfg.init_range)
        nn.init.normal_(self.W_out, std=cfg.init_range)
        self.activation = gelu_new

    def forward(
        self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # compute in projection
        pre = einops.einsum(
            embedding,
            self.W_in,
            "batch seq_len d_model, d_model d_mlp -> batch seq_len d_mlp"
        ) + self.b_in
        # apply activation
        post = self.activation(pre)
        # compute out projection
        mlp_out = einops.einsum(
            post,
            self.W_out,
            "batch seq_len d_mlp, d_mlp d_model -> batch seq_len d_model"
        ) + self.b_out
        return mlp_out

rand_float_test(MLP, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[-0.0502,  0.1272, -0.2881,  ...,  0.1292, -0.2617,  0.6294],
         [-0.0928, -0.1346,  0.1020,  ..., -0.5522, -0.0493,  0.0119],
         [ 0.6985,  0.4018, -0.5010,  ..., -0.0825, -0.3101,  0.0496],
         [ 0.2202,  0.3617, -0.5777,  ...,  0.2528, -0.0769, -0.9744],
         [-0.2822,  0.0543,  0.3357,  ...,  0.0627, -0.1648,  0.5498],
         [ 0.4221,  0.4063, -0.0057,  ..., -0.1639, -0.3572,  0.6323]]],
       device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [39]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, input_embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # normalize input
        embedding = self.ln1(input_embedding)
        # compute attention and add skip connection
        mid_embedding = self.attn(embedding) + input_embedding
        # normalize embedding
        embedding = self.ln2(mid_embedding)
        # compute MLP and add skip connection
        output_embedding = self.mlp(embedding) + mid_embedding
        # return output
        return output_embedding

rand_float_test(TransformerBlock, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 1.9948, -1.0613, -0.2920,  ...,  0.1536,  0.4180,  0.0734],
         [ 0.6076,  1.2586, -0.9310,  ..., -0.2288,  2.6938, -0.3181],
         [ 0.2172, -0.1263, -1.1064,  ..., -0.4843,  0.1577,  0.5916],
         [ 0.7809, -0.3711, -1.4346,  ...,  2.3915,  0.8297,  0.5332],
         [-0.7466, -0.6414,  0.1979,  ...,  0.7472,  0.7787,  0.6182],
         [-0.3397,  0.2590, -0.6169,  ..., -0.0495, -1.0076,  0.5021]]],
       device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [40]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty(cfg.d_model, cfg.d_vocab))
        self.b_U = nn.Parameter(torch.zeros(cfg.d_vocab))
        nn.init.normal_(self.W_U, std=cfg.init_range)

    def forward(
        self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        # compute logits
        logits = einops.einsum(
            embedding,
            self.W_U,
            "batch seq_len d_model, d_model d_vocab -> batch seq_len d_vocab"
        ) + self.b_U
        return logits

rand_float_test(Unembed, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 0.5838,  0.6147,  0.7070,  ..., -0.2241, -0.0592,  0.4205],
         [-1.2668,  1.7735, -0.1435,  ...,  0.6636,  0.7671, -0.4350],
         [ 0.2913, -1.0453, -0.6848,  ..., -0.1880,  0.4448, -0.1822],
         [-0.1750, -0.2846, -0.5337,  ...,  0.1832,  0.0181, -0.2026],
         [ 0.6756, -0.8161,  0.8145,  ...,  0.0566,  0.7791, -0.4561],
         [ 0.2309,  0.6905,  0.0889,  ..., -0.4160, -0.3475,  0.3143]]],
       device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 50257]) 



In [61]:
class GPT(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(
        self, input_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        # compute embeddings + positional embeddings
        embedding = self.embed(input_tokens) + self.pos_embed(input_tokens)
        # compute transformer blocks outputs
        for block in self.blocks:
            embedding = block(embedding)
        # normalize output
        # compute logits
        logits = self.unembed(self.ln_final(embedding))
        return logits

    def load_gpt2_weights(self, gpt2: GPT2LMHeadModel) -> None:
        state_dict = {}

        state_dict["embed.W_E"] = gpt2.transformer.wte.weight
        state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight

        for l in range(cfg.n_layers):
            state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight
            state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias

            # In GPT-2, q,k,v are produced by one big linear map, whose output is
            # concat([q, k, v])
            W = gpt2.transformer.h[l].attn.c_attn.weight
            W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
            W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
            W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads)
            W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads)

            state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
            state_dict[f"blocks.{l}.attn.W_K"] = W_K
            state_dict[f"blocks.{l}.attn.W_V"] = W_V

            qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias
            qkv_bias = einops.rearrange(
                qkv_bias,
                "(qkv index head)->qkv index head",
                qkv=3,
                index=cfg.n_heads,
                head=cfg.d_head,
            )
            state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
            state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
            state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]

            W_O = gpt2.transformer.h[l].attn.c_proj.weight
            W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
            state_dict[f"blocks.{l}.attn.W_O"] = W_O
            state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias

            state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight
            state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias

            W_in = gpt2.transformer.h[l].mlp.c_fc.weight
            state_dict[f"blocks.{l}.mlp.W_in"] = W_in
            state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias

            W_out = gpt2.transformer.h[l].mlp.c_proj.weight
            state_dict[f"blocks.{l}.mlp.W_out"] = W_out
            state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
        state_dict["unembed.W_U"] = gpt2.lm_head.weight.T

        # --- Fix: assicurarsi che il bias dell'unembed esista nello state_dict ---
        if getattr(gpt2.lm_head, "bias", None) is not None:
            state_dict["unembed.b_U"] = gpt2.lm_head.bias.to(self.unembed.b_U.device)
        else:
            state_dict["unembed.b_U"] = torch.zeros(
                self.unembed.b_U.shape,
                dtype=self.unembed.b_U.dtype,
                device=self.unembed.b_U.device,
            )

        state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight
        state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias
        self.load_state_dict(state_dict)

In [62]:
rand_int_test(GPT, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[-0.0068, -0.0099, -0.0056,  ..., -0.0149, -0.0012,  0.0143],
         [-0.0130, -0.0090, -0.0041,  ..., -0.0160, -0.0105,  0.0134],
         [-0.0067, -0.0079, -0.0072,  ..., -0.0215, -0.0041,  0.0113],
         [-0.0089, -0.0034, -0.0052,  ..., -0.0209, -0.0040,  0.0203],
         [-0.0058, -0.0114, -0.0069,  ..., -0.0173, -0.0086,  0.0145],
         [-0.0156, -0.0124, -0.0033,  ..., -0.0148, -0.0093,  0.0163]]],
       device='cuda:0', grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 50257]) 



In [63]:
demo_gpt2 = GPT(Config(debug=False)).to(device)
demo_gpt2.load_gpt2_weights(gpt2)
demo_gpt2.eval()
#demo_gpt2.load_state_dict(hooked_gpt2.state_dict(), strict=False)

GPT(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

In [65]:
# Initialize text
text = "Once upon a"
# Convert text to tensor format
tokens = tokenizer(text, return_tensors="pt").to(device)
# Use token ids directly to avoid re-tokenizing the decoded string each step
input_ids = tokens["input_ids"]  # shape (1, seq_len)
print("Generating text...\n")
# Generate 20 tokens iteratively
for i in range(20):
    with torch.inference_mode():
        # Get model predictions for current sequence
        output_logits = demo_gpt2(input_ids)
        # Select the most likely next token id (scalar)
        next_token_id = output_logits[0, -1].argmax(dim=-1)
        # ensure shape (1,1) for concatenation
        next_token_id = next_token_id.unsqueeze(0).unsqueeze(0).to(input_ids.device)
    # Display the sequence so far (decode current token ids)
    current_text = tokenizer.decode(input_ids[0].tolist())
    print(f"Generation step {i+1}:")
    print(f"Sequence so far: {current_text!r}")
    print(f"{input_ids.shape[-1]+1}th token id = {int(next_token_id.item())!r}\n")
    # Append the new token id (autoregressive generation)
    input_ids = torch.cat([input_ids, next_token_id], dim=1)
# final decode
final_text = tokenizer.decode(input_ids[0].tolist())
print("Final text:", final_text)

Generating text...

Generation step 1:
Sequence so far: 'Once upon a'
4th token id = 290

Generation step 2:
Sequence so far: 'Once upon a and'
5th token id = 290

Generation step 3:
Sequence so far: 'Once upon a and and'
6th token id = 290

Generation step 4:
Sequence so far: 'Once upon a and and and'
7th token id = 290

Generation step 5:
Sequence so far: 'Once upon a and and and and'
8th token id = 290

Generation step 6:
Sequence so far: 'Once upon a and and and and and'
9th token id = 290

Generation step 7:
Sequence so far: 'Once upon a and and and and and and'
10th token id = 290

Generation step 8:
Sequence so far: 'Once upon a and and and and and and and'
11th token id = 3503

Generation step 9:
Sequence so far: 'Once upon a and and and and and and and etc'
12th token id = 3503

Generation step 10:
Sequence so far: 'Once upon a and and and and and and and etc etc'
13th token id = 3503

Generation step 11:
Sequence so far: 'Once upon a and and and and and and and etc etc etc'
1

In [66]:
# Assicurati che tutto sia su device e in modalità eval
print("device models:", device)
print("demo_gpt2 device:", next(demo_gpt2.parameters()).device)
print("gpt2 device:", next(gpt2.parameters()).device)
print("demo_gpt2 eval?", not demo_gpt2.training)
print("gpt2 eval?", not gpt2.training)
# controlla shape di alcuni parametri chiave
print("embed.W_E shape (demo):", demo_gpt2.embed.W_E.shape)
print("wte weight shape (hf):", gpt2.transformer.wte.weight.shape)
print("unembed.W_U shape (demo):", demo_gpt2.unembed.W_U.shape)
print("lm_head weight shape (hf):", gpt2.lm_head.weight.shape)

device models: cuda
demo_gpt2 device: cuda:0
gpt2 device: cuda:0
demo_gpt2 eval? True
gpt2 eval? True
embed.W_E shape (demo): torch.Size([50257, 768])
wte weight shape (hf): torch.Size([50257, 768])
unembed.W_U shape (demo): torch.Size([768, 50257])
lm_head weight shape (hf): torch.Size([50257, 768])


In [67]:
prompt = "Once upon a"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
    hf_logits = gpt2(**inputs).logits  # shape (1, seq_len, vocab)
    demo_logits = demo_gpt2(inputs["input_ids"].long().to(next(demo_gpt2.parameters()).device))
# confronta ultimi token (ultimo pos)
hf_last = hf_logits[0, -1].float()
demo_last = demo_logits[0, -1].float()
print("hf_last mean/std:", hf_last.mean().item(), hf_last.std().item())
print("demo_last mean/std:", demo_last.mean().item(), demo_last.std().item())
diff = (hf_last - demo_last).abs()
print("max abs diff:", diff.max().item(), "mean abs diff:", diff.mean().item())

# Top-5 tokens HF
topk = 5
hf_topk = torch.topk(hf_last, topk)
demo_topk = torch.topk(demo_last, topk)
print("\nHF top-k tokens and scores:")
for i,t in enumerate(hf_topk.indices.tolist()):
    print(i+1, tokenizer.decode([int(t)]), float(hf_topk.values[i]))
print("\nDEMO top-k tokens and scores:")
for i,t in enumerate(demo_topk.indices.tolist()):
    print(i+1, tokenizer.decode([int(t)]), float(demo_topk.values[i]))

hf_last mean/std: -109.96678924560547 3.1215107440948486
demo_last mean/std: -97.36473846435547 3.5802812576293945
max abs diff: 26.025436401367188 mean abs diff: 12.602171897888184

HF top-k tokens and scores:
1  time -96.10369110107422
2  certain -96.84069061279297
3  moment -97.83592224121094
4  visit -97.91505432128906
5  while -98.48494720458984

DEMO top-k tokens and scores:
1  and -81.44944763183594
2 , -82.4196548461914
3  or -82.89862823486328
4  to -82.91533660888672
5  all -83.01960754394531


In [71]:
# This will try to map keys and show shapes / mismatches between hf and demo params
demo_state = demo_gpt2.state_dict()
hf = gpt2.state_dict()

mismatches = []
missing_in_demo = []
extra_in_demo = []

# costruisco un mapping semplice per i nomi già usati nella funzione load_gpt2_weights
# confronto shape se possibile
for k_demo, v in demo_state.items():
    if k_demo in hf:
        if tuple(hf[k_demo].shape) != tuple(v.shape):
            mismatches.append((k_demo, hf[k_demo].shape, v.shape))
    else:
        # prova a trovare alcune chiavi simili (utile per debugging)
        pass

# controlla per le key che ci aspettiamo nello state_dict che abbiamo costruito a mano
expected_keys = list(demo_state.keys())
print("Number of demo keys:", len(demo_state))
print("Number of hf keys:", len(hf))
print("Mismatches (shape differences):", mismatches[:10])
# Non andiamo a stampare tutto per non inquinare il notebook; se vuoi che controlli pattern di nomi,
# stampa qui i primi 30 nomi delle state_dict per confronto:
print("Demo keys sample:", sorted(list(demo_state.keys())))
print("HF keys sample:", sorted(list(hf.keys())))

Number of demo keys: 198
Number of hf keys: 149
Mismatches (shape differences): []
Demo keys sample: ['blocks.0.attn.W_K', 'blocks.0.attn.W_O', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_V', 'blocks.0.ln1.b', 'blocks.0.ln1.w', 'blocks.0.ln2.b', 'blocks.0.ln2.w', 'blocks.0.mlp.W_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_in', 'blocks.0.mlp.b_out', 'blocks.1.attn.W_K', 'blocks.1.attn.W_O', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_V', 'blocks.1.ln1.b', 'blocks.1.ln1.w', 'blocks.1.ln2.b', 'blocks.1.ln2.w', 'blocks.1.mlp.W_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_in', 'blocks.1.mlp.b_out', 'blocks.10.attn.W_K', 'blocks.10.attn.W_O', 'blocks.10.attn.W_Q', 'blocks.10.attn.W_V', 'blocks.10.attn.b_K', 'blocks.10.attn.b_O', 'blocks.10.attn.b_Q', 'blocks.10.attn.b_V', 'blocks.10.ln1.b', 'blocks.10.ln1.w', 'blocks.10.ln2.b', 'blocks.10.ln

In [69]:
# Esempio semplice di greedy vs top-k sampling:
def generate_demo(model, tokenizer, prompt, max_tokens=20, temperature=1.0, top_k=None):
    model.eval()
    ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to(next(model.parameters()).device)
    for _ in range(max_tokens):
        with torch.inference_mode():
            logits = model(ids)[0, -1]  # (vocab,)
            logits = logits / max(temperature, 1e-8)
            if top_k is not None:
                values, indices = torch.topk(logits, top_k)
                probs = torch.softmax(values, dim=-1)
                next_id = indices[torch.multinomial(probs, num_samples=1)]
            else:
                probs = torch.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
        ids = torch.cat([ids, next_id.unsqueeze(0)], dim=1)
    return tokenizer.decode(ids[0].tolist())

print("Greedy (temperature=1):", generate_demo(demo_gpt2, tokenizer, "Once upon a", max_tokens=20, temperature=1.0, top_k=None))
print("Top-k (k=40):", generate_demo(demo_gpt2, tokenizer, "Once upon a", max_tokens=20, temperature=1.0, top_k=40))
print("Sampling temp=0.7, k=40:", generate_demo(demo_gpt2, tokenizer, "Once upon a", max_tokens=20, temperature=0.7, top_k=40))

Greedy (temperature=1): Once upon a oh or favorites under all RED forsoever light and plans light pending plans solicit plan pending plans plans plans
Top-k (k=40): Once upon a & and and and under along and below and including links presently presently etc forthcomingnc etc avenue contact bid
Sampling temp=0.7, k=40: Once upon a I, and and and and and forthcoming and forthcoming forthcoming forthcoming solicit solicit solicit solicit solicit solicit solicit solicit
