<a href="https://colab.research.google.com/github/LuhanMikaelson/ARENA_3.0/blob/main/ARENA_Transformer_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformers from scratch


## Setup (don't read, just run!)


In [None]:
try:
    import google.colab # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False

import os, sys
chapter = "chapter1_transformer_interp"
repo = "ARENA_3.0"

if IN_COLAB:
    # Install packages
    %pip install transformer_lens
    %pip install einops
    %pip install jaxtyping
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python

    # Code to download the necessary files (e.g. solutions, test funcs)
    if not os.path.exists(f"/content/{chapter}"):
        !wget https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/main.zip
        !unzip /content/main.zip 'ARENA_3.0-main/chapter1_transformer_interp/exercises/*'
        sys.path.append(f"/content/{repo}-main/{chapter}/exercises")
        os.remove("/content/main.zip")
        os.rename(f"{repo}-main/{chapter}", chapter)
        os.rmdir(f"{repo}-main")
        os.chdir(f"{chapter}/exercises")
else:
    chapter_dir = r"./" if chapter in os.listdir() else os.getcwd().split(chapter)[0]
    sys.path.append(chapter_dir + f"{chapter}/exercises")

Collecting transformer_lens
  Downloading transformer_lens-1.17.0-py3-none-any.whl (137 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.1/137.1 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.30.1-py3-none-any.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.6/302.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/5

In [None]:
import os; os.environ['ACCELERATE_DISABLE_RICH'] = "1"
import sys
import einops
from dataclasses import dataclass
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
import torch as t
from torch import Tensor
import torch.nn as nn
import numpy as np
import math
from tqdm.notebook import tqdm
from typing import Tuple, List, Optional, Dict, Callable
from jaxtyping import Float, Int
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast
from collections import defaultdict
from rich.table import Table
from rich import print as rprint
import datasets
from torch.utils.data import DataLoader
import wandb
from pathlib import Path
import webbrowser

# Make sure exercises are in the path
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part1_transformer_from_scratch").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow
import part1_transformer_from_scratch.solutions as solutions

device = t.device("cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == '__main__'

reference_gpt2 = HookedTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [None]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests


In [None]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = t.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape, "\n")

def load_gpt2_test(cls, gpt2_layer, input):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    print("Input shape:", input.shape)
    output = layer(input)
    if isinstance(output, tuple): output = output[0]
    print("Output shape:", output.shape)
    try: reference_output = gpt2_layer(input)
    except: reference_output = gpt2_layer(input, input, input)
    print("Reference output shape:", reference_output.shape, "\n")
    comparison = t.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct\n")

## LayerNorm

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.zeros(cfg.d_model))

    def forward(self, residual: Float[Tensor, "batch posn d_model"]) -> Float[Tensor, "batch posn d_model"]:
      mean = t.mean(residual, dim=2, keepdim=True)
      variance = t.var(residual, dim=2, keepdim=True, unbiased=False)
      residual = ((residual - mean)/t.sqrt(variance + self.cfg.layer_norm_eps)) * self.w + self.b
      return residual


rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, cache["resid_post", 11])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Embedding


