# üöÄ Small Language Model Training - Enhanced Edition
## 125M Parameter Model with Large Dataset Support
- ‚úÖ T4 GPU Optimized (15GB Memory)
- ‚úÖ Support for large-scale datasets (OpenWebText, Wikipedia, C4)
- ‚úÖ All bugs fixed (RoPE dimensions, FP16 overflow)
- ‚úÖ Production-ready implementation

In [None]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -q
!pip install transformers datasets tokenizers accelerate einops matplotlib numpy tqdm -q
print("‚úÖ Dependencies installed")

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import json
import gc
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import warnings
from datetime import datetime
from datasets import load_dataset
from transformers import AutoTokenizer
warnings.filterwarnings('ignore')

# Memory optimizations for T4
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

print("="*80)
print("üöÄ ENHANCED SLM TRAINER - LARGE DATASET SUPPORT")
print("="*80)

In [None]:
def check_gpu():
    """Check for GPU availability and setup"""
    if torch.cuda.is_available():
        gpu_name = torch.cuda.get_device_name(0)
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"‚úÖ GPU Detected: {gpu_name}")
        print(f"   Total Memory: {total_memory:.2f} GB")
        
        # Set memory fraction for stability
        torch.cuda.set_per_process_memory_fraction(0.9)
        
        # Enable TF32 for faster computation on Ampere GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        
        return True
    else:
        print("‚ö†Ô∏è No GPU available. Using CPU (very slow)")
        return False

USE_GPU = check_gpu()
device = torch.device('cuda' if USE_GPU else 'cpu')
print(f"üîß Using device: {device}")

