In [2]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

In [3]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2", fold_ln=False, center_unembed=False, center_writing_weights=False)

Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!


In [4]:
reference_text = """Google decreased revenue by 50% this quarter. // Sell
Apple increased revenue by 50% this quarter. // Buy
OpenAI increased revenue by 50% this quarter. // Buy
Tesla decreased revenue by 50% this quarter. // Sell
Audi increased revenue by 50% this quarter. // Buy
Microsoft decreased revenue by 50% this quarter. // Sell
Ford decreased decreased revenue by 40% this quarter. //"""

tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256, 11708, 11832,  6426,   416,  2026,     4,   428,  3860,    13,
          3373, 25688,   198, 16108,  3220,  6426,   416,  2026,     4,   428,
          3860,    13,  3373, 11763,   198, 11505, 20185,  3220,  6426,   416,
          2026,     4,   428,  3860,    13,  3373, 11763,   198, 41351, 11832,
          6426,   416,  2026,     4,   428,  3860,    13,  3373, 25688,   198,
         16353,    72,  3220,  6426,   416,  2026,     4,   428,  3860,    13,
          3373, 11763,   198, 15905, 11832,  6426,   416,  2026,     4,   428,
          3860,    13,  3373, 25688,   198, 37308, 11832, 11832,  6426,   416,
          2319,     4,   428,  3860,    13,  3373]])
torch.Size([1, 86])
['<|endoftext|>', 'Google', ' decreased', ' revenue', ' by', ' 50', '%', ' this', ' quarter', '.', ' //', ' Sell', '\n', 'Apple', ' increased', ' revenue', ' by', ' 50', '%', ' this', ' quarter', '.', ' //', ' Buy', '\n', 'Open', 'AI', ' increased', ' revenue', ' by', ' 50', '%', ' this', ' qua

In [5]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 86, 50257])


In [6]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)

torch.Size([1, 86, 50257])
torch.Size([1, 86, 50257])


In [7]:
list(zip(reference_gpt2.to_str_tokens(reference_text), reference_gpt2.tokenizer.batch_decode(probs.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('Google', ','),
 (' decreased', ','),
 (' revenue', ','),
 (' by', ','),
 (' 50', ','),
 ('%', ','),
 (' this', ','),
 (' quarter', ','),
 ('.', ','),
 (' //', ','),
 (' Sell', ','),
 ('\n', '\n'),
 ('Apple', '\n'),
 (' increased', ','),
 (' revenue', ','),
 (' by', ','),
 (' 50', ','),
 ('%', '\n'),
 (' this', '\n'),
 (' quarter', ','),
 ('.', '\n'),
 (' //', ','),
 (' Buy', ','),
 ('\n', '\n'),
 ('Open', '\n'),
 ('AI', '\n'),
 (' increased', ','),
 (' revenue', ','),
 (' by', ','),
 (' 50', ','),
 ('%', ','),
 (' this', ','),
 (' quarter', '\n'),
 ('.', '\n'),
 (' //', ','),
 (' Buy', ','),
 ('\n', '\n'),
 ('Tesla', ','),
 (' decreased', '\n'),
 (' revenue', ','),
 (' by', ','),
 (' 50', ','),
 ('%', ','),
 (' this', '\n'),
 (' quarter', '\n'),
 ('.', '\n'),
 (' //', '\n'),
 (' Sell', '.'),
 ('\n', '\n'),
 ('Aud', ','),
 ('i', ','),
 (' increased', ','),
 (' revenue', ','),
 (' by', ','),
 (' 50', ','),
 ('%', ','),
 (' this', ','),
 (' quarter', '.'),
 ('

## Config

In [7]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests

In [8]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randn(shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randint(100, 1000, shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str): 
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

## Reference Shapes

In [9]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.attn.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.attn.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_attn torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normaliz

## Reference parameters

In [10]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


## Layer Norm

In [12]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        reduced = einops.reduce(residual, 'batch position d_model -> batch position', 'mean')

        
        
        centered_res = residual - einops.repeat(reduced, 'batch position -> batch position d_model', d_model=self.cfg.d_model)
        
        # normalize
        scale = einops.reduce(centered_res.pow(2), 'batch position model -> batch position', 'mean') +  self.cfg.layer_norm_eps

        
        normalized = centered_res / einops.repeat(scale.sqrt(), 'batch position -> batch position d_model', d_model=self.cfg.d_model)

        return normalized * self.w + self.b





In [13]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


## Embedding

In [15]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        return self.W_E[tokens, :]

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       grad_fn=<IndexBackward0>)

## Positional Embedding

In [16]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
          -2.3037e-03, -4.3189e-03],
         [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
          -4.8160e-03, -9.2252e-04],
         [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
          -3.2512e-03, -9.6509e-04]]], grad_fn=<ExpandBackward0>)

