<a href="https://colab.research.google.com/github/TDulka/transformers/blob/main/GPT2_Reimplementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Practice reimplementation of GPT-2

## Setup libraries

In [21]:
%pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
%pip install git+https://github.com/neelnanda-io/PySvelte.git

Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git (to revision clean-transformer-demo) to /tmp/pip-req-build-faab4ymz
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-faab4ymz
  Running command git checkout -b clean-transformer-demo --track origin/clean-transformer-demo
  Switched to a new branch 'clean-transformer-demo'
  Branch 'clean-transformer-demo' set up to track remote branch 'clean-transformer-demo' from 'origin'.
  Resolved https://github.com/neelnanda-io/Easy-Transformer.git to commit 1f25219e631aeb478d17075d47274db32c874e88
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from easy-transformer==0.1.0)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Co

In [22]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

In [23]:
reference_gpt2 = EasyTransformer.from_pretrained('gpt2-small', fold_ln=False, center_unembed=False, center_writing_weights=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


In [None]:
sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda i:i[1])[-30:]

[('×©', 50227),
 ('romy', 50228),
 ('JM', 50229),
 ('ĠEnhancement', 50230),
 ('bush', 50231),
 ('Skip', 50232),
 ('Ġrappers', 50233),
 ('Ġgazing', 50234),
 ('pedia', 50235),
 ('athlon', 50236),
 ('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

In [None]:
test_text = "Mary and ralph had a miniscule lamb"
tokens = reference_gpt2.to_tokens(test_text)
tokens

tensor([[50256, 24119,   290,   374, 17307,   550,   257,   949,  2304,  2261,
         19343]])

In [None]:
tokens = tokens.cuda()
logits, cache = reference_gpt2.run_with_cache(tokens)

In [None]:
print(logits.shape)

torch.Size([1, 11, 50257])


In [None]:
log_probs = logits.log_softmax(dim=-1)

In [None]:
logits.argmax(dim=-1)

tensor([[  198,  1044,   314, 10757,   494,   587,  1049, 16241,  2261,  2033,
            11]], device='cuda:0')

In [None]:
reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])

['\n',
 'land',
 ' I',
 'andy',
 'ie',
 ' been',
 ' great',
 'usc',
 'ule',
 ' amount',
 ',']

