# PikoGPT Pretraining Pipeline
Training a GPT model on OpenWebText dataset

## 1. Setup & Environment

In [None]:
import sys
import pathlib
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast
from tqdm.auto import tqdm
import numpy as np

# Setup project paths
ROOT = pathlib.Path.cwd()
if ROOT.name == "notebooks":
    ROOT = ROOT.parent
sys.path.insert(0, str(ROOT))

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
# Training Configuration
class Config:
    # Data
    data_path = ROOT / "data" / "dataset_final" / "openwebtext_clean.jsonl"
    
    # Model Architecture
    vocab_size = 50257  # GPT-2 vocab size
    n_layer = 6         # Number of transformer blocks
    n_head = 6          # Number of attention heads
    n_embd = 384        # Embedding dimension
    block_size = 1024   # Maximum sequence length
    dropout = 0.1       # Dropout rate
    
    # Training
    batch_size = 8      # Batch size (adjust based on GPU memory)
    max_iters = 5000    # Total training iterations
    learning_rate = 3e-4
    weight_decay = 0.1
    beta1 = 0.9
    beta2 = 0.95
    grad_clip = 1.0     # Gradient clipping
    
    # Evaluation
    eval_interval = 500  # Evaluate every N iterations
    eval_iters = 100     # Number of iterations for evaluation
    
    # Checkpointing
    checkpoint_dir = ROOT / "runs" / "checkpoints"
    save_interval = 1000  # Save checkpoint every N iterations
    
    # System
    device = "cuda" if torch.cuda.is_available() else "cpu"
    seed = 42

config = Config()
config.checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Set random seed
torch.manual_seed(config.seed)
np.random.seed(config.seed)

print(f"Device: {config.device}")
print(f"Data path: {config.data_path}")

## 3. Tokenizer Setup

In [None]:
# Initialize GPT-2 tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# Set padding token (GPT-2 doesn't have one by default)
tokenizer.pad_token = tokenizer.eos_token

print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Special tokens: {tokenizer.special_tokens_map}")
print(f"EOS token: '{tokenizer.eos_token}' (ID: {tokenizer.eos_token_id})")

In [None]:
# Test tokenization
example_text = "This is an example sentence for tokenization."
tokens = tokenizer(example_text, return_tensors="pt")
print(f"\nOriginal text: {example_text}")
print(f"Token IDs: {tokens['input_ids']}")
print(f"Decoded: {tokenizer.decode(tokens['input_ids'][0])}")

## 4. Dataset Loading

In [None]:
class TextDataset(Dataset):
    """Dataset for loading and tokenizing JSONL text data"""
    
    def __init__(self, jsonl_path, tokenizer, block_size, max_samples=None):
        self.tokenizer = tokenizer
        self.block_size = block_size
        self.examples = []
        
        print(f"Loading data from {jsonl_path}...")
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(tqdm(f)):
                if max_samples and i >= max_samples:
                    break
                data = json.loads(line)
                text = data.get('text', '')
                if text.strip():  # Skip empty texts
                    self.examples.append(text)
        
        print(f"Loaded {len(self.examples)} documents")
    
    def __len__(self):
        return len(self.examples)
    
    def __getitem__(self, idx):
        text = self.examples[idx]
        
        # Tokenize and truncate to block_size
        encoding = self.tokenizer(
            text,
            max_length=self.block_size + 1,  # +1 for labels
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].squeeze()
        
        # Create inputs and labels (shifted by 1 for next-token prediction)
        x = input_ids[:-1]
        y = input_ids[1:]
        
        return x, y

# Load dataset (use max_samples for testing, remove for full training)
train_dataset = TextDataset(
    config.data_path,
    tokenizer,
    config.block_size,
    max_samples=1000  # Remove this for full dataset
)

print(f"\nDataset size: {len(train_dataset)}")
print(f"Block size: {config.block_size}")

In [None]:
# Create DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=0,  # Set to 0 on Windows to avoid multiprocessing issues
    pin_memory=True if config.device == "cuda" else False
)

print(f"Number of batches: {len(train_loader)}")
print(f"Batch size: {config.batch_size}")

# Test batch loading
x_sample, y_sample = next(iter(train_loader))
print(f"\nSample batch shapes:")
print(f"  Input (x): {x_sample.shape}")  # [batch_size, block_size]
print(f"  Labels (y): {y_sample.shape}")  # [batch_size, block_size]

## 5. Model Architecture

In [None]:
class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention"""
    
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        
        # Key, query, value projections
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        
        # Causal mask  
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))
    
    def forward(self, x):
        B, T, C = x.size()  # batch, sequence length, embedding dim
        
        # Calculate query, key, values
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        
        # Attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / (k.size(-1) ** 0.5))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        att = torch.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v
        
        # Re-assemble head outputs
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        
        return y

class MLP(nn.Module):
    """Feed-forward network"""
    
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(nn.Module):
    """Transformer block"""
    
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)
    
    def forward(self, x):
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

class GPT(nn.Module):
    """GPT Language Model"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),  # Token embeddings
            wpe = nn.Embedding(config.block_size, config.n_embd),  # Position embeddings
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Weight tying
        self.transformer.wte.weight = self.lm_head.weight
        
        # Initialize weights
        self.apply(self._init_weights)
        print(f"Model initialized with {self.get_num_params()/1e6:.2f}M parameters")
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())
    
    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Sequence length {t} exceeds block size {self.config.block_size}"
        
        # Forward pass
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        
        for block in self.transformer.h:
            x = block(x)
        
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)
        
        # Calculate loss if targets provided
        loss = None
        if targets is not None:
            loss = torch.nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=tokenizer.pad_token_id
            )
        
        return logits, loss

