## lets code a transformer!

In [1]:
# note: select conda env(=pytorch) to run the notebook

# !pip install einops fancy_einsum torch numpy tqdm --quiet
# !pip install transformer_lens --quiet
# !pip install git+https://github.com/neelnanda-io/PySvelte.git --quiet

In [2]:
import warnings
warnings.filterwarnings("ignore")

import math
import einops
import torch
import numpy as np 
import torch.nn as nn
import tqdm.auto as tqdm
from fancy_einsum import einsum
from dataclasses import dataclass

from transformer_lens import EasyTransformer
from transformer_lens.utils import get_corner, gelu_new, tokenize_and_concatenate

In [3]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"

reference_gpt2 = EasyTransformer.from_pretrained(
    "gpt2-small", 
    fold_ln=False, 
    center_unembed=False, 
    center_writing_weights=False
)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
tokens = reference_gpt2.to_tokens(reference_text)
print(f"tokens: {tokens.shape}")

tokens = tokens.cuda()

logits, cache = reference_gpt2.run_with_cache(tokens)
print(f"logits: {logits.shape}")

log_probs = logits.log_softmax(dim=-1)
print(f"log_probs: {log_probs.shape}")

probs = logits.softmax(dim=-1)
print(f"probs: {probs.shape}")

tokens: torch.Size([1, 35])
logits: torch.Size([1, 35, 50257])
log_probs: torch.Size([1, 35, 50257])
probs: torch.Size([1, 35, 50257])


In [5]:
# what is most likely next token at each position

list(zip(
    reference_gpt2.to_str_tokens(reference_text),
    reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])
))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]

### activation shapes of reference model

- batch = 1
- position = 35
- d_model = 768
- n_heads = 12
- n_layers = 12
- d_mlp = 3072 (4 * d_model)
- d_head = 64 (d_model / n_heads)

In [6]:
for activation_name, activation in cache.cache_dict.items():
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_pattern torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normalized torc

### config

In [7]:
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'NTK_by_parts_factor': 8.0,
 'NTK_by_parts_high_freq_factor': 4.0,
 'NTK_by_parts_low_freq_factor': 1.0,
 'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_ar

In [None]:
@dataclass
class Config:
    d_model: int = 768  # size of residual stream
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12
    
cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


### tests

In [9]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    
    random_input = torch.randn(shape).cuda()
    print(f"input: {random_input.shape}")
    
    output = layer(random_input)
    print(f"output: {output.shape}")
    
    print(f"output: {get_corner(output)}") # prints first k(=3 default) elements from every axis
    print(f"-"*10)
    return output


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    
    random_input = torch.randint(100, 1000, shape).cuda()
    print(f"input: {random_input.shape}")
    
    output = layer(random_input)
    print(f"output: {output.shape}")
    
    print(f"output: {get_corner(output)}")
    print(f"-"*10)
    return output


def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    layer.load_state_dict(
        gpt2_layer.state_dict(), 
        strict=False
    )
    
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print(f"input: {reference_input.shape}")
        
    output = layer(reference_input)
    print(f"output: {output.shape}")
        
    if str(gpt2_layer).startswith("Attention"):
        reference_output = gpt2_layer(reference_input, reference_input, reference_input)
    else:
        reference_output = gpt2_layer(reference_input)
        
    print(f"reference_output: {reference_output.shape}")
        
    comparison = torch.isclose(
        output, reference_output, atol=1e-4, rtol=1e-3
    )
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    
    print(f"output: {get_corner(output)}")
    print(f"-"*10)
    return output

### layernorm

- make mean 0
- normalize to have variance 1
- scale with learned weights
- translate with learned bias

In [10]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(self.cfg.d_model))
        self.b = nn.Parameter(torch.zeros(self.cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        if self.cfg.debug: 
            print(f"residual: {residual.shape}")
        residual = residual - einops.reduce(
            residual,
            "batch position d_model -> batch position 1", 
            "mean"
        )
        
        # calculate the variance, square root it, add in an epsilon 
        scale = (
        einops.reduce(
            residual.pow(2),
            "batch position d_model -> batch position 1", 
            "mean"
        ) + cfg.layer_norm_eps).sqrt()
        
        # scale
        normalized = residual / scale
        
        # normalized: [batch, position, d_model]
        normalized = normalized * self.w + self.b
        
        if self.cfg.debug: 
            print(f"normalized: {normalized.shape}")
        return normalized

In [11]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

input: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[ 0.5619,  0.9242, -1.0676],
         [ 0.7731, -0.1161, -1.2591],
         [ 0.4798,  2.0295,  0.9904]],

        [[ 0.6936,  0.3799, -0.3563],
         [-0.3409, -1.8706,  1.9035],
         [ 0.0347,  0.7611, -0.2391]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35, 768])