def clear_memory():
    """Clear GPU/CPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

clear_memory()

## üìö Dataset Selection
Choose from multiple large-scale datasets for training:

In [None]:
# SELECT YOUR DATASET HERE
# Options: 'tinystories', 'openwebtext', 'wikipedia', 'bookcorpus', 'c4', 'pile'
DATASET_NAME = 'tinystories'  # Change this to use larger datasets
MAX_SAMPLES = 100000  # Set to None for full dataset

print(f"üìä Dataset Information:")
print(f"   {'TinyStories':<15} - 2M stories, ~500MB (good for testing)")
print(f"   {'OpenWebText':<15} - 8M documents, ~40GB (GPT-2 quality)")
print(f"   {'Wikipedia':<15} - 6M articles, ~20GB (clean, factual)")
print(f"   {'BookCorpus':<15} - 74M sentences, ~5GB (books)")
print(f"   {'C4':<15} - 365M documents, ~300GB (massive web crawl)")
print(f"   {'The Pile':<15} - 800GB diverse text (best quality)")
print(f"\n‚úÖ Selected: {DATASET_NAME}")
if MAX_SAMPLES:
    print(f"   Limited to {MAX_SAMPLES:,} samples for demo")

In [None]:
@dataclass
class ModelConfig:
    """Configuration for 125M parameter model optimized for T4"""
    
    # Model architecture (125M parameters)
    vocab_size: int = 50257  # GPT-2 vocabulary
    hidden_size: int = 768   # Hidden dimension
    num_layers: int = 12     # Number of transformer layers
    num_heads: int = 12      # Number of attention heads
    ff_dim: int = 3072       # Feedforward dimension
    max_seq_len: int = 512   # Maximum sequence length
    
    # Training configuration
    batch_size: int = 8 if USE_GPU else 2
    gradient_accumulation_steps: int = 4  # Effective batch = 32
    learning_rate: float = 6e-4
    num_epochs: int = 1  # Start with 1 for large datasets
    warmup_steps: int = 1000
    max_grad_norm: float = 1.0
    dropout: float = 0.1
    weight_decay: float = 0.01
    
    # Memory optimizations
    gradient_checkpointing: bool = True
    use_mixed_precision: bool = USE_GPU
    
    # Logging
    log_interval: int = 50
    eval_interval: int = 500
    save_interval: int = 1000
    
    def model_size(self):
        """Calculate approximate model size in millions of parameters"""
        # Embedding parameters
        embedding_params = self.vocab_size * self.hidden_size * 2
        
        # Transformer layer parameters
        attention_params = 4 * self.hidden_size * self.hidden_size  # Q,K,V,O projections
        ff_params = 3 * self.hidden_size * self.ff_dim  # W1, W2, W3 for SwiGLU
        norm_params = 2 * self.hidden_size  # Two RMSNorms per layer
        layer_params = attention_params + ff_params + norm_params
        
        # Total
        total_params = embedding_params + (layer_params * self.num_layers) + self.hidden_size
        return total_params / 1e6

config = ModelConfig()
print(f"\nüìä Model Configuration:")
print(f"   Parameters: ~{config.model_size():.1f}M")
print(f"   Memory footprint: ~{config.model_size() * 4 / 1000:.2f} GB (FP32)")
print(f"   Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")

## üìñ Large Dataset Loading

In [None]:
def load_large_dataset(dataset_name, max_samples=None):
    """Load various large-scale datasets for training"""
    print(f"üìö Loading {dataset_name} dataset...")
    
    if dataset_name == 'tinystories':
        # Small dataset for testing
        split = f"train[:{max_samples}]" if max_samples else "train"
        dataset = load_dataset("roneneldan/TinyStories", split=split)
        val_split = f"validation[:1000]"
        val_dataset = load_dataset("roneneldan/TinyStories", split=val_split)
        
    elif dataset_name == 'openwebtext':
        # High-quality web text (GPT-2 training data)
        split = f"train[:{max_samples}]" if max_samples else "train"
        dataset = load_dataset("Skylion007/openwebtext", split=split)
        # Create validation split from last 5000 samples
        val_dataset = load_dataset("Skylion007/openwebtext", split="train[-5000:]")
        
    elif dataset_name == 'wikipedia':
        # Wikipedia English
        split = f"train[:{max_samples}]" if max_samples else "train"
        dataset = load_dataset("wikipedia", "20220301.en", split=split)
        val_dataset = load_dataset("wikipedia", "20220301.en", split="train[-5000:]")
        
    elif dataset_name == 'bookcorpus':
        # Books dataset
        split = f"train[:{max_samples}]" if max_samples else "train"
        dataset = load_dataset("bookcorpusopen", split=split)
        val_dataset = load_dataset("bookcorpusopen", split="train[-5000:]")
        
    elif dataset_name == 'c4':
        # Colossal Clean Crawled Corpus
        samples = max_samples if max_samples else 1000000
        split = f"train[:{samples}]"
        dataset = load_dataset("c4", "en", split=split, streaming=False)
        val_dataset = load_dataset("c4", "en", split="validation[:5000]")
        
    elif dataset_name == 'pile':
        # The Pile - highest quality diverse dataset
        samples = max_samples if max_samples else 100000
        split = f"train[:{samples}]"
        dataset = load_dataset("EleutherAI/pile", split=split, streaming=False)
        val_dataset = load_dataset("EleutherAI/pile", split="validation[:5000]")
        
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    
    print(f"‚úÖ Loaded {len(dataset):,} training samples")
    print(f"‚úÖ Loaded {len(val_dataset):,} validation samples")
    
    return dataset, val_dataset

# Load the selected dataset
train_dataset, val_dataset = load_large_dataset(DATASET_NAME, MAX_SAMPLES)

In [None]:
# Initialize tokenizer
print("üî§ Initializing tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
config.vocab_size = len(tokenizer)

# Tokenization function
def tokenize_function(examples):
    # Handle different field names
    text_field = 'text' if 'text' in examples else 'story' 
    return tokenizer(
        examples[text_field],
        truncation=True,
        padding='max_length',
        max_length=config.max_seq_len,
        return_tensors=None
    )

print("‚öôÔ∏è Tokenizing datasets...")
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
tokenized_val = val_dataset.map(tokenize_function, batched=True, remove_columns=val_dataset.column_names)

print(f"‚úÖ Tokenization complete")

In [None]:
# Create PyTorch datasets
class TextDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.dataset = tokenized_dataset
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        input_ids = torch.tensor(item['input_ids'], dtype=torch.long)
        labels = input_ids.clone()
        return input_ids, labels

train_dataset = TextDataset(tokenized_train)
val_dataset = TextDataset(tokenized_val)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config.batch_size, 
    shuffle=True,
    num_workers=2 if USE_GPU else 0,
    pin_memory=USE_GPU
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config.batch_size * 2, 
    shuffle=False,
    num_workers=2 if USE_GPU else 0,
    pin_memory=USE_GPU
)

print(f"\nüìä Data Statistics:")
print(f"   Training batches: {len(train_loader):,}")
print(f"   Validation batches: {len(val_loader):,}")
print(f"   Tokens per epoch: ~{len(train_loader) * config.batch_size * config.max_seq_len:,}")

## üèóÔ∏è Model Architecture (with fixes)

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""
    
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return norm * self.weight

In [None]:
class RotaryPositionalEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""
    
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        
        # Precompute the frequency tensor
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        
        # Precompute cos and sin for max sequence length
        t = torch.arange(max_seq_len).float()
        freqs = torch.outer(t, self.inv_freq)
        self.register_buffer('cos_cached', torch.cos(freqs))
        self.register_buffer('sin_cached', torch.sin(freqs))
    
    def forward(self, x, seq_len):
        return self.cos_cached[:seq_len], self.sin_cached[:seq_len]

def apply_rotary_pos_emb(q, k, cos, sin):
    """Apply rotary embeddings to queries and keys"""
    def rotate_half(x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)
    
    # Repeat cos and sin to match full head_dim (FIX for dimension mismatch)
    cos = torch.cat([cos, cos], dim=-1)
    sin = torch.cat([sin, sin], dim=-1)
    
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [None]:
class MultiHeadAttention(nn.Module):
    """Multi-head attention with RoPE"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.hidden_size // self.num_heads
        
        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
        
        self.rope = RotaryPositionalEmbedding(self.head_dim, config.max_seq_len)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x, mask=None):
        B, L, D = x.shape
        
        # Project to Q, K, V
        q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(B, L, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(B, L, self.num_heads, self.head_dim)
        
        # Apply RoPE
        cos, sin = self.rope(x, L)
        cos = cos.unsqueeze(0).unsqueeze(2)  # [seq_len, dim//2] -> [1, seq_len, 1, dim//2]
        sin = sin.unsqueeze(0).unsqueeze(2)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)
        
        # Reshape for attention
        q = q.transpose(1, 2)  # (B, num_heads, L, head_dim)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Apply causal mask
        if mask is None:
            mask = torch.triu(torch.ones(L, L, device=x.device), diagonal=1).bool()
        # Use smaller value for FP16 compatibility (FIX for overflow)
        mask_value = -1e4 if scores.dtype == torch.float16 else -1e9
        scores = scores.masked_fill(mask, mask_value)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).contiguous().view(B, L, D)
        
        # Output projection
        out = self.o_proj(out)
        return out

