In [2]:
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict
from jaxtyping import Float, Int
import torch
from torch import Tensor
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformer_lens import HookedTransformer
import einops
import numpy as np
import circuitsvis as cv
from IPython.display import display, HTML

In [3]:
device = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available() else "cpu"
)
gpt2 = GPT2LMHeadModel.from_pretrained("openai-community/gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
hooked_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False
)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
text = "The raccoon sat on the mat."
token_ids = tokenizer.encode(text)
print(f"Token (ids): {token_ids}")
print(f"Tokens (string): {tokenizer.tokenize(text)}")
print(f"Text string: {tokenizer.decode(token_ids, skip_special_tokens=False)}")
print(f"Stuff to input a model: {tokenizer(text, return_tensors='pt')}")

Token (ids): [464, 3444, 20912, 3332, 319, 262, 2603, 13]
Tokens (string): ['The', 'Ġrac', 'coon', 'Ġsat', 'Ġon', 'Ġthe', 'Ġmat', '.']
Text string: The raccoon sat on the mat.
Stuff to input a model: {'input_ids': tensor([[  464,  3444, 20912,  3332,   319,   262,  2603,    13]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}


In [6]:
text = "Once upon a"
input_ids = tokenizer(text, return_tensors="pt").to(device)
with torch.inference_mode():
    output_logits = gpt2(**input_ids)["logits"]
print(f"Logits: {output_logits}")
print(f"Logits shape: {output_logits.shape}")

Logits: tensor([[[ -34.5645,  -34.4082,  -38.3080,  ...,  -41.6997,  -39.7802,
           -35.0521],
         [ -84.7256,  -82.9326,  -87.0165,  ...,  -91.6668,  -86.2355,
           -84.7094],
         [-109.0798, -105.7259, -109.9115,  ..., -114.2847, -107.6933,
          -105.3613]]])
Logits shape: torch.Size([1, 3, 50257])


In [7]:
output_probas = output_logits.softmax(dim=-1)
print(f"Probabilites over vocabulary: {output_probas}")

Probabilites over vocabulary: tensor([[[8.9156e-04, 1.0424e-03, 2.1106e-05,  ..., 7.1022e-07,
          4.8419e-06, 5.4752e-04],
         [3.2892e-06, 1.9758e-05, 3.3276e-07,  ..., 3.1810e-09,
          7.2670e-07, 3.3428e-06],
         [4.0888e-07, 1.1700e-05, 1.7799e-07,  ..., 2.2447e-09,
          1.6358e-06, 1.6848e-05]]])


In [8]:
most_likely_next_tokens = tokenizer.batch_decode(output_logits.argmax(dim=-1)[0])
print(list(zip(tokenizer.tokenize(text), most_likely_next_tokens)))

[('Once', ' the'), ('Ġupon', ' a'), ('Ġa', ' time')]


In [9]:
next_token = output_logits[0, -1].argmax(dim=-1)
next_char = tokenizer.decode(next_token)
print(
    "The next token is:", repr(next_char)
)  # repr is to show special tokens and spaces
print("How the sentence becomes: ", text + next_char)

The next token is: ' time'
How the sentence becomes:  Once upon a time


In [10]:
# Initialize text
text = "Once upon a"
# Convert text to tensor format
tokens = tokenizer(text, return_tensors="pt").to(device)
print("Generating text...\n")
# Generate 10 characters iteratively
for i in range(10):
    with torch.inference_mode():
        # Get model predictions
        output_logits = gpt2(**tokens).logits
        # Select the most likely next token
        next_token = output_logits[0, -1].argmax(dim=-1)
        # Decode the token to a character
        next_char = tokenizer.decode(next_token)
    # Display the sequence so far
    current_text = tokenizer.decode(tokens["input_ids"][0])  # Reconstruct the string
    print(f"Generation step {i + 1}:")
    print(f"Sequence so far: {current_text!r}")
    print(f"{tokens['input_ids'].shape[-1] + 1}th char = {next_char!r}\n")
    # Append the new character and re-tokenize
    text += next_char
    tokens = tokenizer(text, return_tensors="pt").to(device)
print("Final text:", text)

Generating text...

Generation step 1:
Sequence so far: 'Once upon a'
4th char = ' time'

Generation step 2:
Sequence so far: 'Once upon a time'
5th char = ','

Generation step 3:
Sequence so far: 'Once upon a time,'
6th char = ' the'

Generation step 4:
Sequence so far: 'Once upon a time, the'
7th char = ' world'

Generation step 5:
Sequence so far: 'Once upon a time, the world'
8th char = ' was'

Generation step 6:
Sequence so far: 'Once upon a time, the world was'
9th char = ' a'

Generation step 7:
Sequence so far: 'Once upon a time, the world was a'
10th char = ' place'

Generation step 8:
Sequence so far: 'Once upon a time, the world was a place'
11th char = ' of'

Generation step 9:
Sequence so far: 'Once upon a time, the world was a place of'
12th char = ' great'

Generation step 10:
Sequence so far: 'Once upon a time, the world was a place of great'
13th char = ' beauty'

Final text: Once upon a time, the world was a place of great beauty


In [11]:
reference_text = "Once upon a time, there was a fox who lived in a forest."
tokens = hooked_gpt2.to_tokens(reference_text).to(device)
logits, cache = hooked_gpt2.run_with_cache(tokens)
html = cv.attention.attention_pattern(
    tokens=hooked_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 3][0][7],
)
styled_html = f"""
<div style="width:800px; font-size:16px;">
    {html}
</div>
"""

display(HTML(styled_html))

In [12]:
print(gpt2)
print(gpt2.config)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)
GPT2Config {
  "activation_function": "gelu_new",
  "

In [32]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output:", output)
    print("Output shape:", output.shape, "\n")


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple):
        output = output[0]
    print("Output:", output)
    print("Output shape:", output.shape, "\n")