In [None]:
list(zip(reference_gpt2.to_str_tokens(test_text), reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('Mary', 'land'),
 (' and', ' I'),
 (' r', 'andy'),
 ('alph', 'ie'),
 (' had', ' been'),
 (' a', ' great'),
 (' min', 'usc'),
 ('isc', 'ule'),
 ('ule', ' amount'),
 (' lamb', ',')]

In [None]:
next_token = logits[0, -1].argmax(dim=-1)
next_token

tensor(11, device='cuda:0')

In [None]:
next_tokens = torch.cat([tokens, torch.tensor(next_token, device='cuda', dtype=torch.int64)[None, None]], dim=-1)
new_logits = reference_gpt2(next_tokens)
print("New Input: ", reference_gpt2.tokenizer.decode(next_tokens[0]))

print(new_logits.shape)
print(reference_gpt2.tokenizer.decode(new_logits[-1,-1].argmax(-1)))


New Input:  <|endoftext|>Mary and ralph had a miniscule lamb,
torch.Size([1, 12, 50257])
 but


  next_tokens = torch.cat([tokens, torch.tensor(next_token, device='cuda', dtype=torch.int64)[None, None]], dim=-1)


In [None]:
# batch = 1
# position = 35
# d_model = 768
# n_heads = 12
# n_layers = 12
# d_mlp = 3072 (4 * d_model) This is apparently arbitrary
# d_head = 64 (d_model / n_heads)

In [None]:
for activation_name, activation in cache.cache_dict.items():
  if '.0.' in activation_name or 'blocks' not in activation_name:
    print(activation_name, activation.shape)

hook_embed torch.Size([1, 11, 768])
hook_pos_embed torch.Size([1, 11, 768])
blocks.0.hook_resid_pre torch.Size([1, 11, 768])
blocks.0.ln1.hook_scale torch.Size([1, 11, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 11, 768])
blocks.0.attn.hook_q torch.Size([1, 11, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 11, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 11, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 11, 11])
blocks.0.attn.hook_attn torch.Size([1, 12, 11, 11])
blocks.0.attn.hook_z torch.Size([1, 11, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 11, 768])
blocks.0.hook_resid_mid torch.Size([1, 11, 768])
blocks.0.ln2.hook_scale torch.Size([1, 11, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 11, 768])
blocks.0.mlp.hook_pre torch.Size([1, 11, 3072])
blocks.0.mlp.hook_post torch.Size([1, 11, 3072])
blocks.0.hook_mlp_out torch.Size([1, 11, 768])
blocks.0.hook_resid_post torch.Size([1, 11, 768])
ln_final.hook_scale torch.Size([1, 11, 1])
ln_final.hook_normalized torch.S

In [None]:
for name, param in reference_gpt2.named_parameters():
  if '.0.' in name or 'blocks' not in name:
    print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


In [None]:
print(reference_gpt2.cfg)

EasyTransformerConfig(n_layers=12, d_model=768, n_ctx=1024, d_head=64, model_name='gpt2-small', n_heads=12, d_mlp=3072, act_fn='gelu_new', d_vocab=50257, eps=1e-05, use_attn_result=False, use_attn_scale=True, use_local_attn=False, model_family='gpt2', checkpoint=None, tokenizer_name='gpt2', window_size=None, attn_types=None, init_mode='gpt2', normalization_type='LN', device='cuda', attention_dir='causal', attn_only=False, seed=42, initializer_range=0.02886751345948129, init_weights=False, scale_attn_by_inverse_layer_idx=False, positional_embedding_type='standard', final_rms=False, d_vocab_out=50257, parallel_attn_mlp=False, rotary_dim=64, dtype=torch.float32)


In [None]:
@dataclass
class Config:
  d_model: int = 768
  debug: bool = True
  layer_norm_eps: float = 1e-5
  d_vocab: int = 50257
  init_range: float = 0.02
  n_ctx: int = 1024
  d_head: int = 64
  d_mlp: int = 3072
  n_heads: int = 12
  n_layers: int = 12

cfg = Config()
print(cfg)


Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [None]:
# Tests
def rand_float_test(cls, shape):
  cfg = Config(debug=True)
  layer = cls(cfg).cuda()
  random_input = torch.randn(shape).cuda()
  print("Input shape:", random_input.shape)
  output = layer(random_input)
  print("Output shape:", output.shape)
  print()
  return output

def rand_int_test(cls, shape):
  cfg = Config(debug=True)
  layer = cls(cfg).cuda()
  random_input = torch.randint(100, 1000, shape).cuda()
  print("Input shape:", random_input.shape)
  output = layer(random_input)
  print("Output shape:", output.shape)
  print()
  return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
  cfg = Config(debug=True)
  layer = cls(cfg).cuda()
  layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
  # Allow inputs of strings or tensors
  if isinstance(input_name, str):
    reference_input = cache_dict[input_name]
  else:
    reference_input = input_name

  print("Input shape:", reference_input.shape)
  output = layer(reference_input)
  print("Output shape:", output.shape)
  reference_output = gpt2_layer(reference_input)
  print("Reference output shape:", reference_output.shape)

  comparison = torch.isclose(output, reference_output, atol=1e-4,rtol=1e-3)
  print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
  return output

In [None]:
class LayerNorm(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.w = nn.Parameter(torch.ones(cfg.d_model))
    self.b = nn.Parameter(torch.zeros(cfg.d_model))
  def forward(self, residual):
    # residual: [batch, position, d_model]
    residual = residual - residual.mean(dim = 2, keepdim = True)
    std = (residual.var(dim = 2, correction = 0, keepdim = True) + self.cfg.layer_norm_eps).sqrt()
    normalized = self.w * residual / std + self.b
    return normalized


In [None]:
rand_float_test(LayerNorm, [2,4,768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 11, 768])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class Embed(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
    nn.init.normal_(self.W_E, std = cfg.init_range)

  def forward(self, tokens):
    # tokens: [batch, position]
    embedded = self.W_E[tokens, :]
    return embedded

rand_int_test(Embed, [5,10])
_ = load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([5, 10])
Output shape: torch.Size([5, 10, 768])

Input shape: torch.Size([1, 11])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class PosEmbed(nn.Module):
  def __init__(self, cfg):
     super().__init__()
     self.cfg = cfg
     self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
     nn.init.normal_(self.W_pos, std = cfg.init_range)

  def forward(self, tokens):
    # tokens: [batch, position]
    batch_size, seq_length = tokens.shape

    pos_embed = self.W_pos[:seq_length, :]
    pos_embed = einops.repeat(pos_embed, 'position d_model -> batch position d_model', batch = batch_size)
    return pos_embed

rand_int_test(PosEmbed, [2,4])
_ = load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)



Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 11])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class Attention(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_Q, std = cfg.init_range)
    self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
    self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_K, std = cfg.init_range)
    self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
    self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
    nn.init.normal_(self.W_V, std = cfg.init_range)
    self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

    self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
    nn.init.normal_(self.W_O, std = cfg.init_range)
    self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

    self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device='cuda'))

  def forward(self, normalized_resid_pre):
  # normalized_resid_pre: [batch, position, d_model]
    Q = einops.einsum(normalized_resid_pre, self.W_Q, 'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head') + self.b_Q
    K = einops.einsum(normalized_resid_pre, self.W_K, 'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head') + self.b_K
    V = einops.einsum(normalized_resid_pre, self.W_V, 'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head') + self.b_V

    attn_scores = einops.einsum(Q, K, 'batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos')
    attn_scores = attn_scores / math.sqrt(self.cfg.d_head)
    attn_scores = self.apply_causal_mask(attn_scores)
    attn_pattern = attn_scores.softmax(dim=-1)

    z = einops.einsum(attn_pattern, V, 'batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head')

    output = einops.einsum(z, self.W_O, 'batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model') + self.b_O

    return output

  def apply_causal_mask(self, attn_scores):
    # attn_scores: [batch, n_heads, query_pos, key_pos]
    batch, n_heads, query_pos, key_pos = attn_scores.shape
    mask = torch.triu(torch.ones((query_pos, key_pos), device=attn_scores.device), diagonal=1).bool()
    mask = einops.repeat(mask, 'query_pos key_pos -> batch n_heads query_pos key_pos', batch = batch, n_heads = n_heads)
    masked_scores = attn_scores.masked_fill(mask, float('-inf'))

    return masked_scores