In [None]:
class FeedForward(nn.Module):
    """Feed-forward network with SwiGLU activation"""
    
    def __init__(self, config):
        super().__init__()
        self.w1 = nn.Linear(config.hidden_size, config.ff_dim, bias=False)
        self.w2 = nn.Linear(config.ff_dim, config.hidden_size, bias=False)
        self.w3 = nn.Linear(config.hidden_size, config.ff_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        # SwiGLU activation: swish(W1(x)) * W3(x)
        return self.w2(self.dropout(F.silu(self.w1(x)) * self.w3(x)))

In [None]:
class TransformerBlock(nn.Module):
    """Transformer block with pre-normalization"""
    
    def __init__(self, config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)
        self.ln1 = RMSNorm(config.hidden_size)
        self.ln2 = RMSNorm(config.hidden_size)
    
    def forward(self, x, mask=None):
        # Pre-norm architecture
        x = x + self.attention(self.ln1(x), mask)
        x = x + self.feed_forward(self.ln2(x))
        return x

In [None]:
class SmallLanguageModel(nn.Module):
    """125M parameter language model optimized for T4"""
    
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token and position embeddings
        self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        self.dropout = nn.Dropout(config.dropout)
        
        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_layers)
        ])
        
        # Output layers
        self.ln_f = RMSNorm(config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Weight tying
        self.lm_head.weight = self.token_embedding.weight
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    def forward(self, input_ids, labels=None):
        # Token embeddings
        x = self.token_embedding(input_ids)
        x = self.dropout(x)
        
        # Create causal mask
        B, L = input_ids.shape
        mask = torch.triu(torch.ones(L, L, device=input_ids.device), diagonal=1).bool()
        
        # Apply transformer blocks
        for block in self.blocks:
            if self.config.gradient_checkpointing and self.training:
                x = torch.utils.checkpoint.checkpoint(block, x, mask)
            else:
                x = block(x, mask)
        
        # Output layer
        x = self.ln_f(x)
        logits = self.lm_head(x)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1)
            )
        
        return logits, loss
    
    @torch.no_grad()
    def generate(self, input_ids, max_length=100, temperature=0.8, top_p=0.9):
        """Generate text using the model"""
        self.eval()
        
        for _ in range(max_length - input_ids.shape[1]):
            # Forward pass
            logits, _ = self(input_ids)
            
            # Get next token logits
            next_token_logits = logits[:, -1, :] / temperature
            
            # Apply top-p (nucleus) sampling
            sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # Remove tokens with cumulative probability above threshold
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
            next_token_logits[indices_to_remove] = float('-inf')
            
            # Sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=-1)
            
            # Stop if EOS token
            if next_token.item() == tokenizer.eos_token_id:
                break
        
        return input_ids