In [13]:
sequence = "Once upon a time, "
tokenized_sequence = tokenizer.tokenize(sequence)
tokens = tokenizer(sequence, return_tensors="pt").to(device)["input_ids"]
print("Tokenized sequence:", tokenized_sequence)
print("Token IDs:", tokens)

Tokenized sequence: ['Once', 'Ġupon', 'Ġa', 'Ġtime', ',', 'Ġ']
Token IDs: tensor([[7454, 2402,  257,  640,   11,  220]])


In [14]:
batch = 1  # starting with only one batch (thus 1 sentence)
seq_len = len(tokenized_sequence)  # 6


@dataclass
class Config:
    n_ctx: int = gpt2.config.n_ctx  # 1024
    d_model: int = gpt2.config.n_embd  # hidden size, or embedding dimension
    n_heads: int = gpt2.config.n_head  # number of attention heads
    n_layers: int = gpt2.config.n_layer  # number of transformer blocks
    d_mlp: int = 4 * d_model  # MLP hidden size, 3072
    d_head: int = d_model // n_heads  # dimension of each attention head, 64
    layer_norm_eps: float = gpt2.config.layer_norm_epsilon  # layer norm epsilon
    d_vocab: int = gpt2.config.vocab_size  # number of tokens in the vocabulary
    init_range: float = (
        gpt2.config.initializer_range
    )  # initialization range for weights
    debug: bool = True


cfg = Config()
print(cfg)

Config(n_ctx=1024, d_model=768, n_heads=12, n_layers=12, d_mlp=3072, d_head=64, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, debug=True)