rand_float_test(Attention, [2,9,768])
_ = load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])

Input shape: torch.Size([2, 9, 768])
Output shape: torch.Size([2, 9, 768])

Input shape: torch.Size([1, 11, 768])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class MLP(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
    nn.init.normal_(self.W_in, std=cfg.init_range)
    self.b_in = nn.Parameter(torch.zeros(cfg.d_mlp))
    self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
    nn.init.normal_(self.W_out, std=cfg.init_range)
    self.b_out = nn.Parameter(torch.zeros(cfg.d_model))

  def forward(self, normalized_resid_mid):
    # normalized_resid_mid: [batch, position, d_model]
    inside = einops.einsum(normalized_resid_mid, self.W_in, 'batch position d_model, d_model d_mlp -> batch position d_mlp') + self.b_in
    activation = gelu_new(inside)
    outside = einops.einsum(activation, self.W_out, 'batch position d_mlp, d_mlp d_model -> batch position d_model') + self.b_out

    return outside



rand_float_test(MLP, [2,4,768])
_ = load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache['blocks.0.ln2.hook_normalized'])



Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 11, 768])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class TransformerBlock(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg

    self.ln1 = LayerNorm(cfg)
    self.attn = Attention(cfg)
    self.ln2 = LayerNorm(cfg)
    self.mlp = MLP(cfg)
  def forward(self, resid_pre):
    # resid_pre: [batch, position, d_model]
    normalized_pre = self.ln1.forward(resid_pre)
    attn_out = self.attn.forward(normalized_pre)
    resid_mid = attn_out + resid_pre
    normalized_mid = self.ln2.forward(resid_mid)
    mlp_out = self.mlp.forward(normalized_mid)
    resid_post = mlp_out + resid_mid
    return resid_post


rand_float_test(TransformerBlock, [2,4,768])
_ = load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 11, 768])
Output shape: torch.Size([1, 11, 768])
Reference output shape: torch.Size([1, 11, 768])
100.00% of the values are correct


