### Training an LLM

1. Iterate over training epochs (one epoch is a complete pass over a training set).
2. Iterate over batches in each training epoch (# of batches = training set size / size of each batch).
3. Reset loss gradients from previous batch iteration.
4. Calculate loss on current batch.
5. Backward pass to calculate loss gradients. 
6. Update model weights using loss gradients. 
7. Print training and validation set losses.
8. Generate sample text for visual inspection.

In [None]:
import torch
from Chapter04 import GPTModel, create_dataloader_v1, generate_text_simple
from Chapter05 import calc_loss_batch, calc_loss_loader, text_to_token_ids, token_ids_to_text, tokeniser

GPT_CONFIG_124M = {
    "vocab_size": 50_257,
    "context_length": 256, # shortened from 1,024
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)

In [None]:
def train_model_simple(model, train_loader, val_loader,
                       optimiser, device, num_epochs,
                       eval_freq, eval_iter, start_context, tokeniser):
    
    train_losses, val_losses, track_tokens_seen = [], [], []
    tokens_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs): # step 1
        model.train()
        for input_batch, target_batch in train_loader: # step 2
            optimiser.zero_grad() # step 3
            loss = calc_loss_batch( # step 4
                input_batch, target_batch, model, device
            )
            loss.backward() # step 5
            optimiser.step() # step 6
            tokens_seen += input_batch.numel()
            global_step += 1

            if global_step % eval_freq == 0: # step 7
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(f"Epoch {epoch+1} (Step {global_step:06d}): "
                      f"Train loss {train_loss:.3f}, "
                      f"Val loss {val_loss:.3f}"
                )

        generate_and_print_sample( # step 8
            model, tokeniser, device, start_context
        )
    return train_losses, val_losses, track_tokens_seen


In [None]:
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    model.eval() # dropout is disabled for stable, reproducible results
    with torch.no_grad(): # disables gradient tracking
        train_loss = calc_loss_loader(
            train_loader, model, device, num_batches=eval_iter
        )
        val_loss = calc_loss_loader(
            val_loader, model, device, num_batches=eval_iter
        )
    model.train()
    return train_loss, val_loss

In [None]:
def generate_and_print_sample(model, train_loader, val_loader, device, eval_iter):
    model.eval()
    context_size = model.pos_emb.weight.shape[0]
    encoded = text_to_token_ids(start_context, tokeniser).to(device)
    with torch.no_grad():
        token_ids = generate_text_simple(
            model=model, 
            idx=encoded,
            max_new_tokens=50,
            context_size=context_size
        )
    decoded_text = token_ids_to_text(token_ids, tokeniser)
    print(decoded_text.replace("\n", " "))
    model.train()