In [33]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=cfg.init_range)

    def forward(
            self, int_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # just a mapping from int tokens to float vectors
        return self.W_E[int_tokens]

rand_int_test(Embed, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[-0.0134, -0.0178, -0.0035,  ..., -0.0042,  0.0115,  0.0149],
         [-0.0033, -0.0063,  0.0049,  ..., -0.0114,  0.0026,  0.0065],
         [-0.0195, -0.0139,  0.0228,  ...,  0.0202, -0.0011, -0.0027],
         [ 0.0449, -0.0042, -0.0135,  ..., -0.0188,  0.0015,  0.0049],
         [-0.0134,  0.0182, -0.0066,  ..., -0.0381,  0.0009,  0.0242],
         [ 0.0178, -0.0264,  0.0298,  ...,  0.0233,  0.0306,  0.0151]]],
       grad_fn=<IndexBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [35]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=cfg.init_range)

    def forward(
            self, int_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # take first seq_len learnt positional embeddings
        return einops.repeat(
            self.W_pos[:seq_len],
            "seq_len d_model -> batch seq_len d_model",
            batch=batch  # 1
        )

rand_int_test(PosEmbed, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[-1.5321e-02, -2.0440e-03,  8.6898e-03,  ...,  6.8708e-03,
           1.7418e-02, -4.6864e-02],
         [-1.0571e-02, -7.4795e-03,  3.4819e-02,  ..., -3.7164e-03,
          -1.1367e-02, -3.4639e-02],
         [ 5.8133e-03,  5.5778e-08,  8.6300e-03,  ..., -1.9935e-02,
           1.6901e-02,  2.2116e-03],
         [ 2.5371e-02,  1.2460e-02,  2.1821e-02,  ..., -6.7912e-03,
          -1.0191e-02, -4.3987e-03],
         [-2.1841e-02, -1.2428e-02,  9.4734e-03,  ...,  2.3605e-03,
          -2.0820e-03,  4.8564e-03],
         [-1.1437e-02,  1.5380e-02,  3.3769e-03,  ..., -7.8232e-04,
           3.9412e-02, -3.0546e-03]]], grad_fn=<ExpandBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [39]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.empty(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))  # Bias!
        nn.init.normal_(self.w, std=cfg.init_range)

    def forward(
            self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # compute mean
        embedding_mean = embedding.mean(dim=-1, keepdim=True)
        # compute standard deviation + eps
        embedding_std = (embedding.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps).sqrt()
        # compute normalized embedding
        embedding = (embedding - embedding_mean) / embedding_std
        return embedding * self.w + self.b


rand_float_test(LayerNorm, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 1.3315e-02, -4.1681e-04,  2.0928e-03,  ..., -4.7153e-04,
          -4.5773e-03, -1.7275e-02],
         [-4.6628e-03,  1.0136e-02,  5.6740e-03,  ..., -4.1052e-04,
          -7.3222e-03,  3.3341e-02],
         [ 1.3759e-02,  1.7675e-02,  7.0852e-03,  ..., -1.3318e-05,
           1.3761e-02,  1.0174e-02],
         [ 1.9054e-02,  2.1374e-03,  1.1312e-02,  ...,  2.5715e-04,
           1.8823e-02,  3.3176e-03],
         [ 1.1060e-03,  2.1896e-02, -5.5682e-03,  ...,  9.3507e-04,
           1.1582e-02, -5.4541e-02],
         [-1.0707e-02, -3.1679e-03,  2.1488e-02,  ..., -1.9208e-03,
           1.8057e-02,  5.2587e-02]]], grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [41]:
