# Setup

In [18]:
from dataclasses import dataclass
import math
import os
from pprint import pprint

import datasets
from jaxtyping import Float, Int
import torch as t
from torch import Tensor
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate
from tqdm.notebook import tqdm
import wandb

import C1P1_solutions as solutions
from C1P1__mj_implementation import Config, DemoTransformer, get_log_probs

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

## Set device

In [2]:
device = t.device(
    "mps"
    if t.backends.mps.is_available()
    else "cuda"
    if t.cuda.is_available()
    else "cpu"
)
print(device)

cuda


## Load GPT-2 Small

In [3]:
reference_gpt2 = HookedTransformer.from_pretrained(
    "gpt2-small",
    fold_ln=False,
    center_unembed=False,
    center_writing_weights=False,
    device=device,
)



Loaded pretrained model gpt2-small into HookedTransformer


  return t.to(


# Training a Transformer

### Learning Objectives

- Understand how to train a transformer from scratch.

- Write a basic transformer training loop.

- Interpret the transformer's failing cross entropy loss with reference to features of the training data
  - E.g., bigram frequencies.

## Create Model

- We'll train a 2L, 4 heads per layer model, with context length 256, for 10*200 steps of batch size 16, just to show what it looks like

In [4]:
model_cfg = Config(
    debug=False,
    d_model=256,
    n_heads=4,
    d_head=64,
    d_mlp=1024,
    n_layers=2,
    n_ctx=256,
    d_vocab=reference_gpt2.cfg.d_vocab,
)
model = DemoTransformer(model_cfg)

## Training Arguments 

In [5]:
@dataclass
class TransformerTrainingArgs:
    batch_size = 16
    epochs = 10
    max_steps_per_epoch = 200
    lr = 1e-3
    weight_decay = 1e-2
    wandb_project: str | None = "day1-demotransformer"
    wandb_name: str | None = None


args = TransformerTrainingArgs()

## Create Data

We load in a tiny dataset made by Neel Nanda, with the first 10K entries in the Pile (inspired by Stas' version for OpenWebText!)

In [6]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train").remove_columns(
    "meta"
)
print(dataset)
print(dataset[0]["text"][:100])

Dataset({
    features: ['text'],
    num_rows: 10000
})
It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playi


- `tokenise_and_concatenate` is a useful function that takes our dataset of strings, and returns a dataset of token IPs ready to feed into the model.
  - We then create a data loader from this tokenised dataset.

- The useful method `train_test_split` can give us a training and test set

In [7]:
tokenized_dataset = tokenize_and_concatenate(
    dataset,
    reference_gpt2.tokenizer,
    streaming=False,
    max_length=model.cfg.n_ctx,
    column_name="text",
    add_bos_token=True,
    num_proc=4,
)

dataset_dict = tokenized_dataset.train_test_split(test_size=1000)
train_loader = DataLoader(
    dataset_dict["train"],
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)
test_loader = DataLoader(
    dataset_dict["test"],
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

When we iterate through these data loaders, we will find dictionaries with the single key `tokens`, which maps to a tensor of token IDs with shape `(batch, seq_len)`

In [8]:
first_batch = train_loader.dataset[: args.batch_size]

pprint(first_batch.keys())
pprint(first_batch["tokens"].shape)

dict_keys(['tokens'])
torch.Size([16, 256])


## Training Loop

- The key parts of the gradient update setup are:
  - Calculating the (cross-entropy) loss between a model's output and the true labels,

  - `loss.backward()` calculates gradients of the loss with respect to the model parameters,

  - `optimizer.step()` updates the model parameters using the gradients,

  - `optimizer.zero_grad()` zeros the gradients so they don't accumulate.

- The training loops can be packaged up into a class that includes methods for training and validation steps among other things.

- We can use dataclasses to store all arguments relevant to training in one place, and pass them to our trainer class.
  - Watch out for scope! Want to refer to `self.args` within the trainer class, not the global `args`.

- You can use Weights and Biases to track experiments and log relevant variables. The three essential functions are:
  - `wandb.init()` initialises a new run, takes arguments `project`, `name`, and `config`, among others.

  - `wandb.log()` logs a dictionary of variables. E.g., `{"loss": loss}`. Also takes a `step` argument.

  - `wandb.finish()` is called at the end of training.

### Exercise: write training loop

- Calculate cross entropy loss using `get_log_probs` from the previous section.

- Use the optimiser `t.optim.AdamW` (Adam with weight decay), and with hyperparameters `lr` and `weight_decay` taken from your `TransformerTrainingArgs` dataclass instance.

- Easy to calculate accuracy by having `validation_step` return a 1D boolean tensor indicating the positions where the model's prediction was correct.
  - Can concatenate all of these tensor together and take the mean to get the overall accuracy for the epoch.

- `max_steps_per_epoch` is provided as a hack to make sure the training phase in each epoch doesn't last too long.
  - You can terminate the training phase after this many steps.

- Remember to move tokens to your device via `tokens.to(device)`.

- Feel free to refer to the [training loops from Chapter 0](https://arena-ch0-fundamentals.streamlit.app/%5B0.3%5D_ResNets#training-loop).

In [9]:
class TransformerTrainer:
    def __init__(self, args: TransformerTrainingArgs, model: DemoTransformer):
        super().__init__()
        self.model = model
        self.args = args

        # AdamW is a variant of Adam that combines the advantages of Adam's adaptive learning rates
        # with the benefits of direct weight decay (decoupled from the optimization steps).
        # This leads to more effective regularisation and better generalisation.
        self.optimizer = t.optim.AdamW(
            self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay
        )

        self.step = 0

    def training_step(
        self, batch: dict[str, Int[Tensor, "batch seq"]]
    ) -> Float[Tensor, ""]:
        """
        Calculates the loss on the tokens in the batch, performs a gradient update step, and logs the loss.

        Remember that `batch` is a dictionary with the single key 'tokens'.
        """

        tokens = batch["tokens"].to(device)
        logits = self.model(tokens)

        # Calculate the loss. Mean to give a single scalar value indicating performance.abs
        # Negative to make it a minimisation problem.
        # Logarithm to simplify the calculation - sums instead of products.
        loss = -get_log_probs(logits, tokens).mean()

        # Computes the gradients of the loss with respect to all model parameters that have requires_grad=True.
        # Updates the grad attributes of the model's parameters.
        loss.backward()

        # Adjust model parameters according to new gradients
        self.optimizer.step()

        # Zero the gradients for the next iteration (PyTorch accumulates gradients by default)
        self.optimizer.zero_grad()

        self.step += 1

        wandb.log({"train_loss": loss.item(), "step": self.step})

        return loss

    def validation_step(self, batch: dict[str, Int[Tensor, "batch seq"]]):
        """
        Calculates & returns the accuracy on the tokens in the batch (i.e. how often the model's prediction
        is correct). Logging should happen in the `train` function (after we've computed the accuracy for
        the whole validation set).
        """
        tokens = batch["tokens"].to(device)

        logits = self.model(tokens)[:, :-1]
        predicted_tokens = logits.argmax(dim=-1)

        return (predicted_tokens == tokens[:, 1:]).flatten()

    def train(self):
        """
        Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
        for each epoch at `self.args.max_steps_per_epoch` steps.
        """
        wandb.init(
            project=self.args.wandb_project, name=self.args.wandb_name, config=self.args
        )

        accuracy = float("nan")

        progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

        for epoch in range(self.args.epochs):
            for i, batch in enumerate(self.train_loader()):
                loss = self.training_step(batch)
                progress_bar.update()
                progress_bar.set_description(
                    f"Epoch {epoch+1}, loss: {loss:.3f}, accuracy: {accuracy:.2f}"
                )
                if i >= self.args.max_steps_per_epoch:
                    break

            correct_predictions = t.concat(
                [self.validation_step(batch) for batch in self.test_loader()]
            )
            accuracy = correct_predictions.float().mean().item()
            wandb.log({"accuracy": accuracy}, step=self.step)

        wandb.finish()

    def train_loader(self) -> DataLoader:
        """Returns train loader (as in code above)."""
        return DataLoader(
            dataset_dict["train"],
            batch_size=self.args.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
        )

    def test_loader(self) -> DataLoader:
        """Returns test loader (as in code above)."""
        return DataLoader(
            dataset_dict["test"],
            batch_size=self.args.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
        )

In [10]:
model = DemoTransformer(model_cfg).to(device)
args = TransformerTrainingArgs()
trainer = TransformerTrainer(args, model)
# trainer.train()

#### Observations

- Comparing our run ([WandB link](https://wandb.ai/matthew-jennings/day1-demotransformer/runs/bqo9iwnz?nw=nwusermatthewjennings)) with one from the solutions ([WandB link](https://wandb.ai/callum-mcdougall/day1-demotransformer/reports/wandb-training-run--Vmlldzo2NTEyMzg4?accessToken=1xmja86i6ugeii40hvlefe1sp90v7cn9p4rrkhxobryjviuxfsydak4fx3gqx18q)) confirms similar performance.

### The loss curve

- The [loss curve](https://wandb.ai/matthew-jennings/day1-demotransformer/reports/train_loss-24-09-15-11-46-38---Vmlldzo5MzcxMDAx) seems to start at a value around 10-11, decreases quickly and then levels out.

- This is related to the kinds of algorithms that the model learns during training.

- At the beginning, the model outputs random noise; something like "predict each token with approximately uniform probability". I.e., $Q(x) = 1 / d_{vocab}$ for all $x$. This gives us a cross-entropy loss equal to $\log(d_{vocab})$.

In [14]:
d_vocab = model.cfg.d_vocab

print(f"d_vocab = {d_vocab}")
print(f"Cross entropy loss on uniform distribution = {math.log(d_vocab):.3f}")

d_vocab = 50257
Cross entropy loss on uniform distribution = 10.825


The next thing we might expect the model to learn is the frequencies of words in the English language (i.e., **unigram frequencies**). Small common tokens like `" and"` or `" the"` might appear much more frequently than others. This was give an average cross-entropy loss of:

$$
-\sum_x{p_x \log{p_x}}
$$

where $p_x$ is the actual frequency of the word/token in our training data.

We can evaluate this as follows:

In [16]:
tokens = tokenized_dataset[:]["tokens"].flatten()

freqs = t.bincount(tokens, minlength=model.cfg.d_vocab).float()

probs = freqs / freqs.sum()

distn = t.distributions.Categorical(probs=probs)
entropy = distn.entropy().item()

print(f"Entropy of the token distribution = {entropy:.3f}")

Entropy of the token distribution = 7.349


**Note the lower value of 7.349**

After unigram frequencies, the next thing our model usually learns is **bigram frequencies**; the frequencies of pairs of adjacent tokens in the training data. E.g., `"I"` and `" am"` are common tokens, but their bigram frequency is much higher than it would be if they occurred independently.

Bigram frequencies actually take you pretty far, since they help with:
- Some simple grammatical rules
  - E.g., a full stop being followed by a capitalised word.

- Weird quirks of tokenisation
  - E.g., ` "manip"` being followed by `"ulative"`

- Common names
  - E.g., `"Barack"` being followed by `" Obama"`

##### After approxmating bigram frequencies, we need smarter techniques, like:
  - Trigrams, which require attention heads
  
  - Induction heads

  - Fact memorisation

  - Other grammar/syntax rules.

Marginal improvement gets much harder, which flattens the loss curve.

### Exercise: log completions

- Choose a handle of prompts, and log the model's completions on those sentences.

- Log at a lower frequency than loss (e.g., once every 10-100 batches)

In [19]:
def sampling_fn(model: DemoTransformer, prompt: str) -> str:
    sampler = solutions.TransformerSampler(model, reference_gpt2.tokenizer)
    output = sampler.sample(
        prompt, temperature=0.7, top_p=0.95, max_tokens_generated=16
    )
    return output


model = DemoTransformer(model_cfg).to(device)

# Should be entirely random, because it uses a newly initialized model
print(sampling_fn(model, prompt="John and Mary went to the"))

John and Mary went to the Stewart Chrysar seals alien phenomenon unlockingNOTE intercourseiegemoderateCTographwu hunter75


In [22]:
@dataclass
class TransformerTrainingArgsLogText(TransformerTrainingArgs):
    text_sample_freq: int = 20
    table_log_freq: int = 200


def train_with_text_logs(self, sampling_fn, prompts):
    """
    Trains the model, for `self.args.epochs` epochs. Also handles wandb initialisation, and early stopping
    for each epoch at `self.args.max_steps_per_epoch` steps.

    This also takes 2 extra arguments:
        sampling_fn: function which takes model & a single prompt (i.e. text string) and returns text string output
        prompt_list: list of prompts we'll log output on
    """
    wandb.init(
        project=self.args.wandb_project, name=self.args.wandb_name, config=self.args
    )

    accuracy = float("nan")

    progress_bar = tqdm(total=self.args.max_steps_per_epoch * self.args.epochs)

    completions = []

    for epoch in range(1, self.args.epochs + 1):
        for i, batch in enumerate(self.train_loader()):
            loss = self.training_step(batch)
            progress_bar.update()
            progress_bar.set_description(
                f"Epoch {epoch}, loss: {loss:.3f}, accuracy: {accuracy:.2f}"
            )

            if self.step % self.args.text_sample_freq == 0:
                completions_this_step = [
                    sampling_fn(self.model, prompt) for prompt in prompts
                ]
                completions.append([epoch, self.step, *completions_this_step])

            if self.step % self.args.table_log_freq == 0:
                wandb.log(
                    {
                        "completions_table": wandb.Table(
                            data=completions,
                            columns=[
                                "Epoch",
                                "Step",
                                *[f"Prompt_{i}" for i in range(1, len(prompts) + 1)],
                            ],
                        )
                    }
                )

            if i >= self.args.max_steps_per_epoch:
                break

        correct_predictions = t.concat(
            [self.validation_step(batch) for batch in self.test_loader()]
        )
        accuracy = correct_predictions.float().mean().item()
        wandb.log({"accuracy": accuracy}, step=self.step)

    wandb.finish()


TransformerTrainer.train_with_text_logs = train_with_text_logs

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x00000202D2462F90>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 202decd3c90, raw_cell="@dataclass
class TransformerTrainingArgsLogText(Tr.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/c%3A/Users/matth/Workspace/arena3/mj/C1P1_transformer_from_scratch__S3_training.ipynb#X50sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x00000202D2462F90>> (for post_run_cell), with arguments args (<ExecutionResult object at 202dea1e250, execution_count=22 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 202decd3c90, raw_cell="@dataclass
class TransformerTrainingArgsLogText(Tr.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/c%3A/Users/matth/Workspace/arena3/mj/C1P1_transformer_from_scratch__S3_training.ipynb#X50sZmlsZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [23]:
prompts = [
    "John and Mary went to the",
    "The cat sat on the mat",
    "The quick brown fox jumped over the",
]

model = DemoTransformer(model_cfg).to(device)

args = TransformerTrainingArgsLogText()
trainer = TransformerTrainer(args, model)
trainer.train_with_text_logs(sampling_fn, prompts)

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x00000202D2462F90>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 202decd5450, raw_cell="prompts = [
    "John and Mary went to the",
    ".." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/c%3A/Users/matth/Workspace/arena3/mj/C1P1_transformer_from_scratch__S3_training.ipynb#X51sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

VBox(children=(Label(value='0.001 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.137212…

0,1
step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▄▄▂▄▄▂▄▃▄▃▃▂▃▂▃▁▂▃▂▃▃▂▃▃▂▃▂▃▂▁▁▁▂▃▂▁▁▂

0,1
step,200.0
train_loss,6.59604


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

  0%|          | 0/2000 [00:00<?, ?it/s]

VBox(children=(Label(value='0.305 MB of 0.311 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.978864…

0,1
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▆▅▆▄▅▄▄▃▅▄▅▂▃▂▄▃▁▂▄▄▃▃▃▃▂▃▂▃▂▂▁▂▃▂▂▃▂▃▂

0,1
step,2010.0
train_loss,5.30513