In [None]:
class Unembed(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
    nn.init.normal_(self.W_U, std=cfg.init_range)
    self.b_U = nn.Parameter(torch.zeros(cfg.d_vocab))

  def forward(self, normalized_resid_final):
    logits = einops.einsum(normalized_resid_final, self.W_U, 'batch position d_model, d_model d_vocab -> batch position d_vocab') + self.b_U
    return logits

rand_float_test(Unembed, [2,4,768])
_ = load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 11, 768])
Output shape: torch.Size([1, 11, 50257])
Reference output shape: torch.Size([1, 11, 50257])
100.00% of the values are correct


In [None]:
class DemoTransformer(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.cfg = cfg
    self.embed = Embed(cfg)
    self.pos_embed = PosEmbed(cfg)
    self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
    self.ln_final = LayerNorm(cfg)
    self.unembed = Unembed(cfg)

  def forward(self, tokens):
    embedded = self.embed.forward(tokens) + self.pos_embed.forward(tokens)

    for block in self.blocks:
      embedded = block.forward(embedded)

    final_norm = self.ln_final.forward(embedded)
    logits = self.unembed.forward(final_norm)

    return logits

rand_int_test(DemoTransformer, [2,4])
_ = load_gpt2_test(DemoTransformer, reference_gpt2, tokens)



Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 11])
Output shape: torch.Size([1, 11, 50257])
Reference output shape: torch.Size([1, 11, 50257])
100.00% of the values are correct


In [None]:
demo_gpt2 = DemoTransformer(Config(debug=False))
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.cuda()

DemoTransformer(
  (embed): Embed()
  (pos_embed): PosEmbed()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNorm()
      (attn): Attention()
      (ln2): LayerNorm()
      (mlp): MLP()
    )
  )
  (ln_final): LayerNorm()
  (unembed): Unembed()
)

In [None]:
test_string = "Murasaki Shikibu (c. 973 – c. 1014 or 1025) was a Japanese novelist, poet and lady-in-waiting at the Imperial court in the Heian period. She is best known as the author of The Tale of Genji, written in Japanese between about 1000 and 1012. She became a lady-in-waiting to Empress Shōshi at the Imperial court around 1005, and continued to write during her service, adding scenes from court life to her work, reflected in The Diary of Lady Murasaki."


In [None]:
test_tokens = reference_gpt2.to_tokens(test_string).cuda()
demo_logits = demo_gpt2(test_tokens)

In [None]:
def lm_cross_entropy_loss(logits, tokens):
  log_probs = logits.log_softmax(dim=-1)
  pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
  return -pred_log_probs.mean()

loss = lm_cross_entropy_loss(demo_logits, test_tokens)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

tensor(3.2001, device='cuda:0', grad_fn=<NegBackward0>)
Loss as average prob tensor(0.0408, device='cuda:0', grad_fn=<ExpBackward0>)
Loss as 'uniform over this many variables' tensor(24.5357, device='cuda:0', grad_fn=<ExpBackward0>)
Uniform loss over the vocab 10.82490511970208


In [None]:
for i in tqdm.tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).cuda()
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

Murasaki Shikibu (c. 973 – c. 1014 or 1025) was a Japanese novelist, poet and lady-in-waiting at the Imperial court in the Heian period. She is best known as the author of The Tale of Genji, written in Japanese between about 1000 and 1012. She became a lady-in-waiting to Empress Shōshi at the Imperial court around 1005, and continued to write during her service, adding scenes from court life to her work, reflected in The Diary of Lady Murasaki.