class Attention(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_K = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_K = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_Q = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_Q = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_V = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.b_V = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.W_O = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        self.b_O = nn.Parameter(torch.zeros(cfg.d_model))

        nn.init.normal_(self.W_K, std=cfg.init_range)
        nn.init.normal_(self.W_Q, std=cfg.init_range)
        nn.init.normal_(self.W_V, std=cfg.init_range)
        nn.init.normal_(self.W_O, std=cfg.init_range)
        self.mask = -1e4

    def forward(
            self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # compute K, Q, V projections
        K = einops.einsum(
            embedding,
            self.W_K,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_K
        Q = einops.einsum(
            embedding,
            self.W_Q,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_Q
        V = einops.einsum(
            embedding,
            self.W_V,
            "batch seq_len d_model, n_heads d_model d_head -> batch seq_len n_heads d_head"
        ) + self.b_V
        # compute attention scores
        attention_scores = einops.einsum(
            Q,
            K,
            "batch dest_pos n_heads d_head, batch source_pos n_heads d_head -> batch n_heads dest_pos source_pos"
        )
        # scale and mask attention scores (causal attention)
        attention_scores = attention_scores / (self.cfg.d_head ** 0.5)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        attention_scores = attention_scores.masked_fill(mask, self.mask) # Masked positions should be close to 0, not actually 0
        # softmax attention scores
        attention_scores = attention_scores.softmax(dim=-1)
        # compute weighted sum of values
        z = einops.einsum(
            V,
            attention_scores,
            "batch seq_len n_heads d_head, batch n_heads dest_pos source_pos -> batch dest_pos n_heads d_head"
        )
        # compute output projection
        attn_output = einops.einsum(
            z,
            self.W_O,
            "batch seq_len n_heads d_head, n_heads d_head d_model -> batch seq_len d_model"
        ) + self.b_O
        return attn_output

rand_float_test(Attention, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109],
         [ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109],
         [ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109],
         [ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109],
         [ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109],
         [ 1.9943,  0.5804, -0.8176,  ...,  1.0229, -0.5229, -0.2109]]],
       grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [52]:
def gelu_new(
    input: Float[torch.Tensor, "batch pos d_mlp"],
) -> Float[torch.Tensor, "batch pos d_mlp"]:
    # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
    return (
        0.5
        * input
        * (
            1.0
            + torch.tanh(
                np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))
            )
        )
    )


class MLP(nn.Module):
    def __init__(self, cgf: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty(cfg.d_model, cfg.d_mlp))
        self.b_in = nn.Parameter(torch.zeros(cfg.d_mlp))
        self.W_out = nn.Parameter(torch.empty(cfg.d_mlp, cfg.d_model))
        self.b_out = nn.Parameter(torch.zeros(cfg.d_model))
        nn.init.normal_(self.W_in, std=cfg.init_range)
        nn.init.normal_(self.W_out, std=cfg.init_range)
        self.activation = gelu_new

    def forward(
        self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # compute in projection
        pre = einops.einsum(
            embedding,
            self.W_in,
            "batch seq_len d_model, d_model d_mlp -> batch seq_len d_mlp"
        ) + self.b_in
        # apply activation
        post = self.activation(pre)
        # compute out projection
        mlp_out = einops.einsum(
            post,
            self.W_out,
            "batch seq_len d_mlp, d_mlp d_model -> batch seq_len d_model"
        ) + self.b_out
        return mlp_out

rand_float_test(MLP, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[-0.1366, -0.1738,  0.2074,  ..., -0.7421, -0.5898,  0.0689],
         [-0.2110,  0.0263,  0.0598,  ..., -0.1454, -0.4640, -0.2758],
         [ 0.0894, -0.5796, -0.0959,  ...,  0.0762,  0.1268,  0.2014],
         [ 0.5708, -0.1968, -0.1139,  ...,  1.1255, -0.0080,  0.4017],
         [ 0.2236, -0.2967, -0.2361,  ..., -0.2367, -0.0662, -0.2864],
         [-0.0366,  0.3278,  0.0909,  ..., -0.1041, -0.1833,  0.2286]]],
       grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [53]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, input_embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_model"]:
        # normalize input
        embedding = self.ln1(input_embedding)
        # compute attention and add skip connection
        mid_embedding = self.attn(embedding) + input_embedding
        # normalize embedding
        embedding = self.ln2(mid_embedding)
        # compute MLP and add skip connection
        output_embedding = self.mlp(embedding) + mid_embedding
        # return output
        return output_embedding

rand_float_test(TransformerBlock, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 0.3258, -0.5019,  0.7033,  ..., -2.2886, -1.2654,  0.4554],
         [-0.6707, -0.9330, -0.8344,  ..., -0.2078,  0.5639, -0.8845],
         [ 1.1695, -0.5519,  0.3030,  ..., -1.8187,  0.7409, -0.7072],
         [ 0.1526, -0.8722,  0.0598,  ...,  1.6024, -1.1545,  0.3237],
         [-0.2018,  0.3114,  0.7507,  ...,  0.8845,  0.6824,  1.0461],
         [ 0.2019, -0.9734,  0.4147,  ...,  1.3799, -1.1554,  1.5493]]],
       grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 768]) 