In [None]:
# Initialize model
model = SmallLanguageModel(config).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("‚úÖ Model initialized:")
print(f"   Total parameters: {total_params/1e6:.1f}M")
print(f"   Trainable parameters: {trainable_params/1e6:.1f}M")
print(f"   Memory footprint: ~{total_params * 4 / 1e9:.2f} GB (FP32)")

clear_memory()

## üöÄ Training Setup

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay,
    betas=(0.9, 0.95)
)

# Learning rate scheduler
total_steps = len(train_loader) * config.num_epochs
def get_lr_lambda(current_step):
    # Warmup
    if current_step < config.warmup_steps:
        return current_step / config.warmup_steps
    # Cosine decay
    progress = (current_step - config.warmup_steps) / (total_steps - config.warmup_steps)
    return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=get_lr_lambda)

# Mixed precision scaler
scaler = GradScaler() if config.use_mixed_precision else None

print("‚úÖ Training setup complete")
print(f"   Optimizer: AdamW")
print(f"   Learning rate: {config.learning_rate}")
print(f"   Warmup steps: {config.warmup_steps}")
print(f"   Mixed precision: {config.use_mixed_precision}")

## üéØ Training Loop

In [None]:
def train_epoch(model, loader, optimizer, scheduler, scaler, config, epoch):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(loader, desc=f"Training Epoch {epoch}")
    
    for i, (input_ids, labels) in enumerate(progress_bar):
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        
        # Forward pass
        if config.use_mixed_precision:
            with autocast():
                logits, loss = model(input_ids, labels)
                loss = loss / config.gradient_accumulation_steps
        else:
            logits, loss = model(input_ids, labels)
            loss = loss / config.gradient_accumulation_steps
        
        # Backward pass
        if config.use_mixed_precision:
            scaler.scale(loss).backward()
        else:
            loss.backward()
        
        # Gradient accumulation
        if (i + 1) % config.gradient_accumulation_steps == 0:
            if config.use_mixed_precision:
                scaler.unscale_(optimizer)
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
            
            if config.use_mixed_precision:
                scaler.step(optimizer)
                scaler.update()
            else:
                optimizer.step()
            
            scheduler.step()
            optimizer.zero_grad()
        
        total_loss += loss.item() * config.gradient_accumulation_steps
        
        # Update progress bar
        if i % config.log_interval == 0:
            avg_loss = total_loss / (i + 1)
            current_lr = scheduler.get_last_lr()[0]
            progress_bar.set_postfix({
                'loss': f'{avg_loss:.4f}',
                'lr': f'{current_lr:.2e}',
                'ppl': f'{math.exp(avg_loss):.2f}'
            })
        
        # Free memory periodically
        if i % 100 == 0:
            clear_memory()
    
    return total_loss / len(loader)

In [None]:
@torch.no_grad()
def evaluate(model, loader, config):
    """Evaluate the model"""
    model.eval()
    total_loss = 0
    
    for input_ids, labels in tqdm(loader, desc="Evaluating", leave=False):
        input_ids = input_ids.to(device)
        labels = labels.to(device)
        
        logits, loss = model(input_ids, labels)
        total_loss += loss.item()
    
    avg_loss = total_loss / len(loader)
    perplexity = math.exp(avg_loss)
    
    return avg_loss, perplexity

In [None]:
@torch.no_grad()
def generate_samples(model, tokenizer, prompts, max_length=100, temperature=0.8):
    """Generate text samples"""
    model.eval()
    
    for prompt in prompts:
        print(f"\nüìù Prompt: {prompt}")
        
        # Tokenize prompt
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        
        # Generate
        output_ids = model.generate(input_ids, max_length=max_length, temperature=temperature)
        
        # Decode
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        print(f"üí¨ Generated: {generated_text}")

## üéì Train the Model

In [None]:
# Training configuration
print("\n" + "="*80)
print("üöÄ Starting training...")
print("="*80)
print(f"Dataset: {DATASET_NAME}")
print(f"Training samples: {len(train_dataset):,}")
print(f"Epochs: {config.num_epochs}")
print(f"Batch size: {config.batch_size} x {config.gradient_accumulation_steps} = {config.batch_size * config.gradient_accumulation_steps}")
print("="*80)

# Test generation before training
print("\nüîÆ Testing generation (untrained model)...")
test_prompts = ["Once upon a time", "The future of AI is"]
generate_samples(model, tokenizer, test_prompts, max_length=50)