# Initialize model
model = GPT(config)
model = model.to(config.device)
print(f"Model moved to {config.device}")

## 6. Training Setup

In [None]:
# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    betas=(config.beta1, config.beta2),
    weight_decay=config.weight_decay
)

print(f"Optimizer: AdamW")
print(f"Learning rate: {config.learning_rate}")
print(f"Weight decay: {config.weight_decay}")

In [None]:
@torch.no_grad()
def estimate_loss(model, data_loader, max_iters):
    """Estimate average loss on dataset"""
    model.eval()
    losses = []
    
    for i, (x, y) in enumerate(data_loader):
        if i >= max_iters:
            break
        x, y = x.to(config.device), y.to(config.device)
        _, loss = model(x, y)
        losses.append(loss.item())
    
    model.train()
    return np.mean(losses)

def save_checkpoint(model, optimizer, iteration, loss, path):
    """Save model checkpoint"""
    torch.save({
        'iteration': iteration,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'config': config.__dict__
    }, path)
    print(f"Checkpoint saved to {path}")

def load_checkpoint(model, optimizer, path):
    """Load model checkpoint"""
    checkpoint = torch.load(path, map_location=config.device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['iteration'], checkpoint['loss']

## 7. Training Loop

In [None]:
# Training loop
model.train()
train_losses = []
iterations = []

print("\n" + "="*50)
print("Starting Training")
print("="*50)

iter_num = 0
running_loss = 0.0
progress_bar = tqdm(total=config.max_iters, desc="Training")

while iter_num < config.max_iters:
    for x, y in train_loader:
        # Move data to device
        x, y = x.to(config.device), y.to(config.device)
        
        # Forward pass
        logits, loss = model(x, y)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        
        # Update weights
        optimizer.step()
        
        # Track loss
        running_loss += loss.item()
        
        # Evaluation
        if iter_num % config.eval_interval == 0 and iter_num > 0:
            avg_loss = running_loss / config.eval_interval
            train_losses.append(avg_loss)
            iterations.append(iter_num)
            
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}', 
                'iter': iter_num
            })
            
            running_loss = 0.0
        
        # Save checkpoint
        if iter_num % config.save_interval == 0 and iter_num > 0:
            checkpoint_path = config.checkpoint_dir / f"checkpoint_iter_{iter_num}.pt"
            save_checkpoint(model, optimizer, iter_num, loss.item(), checkpoint_path)
        
        iter_num += 1
        progress_bar.update(1)
        
        if iter_num >= config.max_iters:
            break

progress_bar.close()
print("\n" + "="*50)
print("Training Complete!")
print("="*50)

# Save final model
final_checkpoint_path = config.checkpoint_dir / "final_model.pt"
save_checkpoint(model, optimizer, iter_num, loss.item(), final_checkpoint_path)

## 8. Text Generation & Evaluation

In [None]:
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8, top_k=50):
    """Generate text from a prompt"""
    model.eval()
    
    # Encode prompt
    tokens = tokenizer.encode(prompt, return_tensors='pt').to(config.device)
    
    for _ in range(max_new_tokens):
        # Crop tokens if sequence is too long
        tokens_cond = tokens if tokens.size(1) <= config.block_size else tokens[:, -config.block_size:]
        
        # Get predictions
        logits, _ = model(tokens_cond)
        logits = logits[:, -1, :] / temperature
        
        # Top-k sampling
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = -float('Inf')
        
        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat([tokens, next_token], dim=1)
        
        # Stop if EOS token is generated
        if next_token.item() == tokenizer.eos_token_id:
            break
    
    # Decode and return
    return tokenizer.decode(tokens[0].tolist())

# Test text generation
prompts = [
    "Once upon a time",
    "The future of artificial intelligence",
    "In a world where"
]

print("\n" + "="*50)
print("Generated Text Samples")
print("="*50 + "\n")

for prompt in prompts:
    generated = generate_text(model, tokenizer, prompt, max_new_tokens=50)
    print(f"Prompt: {prompt}")
    print(f"Generated: {generated}")
    print("-" * 50 + "\n")

In [None]:
# Plot training loss
import matplotlib.pyplot as plt

if len(train_losses) > 0:
    plt.figure(figsize=(10, 6))
    plt.plot(iterations, train_losses, label='Training Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss over Time')
    plt.legend()
    plt.grid(True)
    plt.show()
else:
    print("Not enough iterations for loss plot")

## 9. Next Steps

**To continue training:**
- Remove `max_samples=1000` limit in dataset loading (line with TextDataset)
- Increase `max_iters` to 10000+ for better results
- Adjust `batch_size` based on your GPU memory

**To load a checkpoint:**
```python
checkpoint_path = config.checkpoint_dir / "checkpoint_iter_1000.pt"
iter_num, loss = load_checkpoint(model, optimizer, checkpoint_path)
print(f"Loaded checkpoint from iteration {iter_num} with loss {loss:.4f}")
```

**Improvements:**
- Learning rate scheduling (cosine decay, warmup)
- Validation set evaluation
- Gradient accumulation for larger effective batch size
- Mixed precision training (torch.cuda.amp)