In [None]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(t.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
      print(tokens.shape)
      return self.W_E[tokens]

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 45])
torch.Size([1, 45])
Output shape: torch.Size([1, 45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



## Positional Embedding



In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(t.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_model"]:
      seq_length = list(np.arange(tokens.size()[1]))
      return self.W_pos[seq_length]

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([4, 768]) 

Input shape: torch.Size([1, 45])
Output shape: torch.Size([45, 768])
Reference output shape: torch.Size([1, 45, 768]) 

100.00% of the values are correct



## Attention



In [None]:
import circuitsvis as cv
from IPython.display import display

html = cv.attention.attention_patterns(
    tokens=reference_gpt2.to_str_tokens(reference_text),
    attention=cache["pattern", 0][0]
)
display(html)

In [None]:
from torch.nn import functional as F

In [None]:
class Attention(nn.Module):
    IGNORE: Float[Tensor, ""]

    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_K = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_V = nn.Parameter(t.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        self.W_O = nn.Parameter(t.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        self.b_Q = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_K = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_V = nn.Parameter(t.zeros((cfg.n_heads, cfg.d_head)))
        self.b_O = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.register_buffer("IGNORE", t.tensor(-1e5, dtype=t.float32, device=device))

    def forward(
        self, normalized_resid_pre: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
      device = normalized_resid_pre.device
      keys = t.einsum('bke,neh->bknh', normalized_resid_pre, self.W_K) + self.b_K
      queries = t.einsum('bqe,neh->bqnh', normalized_resid_pre, self.W_Q) + self.b_Q
      values = t.einsum('bke,neh->bknh', normalized_resid_pre, self.W_V) + self.b_V
      QK = t.einsum('bqnh,bkhn->bnqk', queries, t.transpose(keys, 2, 3))
      scaled_QK = QK / (t.sqrt(t.tensor(queries.size()[3], device=device)).unsqueeze(0).expand(QK.size()[2], QK.size()[3]))
      masked_QK = self.apply_causal_mask(scaled_QK)
      attn_probs = F.softmax(masked_QK, dim=3)
      z_scores = t.einsum('bnqk,bknh->bqnh', attn_probs, values)
      extended_z = t.einsum('bqnh,nhe->bqne', z_scores, self.W_O)
      result = t.einsum('bqne->bqe', extended_z) + self.b_O
      return result


    def apply_causal_mask(self, attn_scores: Float[Tensor, "batch n_heads query_pos key_pos"]) -> Float[Tensor, "batch n_heads query_pos key_pos"]:
      '''
      Applies a causal mask to attention scores, and returns masked scores.
      '''
      device = attn_scores.device
      batch_size, n_heads, query_pos, key_pos = attn_scores.size()
      upper_triangular_mask = t.triu(t.ones((query_pos, key_pos), device=attn_scores.device), diagonal=1).bool()
      expanded_upper_triangular_mask = upper_triangular_mask.unsqueeze(0).unsqueeze(0).expand(batch_size, n_heads, -1, -1)
      masked_attn_scores = t.masked_fill(attn_scores, expanded_upper_triangular_mask, self.IGNORE)
      return masked_attn_scores





rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["normalized", 0, "ln1"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## MLP


In [None]:
class MLP(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(t.empty((cfg.d_model, cfg.d_mlp)))
        self.W_out = nn.Parameter(t.empty((cfg.d_mlp, cfg.d_model)))
        self.b_in = nn.Parameter(t.zeros((cfg.d_mlp)))
        self.b_out = nn.Parameter(t.zeros((cfg.d_model)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        nn.init.normal_(self.W_out, std=self.cfg.init_range)

    def forward(
        self, normalized_resid_mid: Float[Tensor, "batch posn d_model"]
    ) -> Float[Tensor, "batch posn d_model"]:
      res_1 = normalized_resid_mid @ self.W_in + self.b_in
      res_2 = gelu_new(res_1)
      res_3 = res_2 @ self.W_out + self.b_out
      return res_3


rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["normalized", 0, "ln2"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Transformer Block




In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(
        self, resid_pre: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_model"]:
      mid_res_1 = self.ln1(resid_pre)
      mid_res_2 = self.attn(mid_res_1)
      mid_resid = resid_pre + mid_res_2
      mid_res_3 = self.ln2( mid_resid)
      mid_res_4 = self.mlp(mid_res_3)
      final_resid = mid_res_4 + mid_resid
      return final_resid





rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768]) 

100.00% of the values are correct



## Unembedding




In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(t.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(t.zeros((cfg.d_vocab), requires_grad=False))

    def forward(
        self, normalized_resid_final: Float[Tensor, "batch position d_model"]
    ) -> Float[Tensor, "batch position d_vocab"]:
      return normalized_resid_final @ self.W_U + self.b_U

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257]) 

100.00% of the values are correct



## Full Transformer

```c
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵🔵🔵⚪⚪

You should spend up to ~10 minutes on this exercise.
```


In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)


    def blocks_forward(self, x):
      for block in self.blocks:
        x = block(x)
      return x

    def forward(self, tokens: Int[Tensor, "batch position"]) -> Float[Tensor, "batch position d_vocab"]:
      embedded_res = self.embed(tokens)
      pos_embedded_res = self.pos_embed(tokens)
      tot_embedded = embedded_res + pos_embedded_res
      post_blocks = self.blocks_forward(tot_embedded)
      norm_final = self.ln_final(post_blocks)
      unembeded_res = self.unembed(norm_final)
      return unembeded_res



rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257]) 

Input shape: torch.Size([1, 45])
torch.Size([1, 45])
Output shape: torch.Size([1, 45, 50257])
Reference output shape: torch.Size([1, 45, 50257]) 

100.00% of the values are correct



In [None]:
def get_log_probs(
    logits: Float[Tensor, "batch posn d_vocab"],
    tokens: Int[Tensor, "batch posn"]
) -> Float[Tensor, "batch posn-1"]:

    log_probs = logits.log_softmax(dim=-1)
    # Get logprobs the first seq_len-1 predictions (so we can compare them with the actual next tokens)
    log_probs_for_tokens = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)

    return log_probs_for_tokens


pred_log_probs = get_log_probs(demo_logits, tokens)
print(f"Avg cross entropy loss: {-pred_log_probs.mean():.4f}")
print(f"Avg cross entropy loss for uniform distribution: {math.log(demo_gpt2.cfg.d_vocab):4f}")
print(f"Avg probability assigned to correct token: {pred_log_probs.exp().mean():4f}")

Avg cross entropy loss: 4.0441
Avg cross entropy loss for uniform distribution: 10.824905
Avg probability assigned to correct token: 0.098629


In [None]:
test_string = '''The Total Perspective Vortex derives its picture of the whole Universe on the principle of'''
for i in tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).to(device)
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())

print(test_string)

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([1, 16])
tokens: tensor([[50256,   464,  7472, 42051, 49790, 37453,   663,  4286,   286,   262,
          2187, 11950,   319,   262,  7989,   286]], device='cuda:0')
torch.Size([1, 17])
tokens: tensor([[50256,   464,  7472, 42051, 49790, 37453,   663,  4286,   286,   262,
          2187, 11950,   319,   262,  7989,   286,   262]], device='cuda:0')
torch.Size([1, 18])
tokens: tensor([[50256,   464,  7472, 42051, 49790, 37453,   663,  4286,   286,   262,
          2187, 11950,   319,   262,  7989,   286,   262,  2472]],
       device='cuda:0')
torch.Size([1, 19])
tokens: tensor([[50256,   464,  7472, 42051, 49790, 37453,   663,  4286,   286,   262,
          2187, 11950,   319,   262,  7989,   286,   262,  2472,  6650]],
       device='cuda:0')
torch.Size([1, 20])
tokens: tensor([[50256,   464,  7472, 42051, 49790, 37453,   663,  4286,   286,   262,
          2187, 11950,   319,   262,  7989,   286,   262,  2472,  6650,    13]],
       device='cuda:0')
torch.Size([1, 21])
toke

In later sections, we'll learn to generate text in slightly more interesting ways than just argmaxing the output.


# 3️⃣ Training a Transformer


## Create Model

In [None]:
model_cfg = Config(
    debug=False,
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab=reference_gpt2.cfg.d_vocab
)
model = DemoTransformer(model_cfg)

## Training Args



In [None]:
@dataclass
class TransformerTrainingArgs():
	batch_size = 16
	epochs = 10
	max_steps_per_epoch = 200
	lr = 1e-3
	weight_decay = 1e-2
	wandb_project: Optional[str] = "day1-demotransformer"
	wandb_name: Optional[str] = None


args = TransformerTrainingArgs()

## Create Data

We load in a tiny dataset made by Neel Nanda, with the first 10K entries in the Pile (inspired by Stas' version for OpenWebText!)

In [None]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
print(dataset[0]['text'][:100])

Downloading readme:   0%|          | 0.00/373 [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/921 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/33.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


`tokenize_and_concatenate` is a useful function which takes our dataset of strings, and returns a dataset of token IDs ready to feed into the model. We then create a dataloader from this tokenized dataset. The useful method `train_test_split` can give us a training and testing set.


In [None]:
import torch.multiprocessing as mp

# Set start method to 'spawn'
mp.set_start_method('spawn', force=True)



In [None]:
tokenized_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=model.cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) # Lowered num_workers to accomodate system specifications
test_loader = DataLoader(dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # Lowered num_workers to accomodate system specifications

In [None]:
first_batch = train_loader.dataset[:args.batch_size]

print(first_batch.keys())
print(first_batch['tokens'].shape)
for i in first_batch.values():
  print(i)
  print(type(i))


dict_keys(['tokens'])
torch.Size([16, 256])
tensor([[50256,   220,   220,  ...,   220,   220,   220],
        [50256,  8321, 44148,  ..., 41339,     7, 11340],
        [50256,   656,   262,  ...,   621,   644,   345],
        ...,
        [50256,  6163,   355,  ...,   352, 15168,   362],
        [50256,  1971,  4332,  ...,  6786,  2873,     1],
        [50256,  5105, 32936,  ..., 17050,    11,   599]])
<class 'torch.Tensor'>


## Training Loop


In [None]:


class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args
        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0
        self.trainloader = self.train_loader()
        self.testloader = self.test_loader()
        print(f"Total number of batches in train loader: {len(self.trainloader)}")

    def training_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        tokens = batch["tokens"]
        tokens = tokens.to(device)
        logits = self.model(tokens)
        loss = - get_log_probs(logits, tokens).mean()
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        return loss

    @t.inference_mode()
    def validation_step(self, batch: Dict[str, Int[Tensor, "batch seq"]]):
        tokens = batch["tokens"]
        tokens = tokens.to(device)
        logits = self.model(tokens)
        predicted_tokens = t.argmax(logits, dim=-1)
        targets = t.cat([tokens[:, 1:], tokens.new_zeros(tokens.size(0), 1)], dim=1)
        correct_predictions = (predicted_tokens == targets)
        correct_predictions_flat = correct_predictions[:, :-1].reshape(-1)
        return correct_predictions_flat

    def train(self):
      wandb.init(project=self.args.wandb_project, name=self.args.wandb_name)
      self.counter = 0

      for epoch in range(self.args.epochs):
          self.step = 0
          progress_bar = tqdm(self.trainloader, total=min(len(self.trainloader), self.args.max_steps_per_epoch))

          for batch in self.trainloader:
              loss = self.training_step(batch)

              # Updating the progress bar and logging
              self.step += 1
              self.counter += 1
              wandb.log(dict(loss=loss.item()), step=self.counter)
              progress_bar.set_description(f"Epoch {epoch}, Loss: {loss:.4f}, Examples seen: {self.counter}")
              progress_bar.update(1)

              if self.step >= self.args.max_steps_per_epoch:
                  break

          all_correct_predictions = []
          for batch in self.testloader:
              correct_predictions_flat = self.validation_step(batch)
              all_correct_predictions.append(correct_predictions_flat)

          epoch_accuracy_tensor = t.cat(all_correct_predictions, dim=0)
          epoch_accuracy = t.mean(epoch_accuracy_tensor.float()).item()
          wandb.log({"Accuracy": epoch_accuracy}, step=self.counter)


      wandb.finish()


    def train_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["train"], batch_size=self.args.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    def test_loader(self) -> DataLoader:
        return DataLoader(dataset_dict["test"], batch_size=self.args.batch_size, shuffle=False, num_workers=4, pin_memory=True)


In [None]:
print(t.sqrt(t.tensor(4)))

tensor(2.)


In [None]:

model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
trainer.train()

Total number of batches in train loader: 4191


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


  0%|          | 0/200 [00:00<?, ?it/s]



torch.Size([16, 256])
tokens: tensor([[50256,   705,    83,  ...,   685,   201,   198],
        [50256,   503,   319,  ...,   340,  1701,   366],
        [50256,  5699,   621,  ...,  1422,   470,   651],
        ...,
        [50256, 17487,   198,  ..., 13050,    11,   290],
        [50256, 15652,   988,  ...,   326,   477,   262],
        [50256,    83,   559,  ...,  4808,    90,   657]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   373, 10770,  ..., 35960,  3822,  3220],
        [50256,   356,   550,  ...,  6678, 45602,   286],
        [50256,  3419,  1391,  ...,   475,   477,   663],
        ...,
        [50256,   220,   220,  ..., 36912,   416,   257],
        [50256,   220,   220,  ...,   220,   220,   220],
        [50256, 24177,  9319,  ...,     8,   198,   198]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   198,   198,  ...,   347,     5,    39],
        [50256, 45061, 14067,  ..., 24448, 19321,  1766],
        [50256,   287,   262,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,  2927,   420,  ..., 10145,    13,   554],
        [50256,  1175,    11,  ...,  9851,   511,  8242],
        [50256,   366,  1639,  ...,   220,   220,   220],
        ...,
        [50256,   257,  5752,  ...,   422,   262, 16834],
        [50256,  4167,   447,  ...,   286,   543,   318],
        [50256,  1872,  9313,  ...,   561,   787,   257]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,    11,   257,  ...,    16,    12,  5705],
        [50256,   284,   514,  ...,  3118,  1102,  2536],
        [50256,   366,   314,  ...,   366,  1639, 17753],
        ...,
        [50256,  2867,  8259,  ...,  4855,   416, 19701],
        [50256,  6946,   351,  ...,    25, 10903,    13],
        [50256,  3212,    11,  ...,  1635,    34,    13]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   477, 25975,  ..., 20105,    13,   679],
        [50256,  9204,  2597,  ...,   284, 11059,    11],
        [50256,    13,   383,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,   430,  3769,  ..., 18477,    87,    92],
        [50256,  7324,    13,  ...,  1257, 13174, 20515],
        [50256,     0,    59,  ...,  4871,  4941,   329],
        ...,
        [50256,   262, 18398,  ..., 10414, 28216,  3403],
        [50256,    18,    67,  ...,  7407,    13,    17],
        [50256,  1295,   286,  ...,   743,   307,  6412]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,  1877, 15432,  ...,  3812,   262, 23422],
        [50256,   267,   796,  ...,   796, 15495,  3901],
        [50256,   220,   402,  ...,  3025,  2951,   836],
        ...,
        [50256,   366,  2949,  ...,   640,   526,   366],
        [50256,   337,   577,  ...,   554,  1109,    11],
        [50256,  2319,    11,  ...,   262,  4725,  3611]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   991,  1336,  ...,   968,  4492,  4152],
        [50256,  4686,    87,  ...,    25,   657,    87],
        [50256,  6067,   553,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,  5475,   329,  ...,  3853,   543,  2594],
        [50256,  1925,   528,  ...,   198,  3347,   714],
        [50256,   657,    87,  ...,    87,   940,    67],
        ...,
        [50256,   198,   220,  ...,  2882,    13,  1439],
        [50256,   550,   284,  ...,  2746,   447,   251],
        [50256,  1997,  7165,  ...,   287,   262,  3223]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,  5661,   737,  ...,   220,   220, 29568],
        [50256,    85,  1789,  ...,   339,   918, 33743],
        [50256,   828,   281,  ..., 13879,  2185, 11968],
        ...,
        [50256,   198, 10919,  ...,   447,   250, 44959],
        [50256, 24057, 18644,  ...,   220,   220,   220],
        [50256,   532, 27988,  ...,   513,   198,  2061]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   326,   198,  ...,   355,  7695, 11973],
        [50256,   428,  5642,  ...,   281,  9848,   764],
        [50256,    45, 23330,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,   247,   297,  ...,    13,   367,  9795],
        [50256,   289, 24247,  ...,     0,   198,   198],
        [50256,   198,   220,  ...,   220,   220,   220],
        ...,
        [50256,   345,   561,  ...,  7926,    11,   345],
        [50256,   220,   220,  ...,    87,  4531,   318],
        [50256,  2077,   777,  ...,    13,  1482, 41981]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,    13,   198,  ...,   284,  1440,  6490],
        [50256, 11232, 19995,  ..., 14216,  1414, 26360],
        [50256,   262, 23212,  ...,   616,  7034,  4286],
        ...,
        [50256,  4417,    13,  ...,    13,   198,   198],
        [50256,  4248,  1065,  ...,   274, 38320,   290],
        [50256,  4477,   284,  ...,  1511,  2091,    26]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   220,   220,  ...,    87,    19,  2388],
        [50256,   220,   220,  ...,    15,    26,    73],
        [50256,   428, 11483,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,    11,   198,  ...,    49,    62,    47],
        [50256,   220,   220,  ...,   220,   220,   220],
        [50256,   198,  5990,  ..., 24297,  1174,  4248],
        ...,
        [50256,   262,  2746,  ...,   262,  3048,   286],
        [50256, 47354,   577,  ...,  2953,    34,   274],
        [50256,   878,   340,  ...,    11,   642,  1776]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   477,   340,  ...,    13,  2631, 20651],
        [50256,   376, 16151,  ..., 37346,    13,  1174],
        [50256,   470,  1337,  ...,   366, 34094,   326],
        ...,
        [50256,  1471,   272,  ...,  5066,  2681,    13],
        [50256,  3950,    13,  ...,  1900,   355,   978],
        [50256,   262,  2566,  ...,  1249,   345,   284]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256, 17486,  2672,  ...,  4871,    13,  1136],
        [50256,    15,  2791,  ...,   262, 12396,   825],
        [50256,    37,     8,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,   198,   464,  ...,   319,  2693,   362],
        [50256,    13,  5407,  ...,   355,   517, 11255],
        [50256,   447,   231,  ...,   575,  3528,    11],
        ...,
        [50256,  1270,   438,  ...,   268, 30486, 37654],
        [50256,   198,  1870,  ..., 26159, 10329,  5095],
        [50256,    25,  1120,  ..., 25104,   287,   616]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,    59,    82,  ...,    59, 14415, 43839],
        [50256,   845,  2562,  ...,   286, 40551,   308],
        [50256,  2520,  2253,  ...,  3503, 12179,   326],
        ...,
        [50256,   220,   720,  ...,   220,   720,  5705],
        [50256,   220,   220,  ...,   220,   220,   220],
        [50256,   530,   286,  ...,    32,  2733,    13]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256, 15462,  6848,  ...,  2000,   683,   852],
        [50256,   389, 25583,  ...,  4522,   379, 21020],
        [50256,  3106, 21507,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,   286,   607,  ..., 14268,   396,    11],
        [50256,   262,  2420,  ...,  7266, 12522,   379],
        [50256,   561,   307,  ...,   517,   355,   281],
        ...,
        [50256,   656,   663,  ...,  1339,   286, 14174],
        [50256,  8426,  4077,  ...,   290,   357,    65],
        [50256,  2183,    11,  ...,  1306,   284,   257]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,    72,    59,  ...,   530,   460,  3538],
        [50256,     8,  5218,  ...,    13, 15643,   680],
        [50256,   220,   220,  ...,   220,   220,   220],
        ...,
        [50256,  2625,  5647,  ...,   486, 31020,    12],
        [50256,    12, 11245,  ..., 22914,    92,   198],
        [50256,  3556,    64,  ...,  4858,  2989,   329]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256, 20418,    83,  ...,   287, 22359,  6115],
        [50256,  2292,    11,  ...,   357,    66,     8],
        [50256,   890,  3892,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,  1021, 23776,  ...,    13,  5747,   465],
        [50256,   220,   220,  ...,   220,   220,   220],
        [50256,    13,   837,  ...,  2304,  1000,  2927],
        ...,
        [50256,   474, 48940,  ...,   281, 13122,    52],
        [50256,  1404, 25603,  ...,   262,  4217,  1404],
        [50256,   220,   220,  ...,   220,   220,   220]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   345,   460,  ...,   340,   373,   523],
        [50256, 42053,   328,  ...,   355,   281, 35555],
        [50256, 37256,    11,  ..., 20844,   406, 12721],
        ...,
        [50256, 10338,    11,  ...,   286, 14937, 44381],
        [50256,   198,    38,  ...,   290,  4132, 12769],
        [50256,   266,    13,  ...,    72,  5512,    72]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   220,   220,  ...,  5662, 41473,  1314],
        [50256,   428,   705,  ...,   772,   351,  2839],
        [50256,   284,  2962,  ..., 

  0%|          | 0/200 [00:00<?, ?it/s]

torch.Size([16, 256])
tokens: tensor([[50256,   262,  4417,  ...,  6608,   286,   262],
        [50256,   220,   220,  ...,   437,    14,    47],
        [50256,   485,  1165,  ...,   883,   508,   836],
        ...,
        [50256,   735,  3894,  ...,    11, 18639,    11],
        [50256,   703,   340,  ...,  2504,   338,   780],
        [50256,   818,   685,  ..., 11537,   198,   198]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,   314,   550,  ...,  5586,  1306,   284],
        [50256, 35944,   198,  ..., 10693,   547, 10945],
        [50256,    12, 19577,  ...,   284,  7239,   319],
        ...,
        [50256, 10731,   326,  ..., 27690,   913, 28449],
        [50256,   290,  1577,  ...,    11,   290, 19008],
        [50256,   880,    12,  ...,   309, 21870,    12]], device='cuda:0')