Murasaki Shikibu was born in Tokyo on October 9, 1883. She was educated at the Imperial College of Art and Design, and studied at the Imperial College of Art and Design, Tokyo. She was educated at the Imperial College of Art and Design, Tokyo, and studied at the Imperial College of Art and Design, Tokyo. She was also a member of the Imperial College of Art and Design, Tokyo.


Murasaki Shikibu was a member of


In [None]:
import datasets
import transformers
import plotly.express as px

In [None]:
batch_size = 8
num_epochs = 1
max_steps = 1000
log_every = 10
lr = 1e-3
weight_decay = 1e-2
model_cfg = Config(debug=False, d_model=256, n_heads=4, d_head=64, d_mlp=1024, n_layers=2, n_ctx=256, d_vocab=reference_gpt2.cfg.d_vocab)

In [None]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
print(dataset)
print(dataset[0]['text'][:100])
tokens_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=model_cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=2)
data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)


Dataset({
    features: ['text', 'meta'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


Map (num_proc=2):   0%|          | 0/10000 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (80023 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors


In [None]:
model = DemoTransformer(model_cfg)
model.cuda()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
losses = []
print("Number of batches:", len(data_loader))
for epoch in range(num_epochs):
  for c, batch in tqdm.tqdm(enumerate(data_loader)):
    tokens = batch['tokens'].cuda()
    logits = model(tokens)
    loss = lm_cross_entropy_loss(logits, tokens)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    losses.append(loss.item())
    if c % log_every == 0:
      print(f"Step: {c}, Loss: {loss.item():.4f}")
    if c > max_steps:
      break

Number of batches: 8506


0it [00:00, ?it/s]

Step: 0, Loss: 10.8520
Step: 10, Loss: 8.8035
Step: 20, Loss: 7.7282
Step: 30, Loss: 7.8087
Step: 40, Loss: 7.3833
Step: 50, Loss: 7.6357
Step: 60, Loss: 6.7563
Step: 70, Loss: 7.4314
Step: 80, Loss: 7.8774
Step: 90, Loss: 6.4978
Step: 100, Loss: 7.3814
Step: 110, Loss: 6.9360
Step: 120, Loss: 5.6193
Step: 130, Loss: 6.4876
Step: 140, Loss: 7.1882
Step: 150, Loss: 6.1695
Step: 160, Loss: 6.8181
Step: 170, Loss: 4.8425
Step: 180, Loss: 6.0882
Step: 190, Loss: 5.2153
Step: 200, Loss: 6.5425
Step: 210, Loss: 6.9538
Step: 220, Loss: 7.3720
Step: 230, Loss: 7.4605
Step: 240, Loss: 5.5480
Step: 250, Loss: 6.7337
Step: 260, Loss: 6.1478
Step: 270, Loss: 6.7170
Step: 280, Loss: 7.2844
Step: 290, Loss: 6.8094
Step: 300, Loss: 5.2000
Step: 310, Loss: 6.3915
Step: 320, Loss: 6.4649
Step: 330, Loss: 6.5783
Step: 340, Loss: 6.2717
Step: 350, Loss: 5.3378
Step: 360, Loss: 5.4763
Step: 370, Loss: 6.5944
Step: 380, Loss: 6.1850
Step: 390, Loss: 6.4572
Step: 400, Loss: 4.8032
Step: 410, Loss: 5.9367
St

In [None]:
px.line(y = losses, x=np.arange(len(losses))*(model_cfg.n_ctx*batch_size), labels={'y':'Loss', 'x':'Tokens'}, title='Demo model training curve')

In [None]:
test_string = "one, two, three, four,"

for i in tqdm.tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).cuda()
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

one, two, three, four, five, six, seven, eight, nine, ten, eleven, twelve, thirteen, fourteen, fifteen, sixteen, seventeen, eighteen, nineteen, or twenty-one, twenty-two, twenty-three, twenty-four, twenty-five, twenty-six, twenty-seven, twenty-eight, twenty-nine, twenty-ten, twenty-eleven, twenty-one, twenty-two, twenty-three, twenty-four, twenty-five, twenty-six,
