In [None]:
import numpy as np
import pandas as pd
import tiktoken as tk
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

### Training Loop

- **Setup:**  
  - Instantiate model: `model = Transformer(vocab_size, d_model, num_heads, d_ff, num_layers).to(device)`  
  - Loss function: `nn.CrossEntropyLoss()` (expects logits `[B, T, V]` and targets `[B, T]`)  
  - Optimizer: `AdamW`  

- **Batching:**  
  - Use `batch_loader(raw_dataset, T=seq_len, B=batch_size, device=device)` to get `(x_batch, y_batch)`  
  - `x_batch` and `y_batch` shape: `[B, T]`  

- **Forward & Backward:**  
  - `logits = model(x_batch)` → shape `[B, T, vocab_size]`  
  - Reshape for loss: `logits.view(-1, vocab_size)` vs `y_batch.view(-1)`  
  - Compute loss and backpropagate  
  - Apply gradient clipping: `clip_grad_norm_(model.parameters(), 1.0)`  
  - `optimizer.step()` to update weights  

- **Epoch loop:**  
  - Track total loss per epoch  
  - Optional: sample generated text with `generate(model, start_text="The Emperor")` to monitor training flavor  


In [None]:
model = Transformer(vocab_size, d_model, num_heads, d_ff, num_layers).to(device)

criterion = nn.CrossEntropyLoss()  ### expects logits [B, T, V] and target [B, T]
optimizer = optim.AdamW(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    # raw_datase given
    x_batch, y_batch = batch_loader(raw_dataset, T=seq_len, B=batch_size, device=device)
    
    optimizer.zero_grad()
    logits = model(x_batch)  # [B, T, V]
    
    # reshape for CrossEntropy: [B*T, V] vs [B*T]
    loss = criterion(logits.view(-1, vocab_size), y_batch.view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # gradient clipping
    
    optimizer.step()
    
    total_loss += loss.item()
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f}")
    
    # Optional: sample a few tokens every epoch to check flavor
    sample_text = generate(model, start_text="The Emperor")
    print("Sample:", sample_text[:200], "...\n")

### Text Generation (`generate` function)

- **Purpose:** Generate text autoregressively from a starting string using the trained Transformer.  
- **Inputs:**  
  - `model`: trained Transformer model  
  - `start_text`: string to seed generation  
  - `tokenizer`: encoding/decoding utility (`encoder`)  
  - `max_tokens`: max number of tokens to generate  
  - `temperature`: controls randomness; higher → more diverse output  
  - `device`: `"cuda"` or `"cpu"`  

- **Process:**  
  1. Encode `start_text` to token IDs.  
  2. Loop up to `max_tokens`:  
     - Forward pass through model to get logits  
     - Take logits of last token and scale by `temperature`  
     - Apply softmax → probabilities  
     - Sample next token with `multinomial`  
     - Append token to sequence  
     - Stop if `<|endoftext|>` token is generated  

- **Output:**  
  - Decoded string containing generated text.  
  - Shape internally: `[1, seq_len]` → converted back to string.


In [None]:
@torch.no_grad()
def generate(model, start_text, tokenizer = encoder, max_tokens=50, temperature=1.0, device="cuda"):
    model.eval()
    
    # Encode starting text
    x = torch.tensor([tokenizer.encode(start_text)], dtype=torch.long, device=device)  # [1, seq_len]
    
    for _ in range(max_tokens):
        logits = model(x)             
        logits = logits[:, -1, :] / temperature   # take last token only

        probs = torch.softmax(logits, dim=-1)    # convert to probabilities
        next_token = torch.multinomial(probs, num_samples=1)  # sample next token

        x = torch.cat([x, next_token], dim=1)  # append to sequence

        # Stop if we hit <|endoftext|>
        if next_token.item() == tokenizer.eot_token:
            break

    # Decode back to text
    return tokenizer.decode(x[0].tolist())