torch.Size([16, 256])
tokens: tensor([[50256,    11,   484,  ...,    11,   788,  5871],
        [50256,   290,   543,  ...,   471,    14,    34],
        [50256,   284,   262,  ..., 

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Accuracy,▁▃▄▅▅▆▆▇██
loss,█▅▆▆▅▃▄▃▃▅▅▄▃▃▂▂▄▄▄▃▃▅▂▂▄▂▃▂▄▃▃▄▄▃▁▂▂▂▃▂

0,1
Accuracy,0.2724
loss,4.61668


<iframe src="https://wandb.ai/luhanexperiments/day1-demotransformer/reports/loss-24-05-16-13-30-06---Vmlldzo3OTcxMTYx?accessToken=zg57pz59wkmm69qv7u9nlpdaebqw8mbnp6pib3ehvwkh57a7c8gjnl5toae1xup6" style="border:none;height:1024px;width:100%">

In [None]:

from IPython.display import HTML

# Define the iframe HTML codes
iframe_html_1 = '<iframe src="https://wandb.ai/luhanexperiments/day1-demotransformer/reports/loss-24-05-16-13-30-06---Vmlldzo3OTcxMTYx?accessToken=zg57pz59wkmm69qv7u9nlpdaebqw8mbnp6pib3ehvwkh57a7c8gjnl5toae1xup6" style="border:none;height:1024px;width:50%"></iframe>'

# Display the combined HTML content using IPython.display.HTML
HTML(iframe_html_1)




In [None]:

# Save model weights after training
current_directory = os.getcwd()
print(current_directory)
t.save(model.state_dict(), os.path.join(current_directory, "model_final.pth"))


/content/chapter1_transformer_interp/exercises


In [None]:
import os
file_name = "model_final.pth"
file_path = os.path.join(os.getcwd(), file_name)

if os.path.exists(file_path):
    print(f"Found {file_name} at: {file_path}")
else:
    print(f"{file_name} not found in current directory.")


Found model_final.pth at: /content/chapter1_transformer_interp/exercises/model_final.pth


# 4️⃣ Sampling from a Transformer


In [None]:
model_cfg = Config()
model = DemoTransformer(model_cfg).to(device)
model.load_state_dict(reference_gpt2.state_dict(), strict=False)

tokenizer = reference_gpt2.tokenizer
print(tokenizer.vocab_size)

50257


## Main Sampling Function


In [None]:
class TransformerSampler:

    def __init__(self, model, tokenizer: GPT2TokenizerFast):
        self.model = model.to(device)
        self.cfg = model.cfg
        self.tokenizer = tokenizer


    @t.inference_mode()
    def sample(self, prompt: str, max_tokens_generated=100, verbose=False, **kwargs):
        '''
        Returns a string of autoregressively generated text, starting from the prompt.

        Sampling terminates at max_tokens_generated, or when the model generates an
        end-of-sequence token.

        kwargs are passed to sample_next_token, to give detailed instructions on how
        new tokens are chosen.
        '''
        self.model.eval()
        eos_token_id = self.tokenizer.eos_token_id  # Get the EOS token ID
        if verbose:
            print(prompt)

        tokenized_prompt = self.tokenizer.encode(prompt, return_tensors="pt").to(device)
        if verbose:
            print("tokenized:", tokenized_prompt.size())

        for _ in range(max_tokens_generated):
            prompt_logits = self.model(tokenized_prompt)  # Assuming model output has 'logits'
            next_token_id = self.sample_next_token(tokenized_prompt.squeeze(0), prompt_logits, **kwargs)
            next_token = t.tensor([[next_token_id]]).to(device)

            if verbose:
                print('next token:', next_token.size())
                print('tokenized prompts', tokenized_prompt.size())
            tokenized_prompt = t.cat((tokenized_prompt, next_token), dim=1)

            if verbose:
                print(self.tokenizer.decode(next_token.squeeze(0)), sep=' ')
            if next_token_id == eos_token_id:
                break
        result = self.tokenizer.decode(tokenized_prompt.squeeze(0).tolist())
        if verbose:
            print('result:', result)
        return result

    @staticmethod
    def sample_next_token(
        input_ids: t.Tensor,
        logits: t.Tensor,
        temperature=1.0,
        top_k=0,
        top_p=0.0,
        frequency_penalty=0.0,
        seed=None
    ):
        assert input_ids.ndim == 1, "input_ids should be a 1D sequence of token ids"
        assert temperature >= 0, "Temperature should be non-negative"
        assert 0 <= top_p <= 1.0, "Top-p must be a probability"
        assert 0 <= top_k, "Top-k must be non-negative"
        assert not (top_p != 0 and top_k != 0), "At most one of top-p and top-k supported"

        # Set random seeds for reproducibility
        if seed is not None:
            t.manual_seed(seed)
            t.cuda.manual_seed_all(seed)
            np.random.seed(seed)

        logits = logits[-1]  # Assuming logits shape is (seq_len, vocab_size)

        # Apply all the specialized sampling methods
        if temperature == 0:
            return TransformerSampler.greedy_search(logits)
        elif temperature != 1.0:
            logits = TransformerSampler.apply_temperature(logits, temperature)
        if frequency_penalty != 0.0:
            logits = TransformerSampler.apply_frequency_penalty(input_ids, logits, frequency_penalty)
        if top_k > 0:
            return TransformerSampler.sample_top_k(logits, top_k)
        if top_p > 0.0:
            return TransformerSampler.sample_top_p(logits, top_p)
        return TransformerSampler.sample_basic(logits)


    @staticmethod
    def greedy_search(logits: t.Tensor) -> int:
        '''
        Returns the most likely token (as an int).
        '''
        return logits.argmax().item()

    @staticmethod
    def apply_temperature(logits: t.Tensor, temperature: float) -> t.Tensor:
        '''
        Applies temperature scaling to the logits.
        '''
        return logits / temperature

    @staticmethod
    def apply_frequency_penalty(input_ids: t.Tensor, logits: t.Tensor, freq_penalty: float) -> t.Tensor:
        '''
        Applies a frequency penalty to the logits.
        '''
        freq_keys, frequencies = t.unique(input_ids, return_counts=True, sorted=False)
        for token, freq in zip(freq_keys, frequencies):
            logits[:, token] -= (freq_penalty * freq)
        return logits

    @staticmethod
    def sample_basic(logits: t.Tensor) -> int:
        '''
        Samples from the distribution defined by the logits.
        '''
        dist = t.distributions.Categorical(logits=logits)
        return dist.sample().item()

    @staticmethod
    def sample_top_k(logits: t.Tensor, k: int) -> int:
        '''
        Samples from the top k most likely tokens.
        '''
        top_k_values, top_k_indices = t.topk(logits, k, dim=-1)
        dist = t.distributions.Categorical(logits=top_k_values)
        sampled_index = dist.sample().item()
        return top_k_indices[sampled_index].item()


    @staticmethod
    def sample_top_p(logits: t.Tensor, top_p: float, min_tokens_to_keep: int = 1) -> int:
        '''
        Samples from the most likely tokens which make up at least p cumulative probability.
        '''
        sorted_logits, sorted_indices = t.sort(logits, descending=True)
        cumulative_probs = t.cumsum(t.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = -float('Inf')
        dist = t.distributions.Categorical(logits=logits)
        return dist.sample().item()




In [None]:
sampler = TransformerSampler(model, tokenizer)

prompt = "Jingle bells, jingle bells, jingle all the way"
print(f"Greedy decoding with prompt: {prompt!r}\n")

output = sampler.sample(prompt, max_tokens_generated=8, temperature=0.0)
print(f"Your model said: {output!r}\n")

expected = "Jingle bells, jingle bells, jingle all the way up to the top of the mountain."
assert output == expected

print("Tests passed!")

Greedy decoding with prompt: 'Jingle bells, jingle bells, jingle all the way'

tokenized: torch.Size([1, 13])
torch.Size([1, 13])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 13])
torch.Size([1, 14])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 14])
torch.Size([1, 15])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 15])
torch.Size([1, 16])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 16])
torch.Size([1, 17])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 17])
torch.Size([1, 18])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 18])
torch.Size([1, 19])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 19])
torch.Size([1, 20])
next token: torch.Size([1, 1])
tokenized prompts  torch.Size([1, 20])
result: Jingle bells, jingle bells, jingle all the way up to the top of the mountain.
Your model said: 'Jingle bells, jingle bells, jingle all the way up to the top of th