In [54]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty(cfg.d_model, cfg.d_vocab))
        self.b_U = nn.Parameter(torch.zeros(cfg.d_vocab))
        nn.init.normal_(self.W_U, std=cfg.init_range)

    def forward(
        self, embedding: Float[Tensor, "batch seq_len d_model"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        # compute logits
        logits = einops.einsum(
            embedding,
            self.W_U,
            "batch seq_len d_model, d_model d_vocab -> batch seq_len d_vocab"
        ) + self.b_U
        return logits

rand_float_test(Unembed, [batch, seq_len, cfg.d_model])

Input shape: torch.Size([1, 6, 768])
Output: tensor([[[ 0.6935, -0.1383, -0.4351,  ..., -0.2934, -0.6750, -0.2848],
         [-0.1704,  0.0941,  0.4547,  ..., -0.9353,  0.9911,  0.2873],
         [ 0.1860,  0.4068, -0.0586,  ..., -0.0121,  0.2642, -0.9695],
         [ 0.0307,  0.0652, -0.2422,  ...,  0.7824,  0.6268,  0.7185],
         [-0.2772, -0.1581,  0.2110,  ..., -0.8265, -0.9453,  0.4571],
         [ 0.0063, -1.4805,  0.6233,  ..., -0.4980,  0.0653,  0.2031]]],
       grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 50257]) 



In [59]:
class GPT(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(
        self, input_tokens: Int[Tensor, "batch seq_len"]
    ) -> Float[Tensor, "batch seq_len d_vocab"]:
        # compute embeddings + positional embeddings
        embedding = self.embed(input_tokens) + self.pos_embed(input_tokens)
        # compute transformer blocks outputs
        for block in self.blocks:
            embedding = block(embedding)
        # normalize output
        # compute logits
        logits = self.unembed(self.ln_final(embedding))
        return logits

    def load_gpt2_weights(self, gpt2: GPT2LMHeadModel) -> None:
        state_dict = {}

        state_dict["embed.W_E"] = gpt2.transformer.wte.weight
        state_dict["pos_embed.W_pos"] = gpt2.transformer.wpe.weight

        for l in range(cfg.n_layers):
            state_dict[f"blocks.{l}.ln1.w"] = gpt2.transformer.h[l].ln_1.weight
            state_dict[f"blocks.{l}.ln1.b"] = gpt2.transformer.h[l].ln_1.bias

            # In GPT-2, q,k,v are produced by one big linear map, whose output is
            # concat([q, k, v])
            W = gpt2.transformer.h[l].attn.c_attn.weight
            W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
            W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=cfg.n_heads)
            W_K = einops.rearrange(W_K, "m (i h)->i m h", i=cfg.n_heads)
            W_V = einops.rearrange(W_V, "m (i h)->i m h", i=cfg.n_heads)

            state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
            state_dict[f"blocks.{l}.attn.W_K"] = W_K
            state_dict[f"blocks.{l}.attn.W_V"] = W_V

            qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias
            qkv_bias = einops.rearrange(
                qkv_bias,
                "(qkv index head)->qkv index head",
                qkv=3,
                index=cfg.n_heads,
                head=cfg.d_head,
            )
            state_dict[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
            state_dict[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
            state_dict[f"blocks.{l}.attn.b_V"] = qkv_bias[2]

            W_O = gpt2.transformer.h[l].attn.c_proj.weight
            W_O = einops.rearrange(W_O, "(i h) m->i h m", i=cfg.n_heads)
            state_dict[f"blocks.{l}.attn.W_O"] = W_O
            state_dict[f"blocks.{l}.attn.b_O"] = gpt2.transformer.h[l].attn.c_proj.bias

            state_dict[f"blocks.{l}.ln2.w"] = gpt2.transformer.h[l].ln_2.weight
            state_dict[f"blocks.{l}.ln2.b"] = gpt2.transformer.h[l].ln_2.bias

            W_in = gpt2.transformer.h[l].mlp.c_fc.weight
            state_dict[f"blocks.{l}.mlp.W_in"] = W_in
            state_dict[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias

            W_out = gpt2.transformer.h[l].mlp.c_proj.weight
            state_dict[f"blocks.{l}.mlp.W_out"] = W_out
            state_dict[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
        state_dict["unembed.W_U"] = gpt2.lm_head.weight.T

        state_dict["ln_final.w"] = gpt2.transformer.ln_f.weight
        state_dict["ln_final.b"] = gpt2.transformer.ln_f.bias
        self.load_state_dict(state_dict)

In [60]:
rand_int_test(GPT, [batch, seq_len])

Input shape: torch.Size([1, 6])
Output: tensor([[[ 1.5225e-02,  3.8826e-03, -1.1096e-02,  ...,  1.6017e-02,
           1.3524e-02,  4.9619e-05],
         [ 1.1275e-02,  4.5961e-03, -9.7187e-03,  ...,  1.5540e-02,
           1.1435e-02,  2.0470e-03],
         [ 1.2999e-02,  8.3015e-03, -1.4311e-02,  ...,  1.4610e-02,
           1.4996e-02,  2.6185e-03],
         [ 2.2296e-02, -1.4677e-03, -1.0998e-02,  ...,  1.8303e-02,
           1.7340e-02,  1.4473e-03],
         [ 1.4243e-02,  8.1804e-03, -7.1281e-03,  ...,  1.5325e-02,
           9.9894e-03,  1.1718e-03],
         [ 1.1972e-02,  4.0970e-03, -4.9842e-03,  ...,  1.3690e-02,
           1.1461e-02,  3.1178e-03]]], grad_fn=<AddBackward0>)
Output shape: torch.Size([1, 6, 50257]) 



In [61]:
demo_gpt2 = GPT(Config(debug=False)).to(device)
# demo_gpt2.load_gpt2_weights(gpt2)
demo_gpt2.load_state_dict(hooked_gpt2.state_dict(), strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.2.attn.mask', 'blocks.2.attn.IGNORE', 'blocks.3.attn.mask', 'blocks.3.attn.IGNORE', 'blocks.4.attn.mask', 'blocks.4.attn.IGNORE', 'blocks.5.attn.mask', 'blocks.5.attn.IGNORE', 'blocks.6.attn.mask', 'blocks.6.attn.IGNORE', 'blocks.7.attn.mask', 'blocks.7.attn.IGNORE', 'blocks.8.attn.mask', 'blocks.8.attn.IGNORE', 'blocks.9.attn.mask', 'blocks.9.attn.IGNORE', 'blocks.10.attn.mask', 'blocks.10.attn.IGNORE', 'blocks.11.attn.mask', 'blocks.11.attn.IGNORE'])

In [62]:
# Initialize text
text = "Once upon a"
# Convert text to tensor format
tokens = tokenizer(text, return_tensors="pt").to(device)
print("Generating text...\n")
# Generate 10 characters iteratively
for i in range(20):
    with torch.inference_mode():
        # Get model predictions
        output_logits = demo_gpt2(tokens["input_ids"])
        # Select the most likely next token
        next_token = output_logits[0, -1].argmax(dim=-1)
        # Decode the token to a character
        next_char = tokenizer.decode(next_token)
    # Display the sequence so far
    current_text = tokenizer.decode(tokens["input_ids"][0])  # Reconstruct the string
    print(f"Generation step {i+1}:")
    print(f"Sequence so far: {current_text!r}")
    print(f"{tokens['input_ids'].shape[-1]+1}th char = {next_char!r}\n")
    # Append the new character and re-tokenize
    text += next_char
    tokens = tokenizer(text, return_tensors="pt").to(device)
print("Final text:", text)

Generating text...



RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1