In [1]:
import math
import os
import sys
import webbrowser
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import datasets
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import gelu_new, tokenize_and_concatenate
from transformers.models.gpt2.tokenization_gpt2_fast import GPT2TokenizerFast

device = t.device("mps" if t.backends.mps.is_available() else "cuda:0" if t.cuda.is_available() else "cpu")



In [8]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2_config = GPT2Config.from_pretrained("gpt2")
print(gpt2_config)

GPT2Config {
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.53.0",
  "use_cache": true,
  "vocab_size": 50257
}



In [3]:
model # gpt2-small

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [4]:
# reference_gpt2 = HookedTransformer.from_pretrained(
#     "gpt2-small",
#     fold_ln=False,
#     center_unembed=False,
#     center_writing_weights=False,  # you'll learn about these arguments later!
# )

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
# @dataclass
# class Config:
#     d_model: int = 768
#     debug: bool = True
#     layer_norm_eps: float = 1e-5
#     d_vocab: int = 50257
#     init_range: float = 0.02
#     n_ctx: int = 1024
#     d_head: int = 64
#     d_mlp: int = 3072
#     n_heads: int = 12
#     n_layers: int = 12

# model_cfg = Config(
#     debug=False,
#     d_model=256,
#     n_heads=4,
#     d_head=64,
#     d_mlp=1024,
#     n_layers=2,
#     n_ctx=256,
#     d_vocab=reference_gpt2.cfg.d_vocab,
# )
# # model = DemoTransformer(model_cfg)

@dataclass
class TransformerTrainingArgs:
    batch_size = 16
    epochs = 20
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "transformer"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

In [None]:
# create datasets

dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns("meta")
print(dataset)
print(dataset[0]["text"][:100])

tokenized_dataset = tokenize_and_concatenate(
    dataset,
#     reference_gpt2.tokenizer,
    tokenizer,
    streaming=False,
#     max_length=reference_gpt2.cfg.n_ctx,
    max_length=gpt2_config.n_ctx,
    column_name="text",
    add_bos_token=True,
    num_proc=4,
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
)
test_loader = DataLoader(
    dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
)

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


Map (num_proc=4):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [None]:
first_batch = train_loader.dataset[: args.batch_size]

print(first_batch.keys())
print(first_batch["tokens"].shape)

# find dictionaries with the single key 'tokens', which maps to a tensor of token IDs with shape (batch, seq_len).

# Training Loop

In [None]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model):
        super().__init__()
        self.model = model
        self.args = args

        self.wandb = False

        self.optimizer = t.optim.AdamW(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.step = 0

        self.train_loader = DataLoader(
            dataset_dict["train"], batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True
        )
        self.test_loader = DataLoader(
            dataset_dict["test"], batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True
        )

    def training_step(self, batch: dict[str, Int[Tensor, "batch seq"]]) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """
        # raise NotImplementedError()
        tokens = batch["tokens"].to(device) # shape (batch, seq)  

        logits = self.model(tokens) # shape (batch, seq) 
        loss = -get_log_probs(logits, tokens) # shape (batch, seq - 1)

        loss = loss.mean() # mean across all dimensions
        # what loss? CrossEntropy. Note: no "label" here. The fitting objective is the batch
        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()
        self.step += 1
        if self.wandb:
            wandb.log({'train_loss': loss}, step = self.step)
        return loss

    @t.inference_mode()
    def evaluate(self) -> float:
        """
        Evaluate the model on the test set and return the accuracy.
        """
        self.model.eval() # go eval mode
        # YOUR CODE HERE - fill in the `evaluate` method
        progress_bar = tqdm(self.test_loader, desc="Evaluating")
        
        acc = []
        # no epoch
        for batch in progress_bar:
            tokens = batch["tokens"].to(device) # shape (batch, seq)  
            logits = self.model(tokens)[:, :-1, ]  # shape (batch, seq-1, d_voc)  
            # Why seq-1: the last position has no label——the next token in the sequence
            preds = logits.argmax(dim=-1)  # shape (batch, seq-1)  
            # print(t.where(preds == batch, preds, 0).shape)
            # wrong = t.where(preds == batch, preds, 0).sum()  
            correct = (preds == tokens[:, 1:]).sum().item() # summing across both batch and (seq-1) dimensions
            n_samples = tokens.shape[0] * (tokens.shape[1] - 1)
            accuracy = correct / n_samples

            progress_bar.update()
            progress_bar.set_description(f"accuracy so far: {np.mean(acc):.3f}")

            if self.wandb:
                wandb.log({'accuracy': accuracy}, step = self.step)
            acc.append(accuracy)
        if self.wandb:       
            wandb.finish()

        self.model.train() # go back to train mode

        return np.mean(acc)

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        if self.wandb:
            wandb.init(project=self.args.wandb_project, name=self.args.wandb_name, config=self.args)
        accuracy = np.nan

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(f"Epoch {epoch + 1}, loss: {loss:.3f}, accuracy: {accuracy:.3f}")
                if i >= self.args.max_steps_per_epoch:
                    break

            accuracy = self.evaluate()
        if self.wandb == True:
            wandb.finish()


# model = reference_gpt2(model_cfg).to(device)
model = model.to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)


In [None]:

trainer.train()

In [None]:
# first iteration: get the training baseline running and wandb in notebook correctly configured. 
# If the latter failed, try to log manually