## Sampling with Categorical


In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097
}
frequency_of_top_5 = defaultdict(int)

N = 10_000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits)
    frequency_of_top_5[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word]
    observed_freq = frequency_of_top_5[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

print("Tests passed!")

torch.Size([1, 6])


  0%|          | 0/10000 [00:00<?, ?it/s]

Word: ' church'. Expected freq 0.0648, observed freq 0.0662
Word: ' house' . Expected freq 0.0367, observed freq 0.0374
Word: ' temple'. Expected freq 0.0145, observed freq 0.0152
Word: ' same'  . Expected freq 0.0104, observed freq 0.0093
Word: ' Church'. Expected freq 0.0097, observed freq 0.0111
Tests passed!


### Sampling With Temperature


In [None]:
logits = t.tensor([1, 2]).log()

cold_logits = TransformerSampler.apply_temperature(logits, temperature=0.001)
print('A low temperature "sharpens" or "peaks" the distribution: ', cold_logits)
t.testing.assert_close(cold_logits, 1000.0 * logits)

hot_logits = TransformerSampler.apply_temperature(logits, temperature=1000.0)
print("A high temperature flattens the distribution: ", hot_logits)
t.testing.assert_close(hot_logits, 0.001 * logits)

print("Tests passed!")

A low temperature "sharpens" or "peaks" the distribution:  tensor([  0.0000, 693.1472])
A high temperature flattens the distribution:  tensor([0.0000, 0.0007])
Tests passed!


<details>
<summary>Question - what is the limit of applying 'sample_basic' after adjusting with temperature, when temperature goes to zero? How about when temperature goes to infinity?</summary>

The limit when temperature goes to zero is greedy search (because dividing by a small number makes the logits very big, in other words the difference between the maximum logit one and all the others will grow).

The limit when temperature goes to infinity is uniform random sampling over all words (because all logits will be pushed towards zero).")
</details>


### Frequency Penalty Test




In [None]:
bieber_prompt = "And I was like Baby, baby, baby, oh Like, Baby, baby, baby, no Like, Baby, baby, baby, oh I thought you'd always be mine, mine"
input_ids = tokenizer.encode(bieber_prompt, return_tensors="pt")
logits = t.ones(tokenizer.vocab_size)
penalized_logits = TransformerSampler.apply_frequency_penalty(input_ids.squeeze(), logits, 2.0)
print(penalized_logits.size())

assert penalized_logits[5156].item() == -11, "Expected 6 occurrences of ' baby' with leading space, 1-2*6=-11"
assert penalized_logits[14801].item() == -5, "Expected 3 occurrences of ' Baby' with leading space, 1-2*3=-5"

print("Tests passed!")

torch.Size([50257])
Tests passed!


### Sampling - Manual Testing


In [None]:
sampler = TransformerSampler(model, tokenizer)

N_RUNS = 1
your_prompt = "Jingle bells, jingle bells, jingle all the way"
cases = [
    ("High freq penalty", dict(frequency_penalty=100.0)),
    ("Negative freq penalty", dict(frequency_penalty=-3.0)),
    ("Too hot!", dict(temperature=2.0)),
    ("Pleasantly cool", dict(temperature=0.7)),
    ("Pleasantly warm", dict(temperature=0.9)),
    ("Too cold!", dict(temperature=0.01)),
]

table = Table("Name", "Kwargs", "Output", title="Sampling - Manual Testing")

for (name, kwargs) in cases:
    for i in range(N_RUNS):
        output = sampler.sample(your_prompt, max_tokens_generated=24, **kwargs)
        table.add_row(name, repr(kwargs), repr(output) + "\n")

rprint(table)

torch.Size([1, 13])
torch.Size([1, 14])
torch.Size([1, 15])
torch.Size([1, 16])
torch.Size([1, 17])
torch.Size([1, 18])
torch.Size([1, 19])
torch.Size([1, 20])
torch.Size([1, 21])
torch.Size([1, 22])
torch.Size([1, 23])
torch.Size([1, 24])
torch.Size([1, 25])
torch.Size([1, 26])
torch.Size([1, 27])
torch.Size([1, 28])
torch.Size([1, 29])
torch.Size([1, 30])
torch.Size([1, 31])
torch.Size([1, 32])
torch.Size([1, 33])
torch.Size([1, 34])
torch.Size([1, 35])
torch.Size([1, 36])
torch.Size([1, 13])
torch.Size([1, 14])
torch.Size([1, 15])
torch.Size([1, 16])
torch.Size([1, 17])
torch.Size([1, 18])
torch.Size([1, 19])
torch.Size([1, 20])
torch.Size([1, 21])
torch.Size([1, 22])
torch.Size([1, 23])
torch.Size([1, 24])
torch.Size([1, 25])
torch.Size([1, 26])
torch.Size([1, 27])
torch.Size([1, 28])
torch.Size([1, 29])
torch.Size([1, 30])
torch.Size([1, 31])
torch.Size([1, 32])
torch.Size([1, 33])
torch.Size([1, 34])
torch.Size([1, 35])
torch.Size([1, 36])
torch.Size([1, 13])
torch.Size([1, 14])


## Top-K Sampling

Conceptually, the steps in top-k sampling are:
- Find the `top_k` largest probabilities (you can use [`torch.topk`](https://pytorch.org/docs/stable/generated/torch.topk.html))
- Set all other probabilities to zero
- Normalize and sample


### Exercise - implement `sample_top_k`

```c
Difficulty: 🔴🔴⚪⚪⚪
Importance: 🔵⚪⚪⚪⚪

You should spend up to 5-10 minutes on this exercise.
```

Implement the method `sample_top_k` now. Your implementation should stay in log-space throughout (don't exponentiate to obtain probabilities). This means you don't actually need to worry about normalizing, because `Categorical` accepts unnormalised logits.


In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_5 = {
    " church": 0.0648,
    " house": 0.0367,
    " temple": 0.0145,
    " same": 0.0104,
    " Church": 0.0097
}
topk_5_sum = sum(expected_top_5.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_k=5)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_5:
    expected_freq = expected_top_5[word] / topk_5_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq = {expected_freq:.4f}, observed freq = {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.015, "Try increasing N if this fails by a small amount."

torch.Size([1, 6])


  0%|          | 0/10000 [00:00<?, ?it/s]

RuntimeError: selected index k out of range

### Top-K Sampling - Example

The [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) famously included an example prompt about unicorns. Now it's your turn to see just how cherry picked this example was.

The paper claims they used `top_k=40` and best of 10 samples.


In [None]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English."
output = sampler.sample(your_prompt, temperature=0.7, top_k=40, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

This is pretty incredible! For some perspective on how much of a paradigm shift even basic models like this represented, we recommend reading [this section from Simulators](https://www.lesswrong.com/posts/vJFdjigzmcXMhNTsx/simulators#The_limit_of_sequence_modeling).

## Top-p aka Nucleus Sampling

The basic idea is that we choose the most likely words, up until the total probability of words we've chosen crosses some threshold. Then we sample from those chosen words based on their logits.

The steps are:

- Sort the probabilities from largest to smallest
- Find the cutoff point where the cumulative probability first equals or exceeds `top_p`. We do the cutoff inclusively, keeping the first probability above the threshold.
- If the number of kept probabilities is less than `min_tokens_to_keep`, keep that many tokens instead.
- Set all other probabilities to zero
- Normalize and sample

Optionally, refer to the paper [The Curious Case of Neural Text Degeneration](https://arxiv.org/pdf/1904.09751.pdf) for some comparison of different methods.


### Exercise - implement `sample_top_p`

```c
Difficulty: 🔴🔴🔴⚪⚪
Importance: 🔵⚪⚪⚪⚪

You should spend up to 15-20 minutes on this exercise.
```

<details>
<summary>Example of top-p sampling (if you're confused)</summary>

If our probabilities were `(0.4, 0.3, 0.2, 0.1)` and our cutoff was `top_p=0.8`, then we'd sample from the first three elements (because their total probability is `0.9` which is over the threshold, but the first two only have a total prob of `0.7` which is under the threshold). Once we've chosen to sample from those three, we would renormalise them by dividing by their sum (so the probabilities we use when sampling are `(4/9, 3/9, 2/9)`.
</details>

<details>
<summary>Help - I'm stuck on how to implement this function.</summary>

First, sort the logits using the `sort(descending=True)` method (this returns values and indices). Then you can get `cumulative_probs` by applying softmax to these logits and taking the cumsum. Then, you can decide how many probabilities to keep by using the `t.searchsorted` function.
    
Once you've decided which probabilities to keep, it's easiest to sample from them using the original logits (you should have preserved the indices when you called `logits.sort`). This way, you don't need to worry about renormalising like you would if you were using probabilities.
</details>


In [None]:
prompt = "John and Mary went to the"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
logits = model(input_ids)[0, -1]

expected_top_10pct = {
    " church": 0.0648,
    " house": 0.0367, # These are the two most likely tokens, and add up to >10%
}
top_10pct_sum = sum(expected_top_10pct.values())

observed_freqs = defaultdict(int)

N = 10000
for _ in tqdm(range(N)):
    token = TransformerSampler.sample_next_token(input_ids.squeeze(), logits, top_p=0.1)
    observed_freqs[tokenizer.decode(token)] += 1

for word in expected_top_10pct:
    expected_freq = expected_top_10pct[word] / top_10pct_sum
    observed_freq = observed_freqs[word] / N
    print(f"Word: {word!r:<9}. Expected freq {expected_freq:.4f}, observed freq {observed_freq:.4f}")
    assert abs(observed_freq - expected_freq) < 0.01, "Try increasing N if this fails by a small amount."

### Top-p Sampling - Example


In [None]:
sampler = TransformerSampler(model, tokenizer)

your_prompt = "Eliezer Shlomo Yudkowsky (born September 11, 1979) is an American decision and artificial intelligence (AI) theorist and writer, best known for"
output = sampler.sample(your_prompt, temperature=0.7, top_p=0.95, max_tokens_generated=64)
rprint(f"Your model said:\n\n[bold dark_orange]{output}")

<details>
<summary>Log probabilities are equal to the logit output after being translated by some amount X (where X is a function of the original logit output). Can you prove this?</summary>

Suppose our vector of logits is $x$, and we take softmax to get a vector of probabilities $p$, then log again to get a vector of log probabilities $l$. Then the $i$-th element of this vector of logprobs is:

$$
\begin{align}
l_i &= \log p_i \\
&= \log \frac{\exp(x_i)}{\sum_j \exp(x_j)} \\
&= x_i - \log \sum_j \exp(x_j) \\
&= x_i - C
\end{align}
$$

where $C = \log \sum_j \exp(x_j)$ is the same for all elements. So we can see that $l_i$ is equal to the logit output $x_i$ after being translated by $C$.

It's important not to mix up logits and logprobs!
</details>

<details>
<summary>Why do you think we use log softmax rather than logit output?</summary>

Logit output is translation invariant. If we had two different beams and we were generating the next tokens in those beams, there would be no reasonable way to compare the two beams to each other, because we could shift the logit vector for one beam by a constant amount without changing the distribution.

</details>


Note how after each "generate" stage, we have `num_beams ** 2` possible completions, which we then filter down to `num_beams`. Can you see why we need to generate this many (and what might happen if we generated fewer)?


How do we deal with sequences that terminate early (i.e. by generating an EOS token)? Answer - we append them to the list of completions which we'll return at the end, and remove them from the generation tree. Our algorithm terminates when either all our sequences have length `max_new_tokens` larger than the initial prompt length, or we've generated `num_returns_sequences` terminating sequences.