## Attention

In [17]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]

        q = einsum('batch query_pos d_model, n_head d_model d_head -> batch query_pos n_head d_head', normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum('batch key_pos d_model, n_head d_model d_head -> batch key_pos n_head d_head', normalized_resid_pre, self.W_K) + self.b_K
    
        attn = einsum('batch query_pos n_head d_head, batch key_pos n_head d_head -> batch n_head query_pos key_pos', q, k)
        attn = attn / math.sqrt(self.cfg.d_head) # idk maybe n_heads
        attn = self.apply_causal_mask(attn)
        attn = attn.softmax(dim=-1)

        # take values
        v = einsum('batch key_pos d_model, n_head d_model d_head -> batch key_pos n_head d_head', normalized_resid_pre, self.W_V) + self.b_V
        
        z = einsum('batch key_pos n_head d_head, batch n_head query_pos key_pos -> batch query_pos n_head d_head', v, attn)

        return einsum('batch position n_head d_head, n_head d_head d_model -> batch position d_model', z, self.W_O) + self.b_O

    
    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.attn.ln1.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
91.77% of the values are correct


tensor([[[ 1.3649e+00,  2.1711e+00,  7.0825e+00,  ..., -1.4679e-01,
           2.6480e-01,  9.8746e-01],
         [-1.3159e+01, -4.1196e+00,  8.6870e+00,  ..., -4.7698e-01,
          -2.4685e-01,  3.7986e-01],
         [-1.7002e+01,  4.8321e+00, -6.2118e-01,  ..., -7.1945e-01,
           1.0781e+00,  5.4464e-01],
         ...,
         [-1.3211e+01,  7.5175e-01,  8.9662e+00,  ..., -4.2861e-01,
           4.6559e-01, -9.4983e-01],
         [-1.3922e-03,  6.5740e+00,  1.9785e+01,  ..., -6.7092e-01,
          -1.0935e-01,  7.8008e-02],
         [-6.0138e+00, -1.8512e-01,  1.8866e+01,  ..., -5.4550e-01,
          -4.9667e-02, -1.4721e-01]]], grad_fn=<AddBackward0>)

## MLP

In [18]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        fir = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        sec = gelu_new(fir)
        out = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", sec, self.W_out) + self.b_out
        return out

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.5493,  1.2235,  1.7083,  ...,  0.3605, -0.3244, -1.0762],
         [-0.5139,  1.5578,  1.1685,  ...,  0.1968, -0.4189,  0.2061],
         [ 0.8885, -0.2943,  1.2722,  ...,  0.1263, -0.1853, -0.6608],
         ...,
         [ 0.3520,  0.6832,  0.3432,  ...,  0.4583,  0.2621, -1.3458],
         [ 1.5537,  1.5894,  0.7426,  ..., -0.2869, -0.6618,  0.4345],
         [ 1.1247,  1.8329, -0.5603,  ...,  0.4421, -0.3301, -0.4539]]],
       grad_fn=<AddBackward0>)

## Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        norm_resid_pre = self.ln1(resid_pre)
        resid_pre += self.attn(norm_resid_pre)
        resid_mid = self.ln2(resid_pre)
        return resid_pre + self.mlp(resid_mid)


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

## Unembedding

In [27]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        
        return einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[-52.2799, -51.3340, -53.2627,  ..., -60.2047, -58.1461, -52.3683],
         [-22.4456, -22.5785, -24.8806,  ..., -28.1892, -27.3050, -22.6219],
         [-20.1106, -20.3379, -22.7917,  ..., -24.9888, -24.2043, -20.6488],
         ...,
         [-25.4227, -25.0626, -27.5736,  ..., -29.2495, -29.2111, -26.1180],
         [-20.1737, -20.5050, -23.1387,  ..., -25.3509, -24.2196, -20.6803],
         [-49.3418, -49.4090, -51.9608,  ..., -56.9139, -56.5944, -51.0995]]],
       grad_fn=<AddBackward0>)

## Full Transformer

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

In [1]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

In [5]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2", fold_ln=False, center_unembed=False, center_writing_weights=False)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!


In [15]:
reference_text = """Apple revenue has increased by 5% this quarter. // Buy
Google missed earning by 10% this quarter. // Sell
OpenAI has replaced the CEO Sam Altman after the board voted. // Sell
Tesla has reduced salaries by 20% after meeting with the unions. // Buy
Ford anticipated its operating profit to improve. //
"""
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256, 16108,  6426,   468,  3220,   416,   642,     4,   428,  3860,
            13,  3373, 11763,   198, 11708,  6825, 13748,   416,   838,     4,
           428,  3860,    13,  3373, 25688,   198, 11505, 20185,   468,  6928,
           262,  6123,  3409, 12344,   805,   706,   262,  3096,  7052,    13,
          3373, 25688,   198, 41351,   468,  5322, 17058,   416,  1160,     4,
           706,  3249,   351,   262, 11936,    13,  3373, 11763,   198, 37308,
         14486,   663,  5361,  7630,   284,  2987,    13,  3373,   198]])