# Training loop
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(1, config.num_epochs + 1):
    print(f"\nüìÖ Epoch {epoch}/{config.num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, config, epoch)
    train_losses.append(train_loss)
    
    # Evaluate
    val_loss, val_ppl = evaluate(model, val_loader, config)
    val_losses.append(val_loss)
    
    print(f"\nüìä Epoch {epoch} Results:")
    print(f"   Train Loss: {train_loss:.4f} | Perplexity: {math.exp(train_loss):.2f}")
    print(f"   Val Loss: {val_loss:.4f} | Perplexity: {val_ppl:.2f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'config': config
        }, f'best_model_{DATASET_NAME}.pt')
        print(f"   üíæ Saved best model (val_loss: {val_loss:.4f})")
    
    # Generate samples
    print("\nüîÆ Generating samples...")
    generate_samples(model, tokenizer, test_prompts, max_length=50)
    
    clear_memory()

print("\n‚úÖ Training complete!")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot([math.exp(l) for l in train_losses], label='Train Perplexity')
plt.plot([math.exp(l) for l in val_losses], label='Val Perplexity')
plt.xlabel('Epoch')
plt.ylabel('Perplexity')
plt.title('Perplexity Progress')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig(f'training_curves_{DATASET_NAME}.png')
plt.show()

print(f"\nüìà Final Results:")
print(f"   Best Val Loss: {best_val_loss:.4f}")
print(f"   Best Val Perplexity: {math.exp(best_val_loss):.2f}")

## üé® Interactive Generation

In [None]:
def interactive_generation():
    """Interactive text generation interface"""
    print("\nüé® Interactive Text Generation")
    print("Type 'quit' to exit")
    print("-" * 40)
    
    while True:
        prompt = input("\nEnter prompt: ")
        if prompt.lower() == 'quit':
            break
        
        try:
            max_len = int(input("Max length (default 100): ") or 100)
            temp = float(input("Temperature 0.1-2.0 (default 0.8): ") or 0.8)
        except ValueError:
            max_len = 100
            temp = 0.8
        
        print("\nü§ñ Generating...")
        input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
        output_ids = model.generate(input_ids, max_length=max_len, temperature=temp)
        generated = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        print(f"\nüìù Generated text:\n{generated}")

# Run interactive generation
interactive_generation()

## üìä Model Analysis

In [None]:
# Analyze model
print("\nüìä Model Analysis:")
print("="*50)

# Parameter count by component
components = {
    'Embeddings': sum(p.numel() for n, p in model.named_parameters() if 'embedding' in n),
    'Attention': sum(p.numel() for n, p in model.named_parameters() if 'attention' in n or 'q_proj' in n or 'k_proj' in n or 'v_proj' in n or 'o_proj' in n),
    'FeedForward': sum(p.numel() for n, p in model.named_parameters() if 'w1' in n or 'w2' in n or 'w3' in n),
    'Normalization': sum(p.numel() for n, p in model.named_parameters() if 'ln' in n or 'norm' in n),
    'Output': sum(p.numel() for n, p in model.named_parameters() if 'lm_head' in n)
}

total_params = sum(components.values())
for name, count in components.items():
    pct = (count / total_params) * 100
    print(f"{name:<15}: {count/1e6:>8.2f}M ({pct:>5.1f}%)")

print("="*50)
print(f"{'Total':<15}: {total_params/1e6:>8.2f}M")

# Memory usage
if USE_GPU:
    print(f"\nüíæ GPU Memory:")
    print(f"   Allocated: {torch.cuda.memory_allocated()/1e9:.2f} GB")
    print(f"   Reserved: {torch.cuda.memory_reserved()/1e9:.2f} GB")

## üí° Tips for Large-Scale Training

### Dataset Recommendations by Model Size:
- **125M parameters**: 1-10B tokens minimum
- **350M parameters**: 10-50B tokens 
- **1B parameters**: 50-200B tokens
- **3B+ parameters**: 200B+ tokens

### Training Time Estimates (T4 GPU):
- **TinyStories** (2M samples): ~30 minutes/epoch
- **OpenWebText** (8M samples): ~8 hours/epoch
- **Wikipedia** (6M articles): ~6 hours/epoch
- **C4** (365M samples): ~7 days/epoch

### Memory Optimization Tips:
1. Use gradient checkpointing (already enabled)
2. Reduce batch size if OOM
3. Use mixed precision training (FP16)
4. Consider gradient accumulation
5. Use DeepSpeed or FSDP for multi-GPU

### Quality Improvements:
1. Train for more epochs (3-10)
2. Use larger, cleaner datasets
3. Implement learning rate decay
4. Add dropout and weight decay
5. Use validation for early stopping