residual: torch.Size([1, 35, 768])
normalized: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 35, 768])
100.00% of the values are correct
output: tensor([[[-0.0667,  0.0881, -0.3085],
         [ 0.0278, -0.2843,  0.2504],
         [-0.5468, -0.5119, -0.6429]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### embedding

- basically a lookup table from tokens to residual stream vectors

In [12]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((self.cfg.d_vocab, self.cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
        
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug:
            print(f"tokens: {tokens.shape}")
        
        # embed: [batch, position, d_model]
        embed = self.W_E[tokens, :]
        
        if self.cfg.debug:
            print(f"embed: {embed.shape}")
        return embed

In [13]:
_ = rand_int_test(Embed, [2, 4])
_ = load_gpt2_test(Embed, reference_gpt2.embed, tokens)

input: torch.Size([2, 4])
tokens: torch.Size([2, 4])
embed: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[ 0.0012, -0.0255, -0.0062],
         [-0.0118,  0.0017, -0.0253],
         [ 0.0072,  0.0229, -0.0031]],

        [[ 0.0195,  0.0319, -0.0237],
         [-0.0006, -0.0248, -0.0095],
         [ 0.0308, -0.0121, -0.0093]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35])
tokens: torch.Size([1, 35])
embed: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 35, 768])
100.00% of the values are correct
output: tensor([[[ 0.0514, -0.0277,  0.0499],
         [ 0.1474, -0.0959,  0.1430],
         [ 0.1596, -0.1249,  0.1148]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### positional embedding

In [14]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((self.cfg.n_ctx, self.cfg.d_model)))
        nn.init.normal(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens = [batch, position]
        if self.cfg.debug:
            print(f"tokens: {tokens.shape}")
        
        # pos_embed = [position, d_model]
        pos_embed = self.W_pos[:tokens.size(1), :]
        
        # pos_embed = [batch, position, d_model]
        pos_embed = einops.repeat(
            pos_embed,
            "position d_model -> batch position d_model",
            batch=tokens.size(0)
        )
        
        if self.cfg.debug:
            print(f"pos_embed: {pos_embed.shape}")
        return pos_embed

In [15]:
_ = rand_int_test(PosEmbed, [2, 4])
_ = load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

input: torch.Size([2, 4])
tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[ 0.0063,  0.0192,  0.0057],
         [ 0.0129,  0.0290, -0.0277],
         [-0.0143,  0.0313, -0.0219]],

        [[ 0.0063,  0.0192,  0.0057],
         [ 0.0129,  0.0290, -0.0277],
         [-0.0143,  0.0313, -0.0219]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35])
tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 35, 768])
100.00% of the values are correct
output: tensor([[[-0.0188, -0.1974,  0.0040],
         [ 0.0240, -0.0538, -0.0949],
         [ 0.0042, -0.0848,  0.0545]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### attention

* step 1: produce an attention pattern - for each destination tokens, probability distribution over previous tokens (incl current token)
    - linear map from input -> query, key. shape = [batch, position, n_heads, d_head]
    - dot product every pair of queries and keys to get attn_scores. shape = [batch, n_heads, query_pos, key_pos] (query = dest, key = source)
    - scale and mask attn_scores to make it lower triangular, ie causal
    - softmax row-wise, to get a probability distribution along each of the key_pos dimension - this is our attention pattern!

* step 2: move information from source tokens to destination token using attention pattern (move = apply linear map)
    - linear map from input -> value. shape = [batch, key_pos, n_heads, d_head]
    - mix along the key_pos with attn pattern to get z, a mixed value. shape of z = [batch, query_pos, n_heads, d_head]
    - map to output. shape = [batch, position, d_model]. position = query_pos, we've summed over all heads

In [16]:
# import pysvelte
# pysvelte.AttentionMulti(
#     tokens=reference_gpt2.to_str_tokens(reference_text), 
#     attention=cache['blocks.0.attn.hook_attn'][0].permute(1, 2, 0)
# ).show()

In [19]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        self.W_Q = nn.Parameter(torch.empty((self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)))
        nn.init.normal(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((self.cfg.n_heads, self.cfg.d_head)))
        
        self.W_K = nn.Parameter(torch.empty((self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)))
        nn.init.normal(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((self.cfg.n_heads, self.cfg.d_head)))
        
        self.W_V = nn.Parameter(torch.empty((self.cfg.n_heads, self.cfg.d_model, self.cfg.d_head)))
        nn.init.normal(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((self.cfg.n_heads, self.cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((self.cfg.n_heads, self.cfg.d_head, self.cfg.d_model)))
        nn.init.normal(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((self.cfg.d_model)))
        
        self.register_buffer(
            "IGNORE",
            torch.tensor(-1e5, dtype=torch.float32, device="cuda")
        )
        
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug:
            print(f"normalized_resid_pre: {normalized_resid_pre.shape}")
            
        q = einsum(
            "batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head",
            normalized_resid_pre,
            self.W_Q
        ) + self.b_Q
        k = einsum(
            "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
            normalized_resid_pre,
            self.W_K
        ) + self.b_K
        
        attn_scores = einsum(
            "batch query_pos n_heads d_model, batch key_pos n_heads d_model -> batch n_heads query_pos key_pos",
            q,
            k
        )
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_nask(attn_scores)
        
        pattern = attn_scores.softmax(dim=-1)
        
        v = einsum(
            "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
            normalized_resid_pre,
            self.W_V
        ) + self.b_V
        
        z = einsum(
            "batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", 
            pattern, 
            v
        )

        attn_out = einsum(
            "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", 
            z, 
            self.W_O
        ) + self.b_O
        return attn_out
                
      
    def apply_causal_nask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]  
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

In [20]:
_ = rand_float_test(Attention, [2, 4, 768])
_ = load_gpt2_test(cls=Attention, gpt2_layer=reference_gpt2.blocks[0].attn, input_name=cache["blocks.0.ln1.hook_normalized"])

input: torch.Size([2, 4, 768])
normalized_resid_pre: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[-0.2244,  0.4576, -0.2527],
         [-0.3080,  0.2571,  0.0214],
         [-0.0862,  0.0386,  0.0249]],

        [[-0.0063, -0.6488, -0.0961],
         [ 0.0497, -0.3078,  0.0510],
         [ 0.0682, -0.0271, -0.0599]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35, 768])
normalized_resid_pre: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 35, 768])
100.00% of the values are correct
output: tensor([[[ 0.7966,  0.0170,  0.0348],
         [ 0.0013,  0.1575, -0.1406],
         [ 0.0897, -0.7241, -0.6987]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### mlp

In [21]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((self.cfg.d_model, self.cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((self.cfg.d_mlp)))
        
        self.W_out = nn.Parameter(torch.empty((self.cfg.d_mlp, self.cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((self.cfg.d_model)))
        
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug:
            print(f"normalized_resid_mid: {normalized_resid_mid.shape}")
        
        pre = einsum(
            "batch position d_model, d_model d_mlp -> batch position d_mlp",
            normalized_resid_mid,
            self.W_in
        ) + self.b_in
        
        post = gelu_new(pre)

        mlp_out = einsum(
            "batch position d_mlp, d_mlp d_model -> batch position d_model",
            post,
            self.W_out 
        ) + self.b_out
        
        if self.cfg.debug:
            print(f"mlp_out: {mlp_out.shape}")
        return mlp_out

In [22]:
_ = rand_float_test(MLP, [2, 4, 768])
_ = load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

input: torch.Size([2, 4, 768])
normalized_resid_mid: torch.Size([2, 4, 768])
mlp_out: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[ 0.1754,  0.1624, -0.2302],
         [-0.4422, -0.4709,  0.0976],
         [ 0.2639, -0.2979,  0.2109]],

        [[ 0.0609,  0.2412,  0.2392],
         [ 0.1290, -0.1086,  0.1249],
         [ 0.1668, -0.4153,  0.3428]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35, 768])
normalized_resid_mid: torch.Size([1, 35, 768])
mlp_out: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 35, 768])
100.00% of the values are correct
output: tensor([[[-0.4380,  0.3624,  0.5117],
         [-1.0766, -0.0438,  0.3276],
         [-1.2182, -1.5481, -0.9702]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### transformer block


In [23]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre = [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        
        return resid_post

In [24]:
_ = rand_float_test(TransformerBlock, [2, 4, 768])
_ = load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

input: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_pre: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_mid: torch.Size([2, 4, 768])
mlp_out: torch.Size([2, 4, 768])
output: torch.Size([2, 4, 768])
output: tensor([[[ 0.3182, -1.0828, -0.7474],
         [-0.5555, -0.1945,  0.5684],
         [ 1.9297,  0.7157, -0.6533]],

        [[-0.8014,  0.2286,  1.2266],
         [-0.1060,  0.3313, -0.3862],
         [-0.5618, -1.1045,  0.7777]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35, 768])
residual: torch.Size([1, 35, 768])
normalized: torch.Size([1, 35, 768])
normalized_resid_pre: torch.Size([1, 35, 768])
residual: torch.Size([1, 35, 768])
normalized: torch.Size([1, 35, 768])
normalized_resid_mid: torch.Size([1, 35, 768])
mlp_out: torch.Size([1, 35, 768])
output: torch.Size([1, 35, 768])
reference_output: torch.Size([1, 3

### unembedding

In [25]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((self.cfg.d_model, self.cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((self.cfg.d_vocab), requires_grad=False))
        
    def forward(self, normalized_resid_final):
        # normalized_resid_final = [batch position, d_model]
        if self.cfg.debug:
            print(f'normalized_resid_final: {normalized_resid_final.shape}')
        
        # logits = [batch, position, d_vocab]
        logits = einsum(
            "batch position d_model, d_model d_vocab -> batch position d_vocab",
            normalized_resid_final,
            self.W_U
        ) + self.b_U
        
        if self.cfg.debug:
            print(f'logits: {logits.shape}')
        return logits

_ = rand_float_test(Unembed, [2, 4, 768])
_ = load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])


input: torch.Size([2, 4, 768])
normalized_resid_final: torch.Size([2, 4, 768])
logits: torch.Size([2, 4, 50257])
output: torch.Size([2, 4, 50257])
output: tensor([[[-0.5075, -0.8049, -1.0693],
         [ 0.0154,  0.1165, -0.6128],
         [ 0.2335,  0.4006, -0.2663]],

        [[-1.2108, -0.1579,  0.1332],
         [-0.4549, -0.6426,  0.3444],
         [-0.0937,  0.2746,  0.3755]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------
input: torch.Size([1, 35, 768])
normalized_resid_final: torch.Size([1, 35, 768])
logits: torch.Size([1, 35, 50257])
output: torch.Size([1, 35, 50257])
reference_output: torch.Size([1, 35, 50257])
100.00% of the values are correct
output: tensor([[[ -43.4317,  -39.8364,  -43.0659],
         [-128.0392, -127.9935, -130.7010],
         [-119.8521, -121.0064, -123.8819]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
----------


### full transformer

In [26]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(self.cfg.n_layers)]
        )
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
        
    def forward(self, tokens):
        # tokens = [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(embed)
        residual = embed + pos_embed # shape = [batch, position, d_model]
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual) # shape = [batch, position, d_model]
        logits = self.unembed(normalized_resid_final) # shape = [batch, position, d_vocab]
        return logits

_ = rand_int_test(DemoTransformer, [2, 4])
_ = load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

input: torch.Size([2, 4])
tokens: torch.Size([2, 4])
embed: torch.Size([2, 4, 768])
tokens: torch.Size([2, 4, 768])
pos_embed: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_pre: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_mid: torch.Size([2, 4, 768])
mlp_out: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_pre: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_mid: torch.Size([2, 4, 768])
mlp_out: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_pre: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
normalized_resid_mid: torch.Size([2, 4, 768])
mlp_out: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Si

### try it out!

In [27]:
demo_gpt2 = DemoTransformer(Config(debug=True))
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.cuda()

DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

In [28]:
test_string = """Mini scule is a species of microhylid frog endemic to Madagascar that was described in 2019. The scientific name of the species refers to its size, being a pun on the word minuscule. It is very small, measuring only 8.4 to 10.8 mm (0.33 to 0.43 in) in snout–vent length. It has bronze underparts with a brown groin and back of the thigh, cream upperparts with brown flecking, a dark brown side of the head, and a red iris. On the hind feet, the first toe is absent and the second and fifth toes are strongly reduced. The frog is known only from the Sainte Luce Reserve, where it inhabits areas with deep leaf litter near semi-permanent water bodies. Specimens of frogs from Mandena, the Vohimena mountains, the southern Anosy Mountains, and Tsitongambarika may also be of this species. Along with Mini mum and Mini ature, the other two species in its genus, it received media attention when first described due to the wordplay in its scientific name. (Full article...)"""

In [29]:
test_tokens = reference_gpt2.to_tokens(test_string).cuda()
demo_logits = demo_gpt2(test_tokens)

tokens: torch.Size([1, 237])
embed: torch.Size([1, 237, 768])
tokens: torch.Size([1, 237, 768])
pos_embed: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_pre: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_mid: torch.Size([1, 237, 768])
mlp_out: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_pre: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_mid: torch.Size([1, 237, 768])
mlp_out: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_pre: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 768])
normalized: torch.Size([1, 237, 768])
normalized_resid_mid: torch.Size([1, 237, 768])
mlp_out: torch.Size([1, 237, 768])
residual: torch.Size([1, 237, 

In [30]:
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()
loss = lm_cross_entropy_loss(demo_logits, test_tokens)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

tensor(3.7186, device='cuda:0', grad_fn=<NegBackward0>)
Loss as average prob tensor(0.0243, device='cuda:0', grad_fn=<ExpBackward0>)
Loss as 'uniform over this many variables' tensor(41.2079, device='cuda:0', grad_fn=<ExpBackward0>)
Uniform loss over the vocab 10.82490511970208