torch.Size([1, 69])
['<|endoftext|>', 'Apple', ' revenue', ' has', ' increased', ' by', ' 5', '%', ' this', ' quarter', '.', ' //', ' Buy', '\n', 'Google', ' missed', ' earning', ' by', ' 10', '%', ' this', ' quarter', '.', ' //', ' Sell', '\n', 'Open', 'AI', ' has', ' replaced', ' the', ' CEO', ' Sam', ' Alt', 'man', ' after', ' the', ' board', ' voted', '.', ' //', ' Sell', '\n', 'Tesla', ' has', ' reduced', ' salaries', ' by', ' 20', '%', ' afte

In [16]:
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 69, 50257])


In [17]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)


torch.Size([1, 69, 50257])
torch.Size([1, 69, 50257])


In [18]:
list(zip(reference_gpt2.to_str_tokens(reference_text), reference_gpt2.tokenizer.batch_decode(probs.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('Apple', ','),
 (' revenue', ','),
 (' has', ','),
 (' increased', ','),
 (' by', ','),
 (' 5', ','),
 ('%', ','),
 (' this', ','),
 (' quarter', ','),
 ('.', '\n'),
 (' //', ','),
 (' Buy', ','),
 ('\n', '\n'),
 ('Google', ','),
 (' missed', ','),
 (' earning', ','),
 (' by', ' the'),
 (' 10', ' the'),
 ('%', ','),
 (' this', ' the'),
 (' quarter', ' the'),
 ('.', '\n'),
 (' //', ','),
 (' Sell', ','),
 ('\n', '\n'),
 ('Open', ','),
 ('AI', ','),
 (' has', ','),
 (' replaced', ','),
 (' the', ','),
 (' CEO', ','),
 (' Sam', ' the'),
 (' Alt', ' the'),
 ('man', ','),
 (' after', ' the'),
 (' the', ' the'),
 (' board', ','),
 (' voted', ','),
 ('.', '\n'),
 (' //', ' and'),
 (' Sell', ','),
 ('\n', '\n'),
 ('Tesla', ','),
 (' has', ','),
 (' reduced', ','),
 (' salaries', ','),
 (' by', ','),
 (' 20', ','),
 ('%', ' and'),
 (' after', ' the'),
 (' meeting', ','),
 (' with', ','),
 (' the', ','),
 (' unions', ' and'),
 ('.', '\n'),
 (' //', ','),
 (' Buy', ','

## Config

In [7]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests

In [8]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randn(shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    random_input = torch.randint(100, 1000, shape)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str): 
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

## Reference Shapes

In [9]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.attn.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.attn.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_attn torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normaliz

## Reference parameters

In [10]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


## Layer Norm

In [12]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        reduced = einops.reduce(residual, 'batch position d_model -> batch position', 'mean')

        
        
        centered_res = residual - einops.repeat(reduced, 'batch position -> batch position d_model', d_model=self.cfg.d_model)
        
        # normalize
        scale = einops.reduce(centered_res.pow(2), 'batch position model -> batch position', 'mean') +  self.cfg.layer_norm_eps

        
        normalized = centered_res / einops.repeat(scale.sqrt(), 'batch position -> batch position d_model', d_model=self.cfg.d_model)

        return normalized * self.w + self.b





In [13]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


## Embedding

In [15]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        return self.W_E[tokens, :]

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       grad_fn=<IndexBackward0>)

## Positional Embedding

In [16]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
          -2.3037e-03, -4.3189e-03],
         [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
          -4.8160e-03, -9.2252e-04],
         [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
          -3.2512e-03, -9.6509e-04]]], grad_fn=<ExpandBackward0>)

## Attention

In [17]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]

        q = einsum('batch query_pos d_model, n_head d_model d_head -> batch query_pos n_head d_head', normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum('batch key_pos d_model, n_head d_model d_head -> batch key_pos n_head d_head', normalized_resid_pre, self.W_K) + self.b_K
    
        attn = einsum('batch query_pos n_head d_head, batch key_pos n_head d_head -> batch n_head query_pos key_pos', q, k)
        attn = attn / math.sqrt(self.cfg.d_head) # idk maybe n_heads
        attn = self.apply_causal_mask(attn)
        attn = attn.softmax(dim=-1)

        # take values
        v = einsum('batch key_pos d_model, n_head d_model d_head -> batch key_pos n_head d_head', normalized_resid_pre, self.W_V) + self.b_V
        
        z = einsum('batch key_pos n_head d_head, batch n_head query_pos key_pos -> batch query_pos n_head d_head', v, attn)

        return einsum('batch position n_head d_head, n_head d_head d_model -> batch position d_model', z, self.W_O) + self.b_O

    
    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.attn.ln1.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
91.77% of the values are correct


tensor([[[ 1.3649e+00,  2.1711e+00,  7.0825e+00,  ..., -1.4679e-01,
           2.6480e-01,  9.8746e-01],
         [-1.3159e+01, -4.1196e+00,  8.6870e+00,  ..., -4.7698e-01,
          -2.4685e-01,  3.7986e-01],
         [-1.7002e+01,  4.8321e+00, -6.2118e-01,  ..., -7.1945e-01,
           1.0781e+00,  5.4464e-01],
         ...,
         [-1.3211e+01,  7.5175e-01,  8.9662e+00,  ..., -4.2861e-01,
           4.6559e-01, -9.4983e-01],
         [-1.3922e-03,  6.5740e+00,  1.9785e+01,  ..., -6.7092e-01,
          -1.0935e-01,  7.8008e-02],
         [-6.0138e+00, -1.8512e-01,  1.8866e+01,  ..., -5.4550e-01,
          -4.9667e-02, -1.4721e-01]]], grad_fn=<AddBackward0>)

## MLP

In [18]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        fir = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", normalized_resid_mid, self.W_in) + self.b_in
        sec = gelu_new(fir)
        out = einsum("batch position d_model, d_model d_mlp -> batch position d_mlp", sec, self.W_out) + self.b_out
        return out

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.5493,  1.2235,  1.7083,  ...,  0.3605, -0.3244, -1.0762],
         [-0.5139,  1.5578,  1.1685,  ...,  0.1968, -0.4189,  0.2061],
         [ 0.8885, -0.2943,  1.2722,  ...,  0.1263, -0.1853, -0.6608],
         ...,
         [ 0.3520,  0.6832,  0.3432,  ...,  0.4583,  0.2621, -1.3458],
         [ 1.5537,  1.5894,  0.7426,  ..., -0.2869, -0.6618,  0.4345],
         [ 1.1247,  1.8329, -0.5603,  ...,  0.4421, -0.3301, -0.4539]]],
       grad_fn=<AddBackward0>)

## Transformer Block

In [26]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        norm_resid_pre = self.ln1(resid_pre)
        resid_pre += self.attn(norm_resid_pre)
        resid_mid = self.ln2(resid_pre)
        return resid_pre + self.mlp(resid_mid)


rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
0.02% of the values are correct


tensor([[[ 0.7221, -0.0078,  2.6568,  ...,  1.3642,  1.4635,  0.6851],
         [-1.7113, -0.8840,  2.0401,  ...,  0.8022,  0.7299,  0.1252],
         [-1.7759, -2.4956,  1.2355,  ...,  1.2241,  0.6944,  0.1922],
         ...,
         [-2.4313, -1.1537, -0.3350,  ...,  0.7296,  0.4505, -0.3572],
         [-3.0099, -1.0473,  0.6373,  ...,  0.5975,  0.4731, -0.0175],
         [-1.8263, -1.9858,  1.1629,  ...,  0.8742,  0.6700,  0.0920]]],
       grad_fn=<AddBackward0>)

## Unembedding

In [27]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        
        return einsum("batch position d_model, d_model d_vocab -> batch position d_vocab", normalized_resid_final, self.W_U) + self.b_U

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[-52.2799, -51.3340, -53.2627,  ..., -60.2047, -58.1461, -52.3683],
         [-22.4456, -22.5785, -24.8806,  ..., -28.1892, -27.3050, -22.6219],
         [-20.1106, -20.3379, -22.7917,  ..., -24.9888, -24.2043, -20.6488],
         ...,
         [-25.4227, -25.0626, -27.5736,  ..., -29.2495, -29.2111, -26.1180],
         [-20.1737, -20.5050, -23.1387,  ..., -25.3509, -24.2196, -20.6803],
         [-49.3418, -49.4090, -51.9608,  ..., -56.9139, -56.5944, -51.0995]]],
       grad_fn=<AddBackward0>)

## Full Transformer

In [28]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
0.00% of the values are correct


tensor([[[ -43.4318,  -39.8365,  -43.0660,  ...,  -54.0878,  -54.3452,
           -42.3645],
         [-128.0392, -127.9935, -130.7010,  ..., -136.7122, -129.9262,
          -129.3966],
         [-119.8521, -121.0064, -123.8819,  ..., -128.5180, -126.6027,
          -121.9061],
         ...,
         [-112.9815, -112.7749, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6724, -104.4888, -108.7361,  ..., -118.3552, -113.8766,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1971, -138.5883,
          -122.3698]]], grad_fn=<AddBackward0>)