# 🚀 Advanced SLM Training with GPT-3 Tokenizer

## ✅ Complete Training System:
- **Multi-GPU Support**: DataParallel for Kaggle T4×2 GPUs
- **GPT-3 Tokenizer**: Using tiktoken (cl100k_base) - same as GPT-3.5/GPT-4
- **Multi-Dataset Training**: WikiText-2, Sherlock Holmes (Project Gutenberg), Numerical samples
- **Enhanced Config for T4×2**: 768 emb, 8 layers, 8 heads, 3072 FFN, 1024 context, FP16, gradient checkpointing
- **Enhanced .pt Files**: Complete architecture + metadata + tokenizer + training history
- **Automatic Fine-Tuning**: Literary datasets after base training
- **Cross-Platform**: Auto-detects Kaggle/Colab/local environments

## 📦 What's in the final.pt File:
- ✅ Complete model architecture (RMSNorm, RoPE, SwiGLU, MultiHeadAttention)
- ✅ All model weights and state dict
- ✅ Complete hyperparameters and training configuration
- ✅ GPT-3 tokenizer configuration and vocabulary info
- ✅ Training metadata (timestamps, epochs, loss history)
- ✅ Version tracking and dataset information
- ✅ Ready for future retraining with additional datasets

## 🔄 Training Workflow:
1. **Stage 1**: Train base model on WikiText-2 + Sherlock Holmes + Numerical data
2. **Stage 2**: Save comprehensive final.pt with full metadata
3. **Stage 3**: Automatically load and fine-tune on additional literary data
4. **Stage 4**: Save final enhanced model ready for deployment

**Run all cells sequentially to complete the full training pipeline!**

## 📦 Step 1: Install Packages & Detect Environment

In [None]:
!pip install -q torch>=2.0.0 transformers>=4.30.0 datasets>=2.14.0 tqdm>=4.65.0 huggingface_hub>=0.16.0 accelerate>=0.20.0 psutil tiktoken>=0.5.0

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset
import tiktoken
from datasets import load_dataset
from tqdm import tqdm
import time
import os
import sys
import math
import uuid
import shutil
import psutil
import gc
import requests
from pathlib import Path
from datetime import datetime
from typing import Optional, Dict, List, Any
import json
import re
from collections import Counter
import warnings

warnings.filterwarnings('ignore')
print("✅ Packages installed!")

def detect_environment():
    if 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
        return 'kaggle'
    elif 'COLAB_GPU' in os.environ or 'google.colab' in str(sys.modules):
        return 'colab'
    else:
        return 'local'

ENV = detect_environment()
print(f"\n🌍 Platform detected: {ENV.upper()}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\n🖥️  Using device: {device}")
if device == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    torch.cuda.empty_cache()

## 🚀 Step 1.5: Multi-GPU Setup with DataParallel

In [None]:
# Multi-GPU Setup - Using DataParallel for Jupyter notebooks
USE_MULTI_GPU = torch.cuda.device_count() > 1
USE_DDP = False  # DDP not suitable for notebooks, using DataParallel instead

if USE_MULTI_GPU:
    print(f"🚀 Multiple GPUs detected: {torch.cuda.device_count()} GPUs")
    print("   Using DataParallel for multi-GPU training in notebook environment")
    # DataParallel will be applied to model after creation
else:
    print(f"💻 Single GPU mode: Using {device}")

# Adjust batch size for multi-GPU
if USE_MULTI_GPU:
    # Use 4 per GPU for larger model
    effective_batch_size = 4 * torch.cuda.device_count()
    print(f"   Effective batch size: {effective_batch_size} (4 per GPU)")
else:
    effective_batch_size = 8


## 🗂️ Step 2: Configure Paths & Resource Monitoring

In [None]:
print("🔧 Configuring paths for", ENV.upper())

if ENV == 'kaggle':
    BASE_DIR = "/kaggle/working"
elif ENV == 'colab':
    BASE_DIR = "/content"
else:
    BASE_DIR = "."

CACHE_DIR = os.path.join(BASE_DIR, "cache")
MODEL_DIR = os.path.join(BASE_DIR, "models")
DATA_DIR = os.path.join(BASE_DIR, "data")
CHECKPOINT_DIR = os.path.join(BASE_DIR, "checkpoints")

for dir_path in [CACHE_DIR, MODEL_DIR, DATA_DIR, CHECKPOINT_DIR]:
    os.makedirs(dir_path, exist_ok=True)

print(f"\n📁 Directories created:")
print(f"   Models: {MODEL_DIR}")
print(f"   Data: {DATA_DIR}")
print(f"   Cache: {CACHE_DIR}")

def get_disk_usage(path=None):
    if path is None:
        path = BASE_DIR
    if os.path.exists(path):
        total, used, free = shutil.disk_usage(path)
    else:
        total, used, free = shutil.disk_usage(".")
    return {"total_gb": total / (1024**3), "used_gb": used / (1024**3), 
            "free_gb": free / (1024**3), "used_percent": (used / total) * 100}

def monitor_resources():
    disk = get_disk_usage()
    memory = psutil.virtual_memory()
    print(f"\n📊 Resources:")
    print(f"   Disk: {disk['free_gb']:.1f}GB free / {disk['total_gb']:.1f}GB total")
    print(f"   RAM: {memory.percent:.0f}% used")
    if device.startswith('cuda'):
        for gpu_id in range(torch.cuda.device_count()):
            gpu_mem = torch.cuda.memory_allocated(gpu_id) / (1024**3)
            gpu_total = torch.cuda.get_device_properties(gpu_id).total_memory / (1024**3)
            print(f"   GPU {gpu_id}: {gpu_mem:.1f}GB / {gpu_total:.1f}GB used")
def clear_cache():
    gc.collect()
    if device == 'cuda':
        torch.cuda.empty_cache()

monitor_resources()
print("\n✅ Environment ready!")

## 🔤 Step 3: Initialize GPT-3 Tokenizer (tiktoken)

In [None]:
print("🔤 Initializing GPT-3 tokenizer (tiktoken)...")

tokenizer = tiktoken.get_encoding("cl100k_base")
VOCAB_SIZE = tokenizer.n_vocab

print(f"✅ GPT-3 Tokenizer loaded!")
print(f"   Encoding: cl100k_base (same as GPT-3.5/GPT-4)")
print(f"   Vocab size: {VOCAB_SIZE:,}")

test_text = "Hello, World! 12345"
tokens = tokenizer.encode(test_text)
decoded = tokenizer.decode(tokens)
print(f"\n   Test encoding: '{test_text}'")
print(f"   Tokens: {tokens}")
print(f"   Decoded: '{decoded}'")
print(f"   Token count: {len(tokens)}")

## 🏗️ Step 4: Complete Model Architecture (Llama3-Inspired)

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / norm * self.weight

def precompute_freqs_cis(dim: int, max_len: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(max_len)
    freqs = torch.outer(t, freqs).float()
    freqs_cos = torch.cos(freqs)
    freqs_sin = torch.sin(freqs)
    return freqs_cos, freqs_sin

def apply_rotary_emb(xq, xk, freqs_cos, freqs_sin):
    xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2)
    xq_out_r = xq_r[..., 0] * freqs_cos - xq_r[..., 1] * freqs_sin
    xq_out_i = xq_r[..., 0] * freqs_sin + xq_r[..., 1] * freqs_cos
    xk_out_r = xk_r[..., 0] * freqs_cos - xk_r[..., 1] * freqs_sin
    xk_out_i = xk_r[..., 0] * freqs_sin + xk_r[..., 1] * freqs_cos
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(-2)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(-2)
    return xq_out.type_as(xq), xk_out.type_as(xk)

class SwiGLU(nn.Module):
    def __init__(self, dim: int, hidden_dim: Optional[int] = None):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = int(2 * dim * 4 / 3)
            hidden_dim = ((hidden_dim + 63) // 64) * 64
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, n_heads, max_len=1024, dropout=0.1):
        super().__init__()
        assert emb_size % n_heads == 0
        self.emb_size = emb_size
        self.n_heads = n_heads
        self.head_dim = emb_size // n_heads
        self.qkv = nn.Linear(emb_size, 3 * emb_size, bias=False)
        self.out = nn.Linear(emb_size, emb_size)
        self.dropout = nn.Dropout(dropout)
        self.freqs_cos, self.freqs_sin = precompute_freqs_cis(self.head_dim, max_len)

    def forward(self, x, mask=None):
        B, L, _ = x.shape
        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        self.freqs_cos = self.freqs_cos.to(x.device)
        self.freqs_sin = self.freqs_sin.to(x.device)
        q, k = apply_rotary_emb(q, k, self.freqs_cos[:L], self.freqs_sin[:L])
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = self.dropout(F.softmax(scores, dim=-1))
        out = torch.matmul(attn, v).transpose(1, 2).contiguous().reshape(B, L, self.emb_size)
        return self.out(out)

class TransformerBlock(nn.Module):
    def __init__(self, emb_size, n_heads, ff_size, max_len=1024, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(emb_size, n_heads, max_len, dropout)
        self.ff = SwiGLU(emb_size, ff_size)
        self.norm1 = RMSNorm(emb_size)
        self.norm2 = RMSNorm(emb_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        x = x + self.dropout(self.attention(self.norm1(x), mask))
        x = x + self.dropout(self.ff(self.norm2(x)))
        return x

class AdvancedLLM(nn.Module):
    def __init__(self, vocab_size=100277, emb_size=768, n_layers=8, n_heads=8, 
                 ff_size=3072, max_len=1024, dropout=0.1, use_gradient_checkpoint=True):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb_size = emb_size
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.max_len = max_len
        self.use_gradient_checkpoint = use_gradient_checkpoint
        self.token_embedding = nn.Embedding(vocab_size, emb_size)
        self.dropout = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(emb_size, n_heads, ff_size, max_len, dropout)
            for _ in range(n_layers)
        ])
        self.norm = RMSNorm(emb_size)
        self.lm_head = nn.Linear(emb_size, vocab_size, bias=False)
        self.token_embedding.weight = self.lm_head.weight
        self._init_weights()

    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
                if module.bias is not None:
                    torch.nn.init.zeros_(module.bias)
            elif isinstance(module, nn.Embedding):
                torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, x, mask=None):
        B, L = x.shape
        if mask is None:
            mask = torch.tril(torch.ones(L, L, device=x.device)).unsqueeze(0).unsqueeze(0)
        x = self.dropout(self.token_embedding(x))
        for block in self.blocks:
            if self.use_gradient_checkpoint and self.training:
                x = torch.utils.checkpoint.checkpoint(block, x, mask, use_reentrant=False)
            else:
                x = block(x, mask)
        return self.lm_head(self.norm(x))

    def get_num_params(self):
        return sum(p.numel() for p in self.parameters())

print("✅ Model architecture loaded: RMSNorm, RoPE, SwiGLU, Gradient Checkpointing")
print(f"   Configuration: 768 emb, 8 layers, 8 heads, 3072 ff_size")

## 📚 Step 5: Load and Prepare Datasets

In [None]:
print("📚 Loading datasets...")

def download_sherlock_holmes():
    print("\n📖 Downloading Sherlock Holmes from Project Gutenberg...")
    urls = [
        "https://www.gutenberg.org/files/1661/1661-0.txt",
        "https://www.gutenberg.org/files/2097/2097-0.txt",
        "https://www.gutenberg.org/files/244/244-0.txt"
    ]
    texts = []
    for i, url in enumerate(urls, 1):
        try:
            response = requests.get(url, timeout=30)
            response.raise_for_status()
            text = response.text
            start_idx = text.find("*** START")
            end_idx = text.find("*** END")
            if start_idx != -1 and end_idx != -1:
                text = text[start_idx:end_idx]
            texts.append(text)
            print(f"   ✅ Downloaded book {i}: {len(text):,} characters")
        except Exception as e:
            print(f"   ⚠️ Failed to download book {i}: {str(e)}")
    return "\n\n".join(texts)

def create_numerical_data():
    print("\n🔢 Creating numerical training samples...")
    numerical_samples = []
    for i in range(1000):
        num1, num2 = i * 2, i * 3
        numerical_samples.append(f"Number {i}: {num1}, {num2}, sum={num1+num2}")
    for i in range(100, 200):
        numerical_samples.append(f"Calculation: {i} * 2 = {i*2}, {i} + 5 = {i+5}")
    text = "\n".join(numerical_samples)
    print(f"   ✅ Created {len(numerical_samples):,} numerical samples")
    return text

sherlock_text = download_sherlock_holmes()
numerical_text = create_numerical_data()

print("\n📚 Loading WikiText-2...")
try:
    wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", cache_dir=CACHE_DIR)
    wikitext_text = "\n".join([item['text'] for item in wikitext if item['text'].strip()][:10000])
    print(f"   ✅ WikiText-2 loaded: {len(wikitext_text):,} characters")
except Exception as e:
    print(f"   ⚠️ WikiText-2 failed: {str(e)}. Using fallback text.")
    wikitext_text = "This is sample text for training. " * 1000

print("\n📊 Dataset Statistics:")
print(f"   Sherlock Holmes: {len(sherlock_text):,} characters")
print(f"   Numerical Data: {len(numerical_text):,} characters")
print(f"   WikiText: {len(wikitext_text):,} characters")

combined_base_text = wikitext_text + "\n\n" + sherlock_text + "\n\n" + numerical_text
print(f"   Combined Base Training: {len(combined_base_text):,} characters")

fine_tune_text = sherlock_text + "\n\n" + numerical_text
print(f"   Fine-tuning Dataset: {len(fine_tune_text):,} characters")

print("\n✅ All datasets loaded!")

## 🔧 Step 6: Dataset Tokenization and Preparation

In [None]:
class TextDataset(Dataset):
    def __init__(self, text, tokenizer, max_len=1024):
        self.tokenizer = tokenizer
        self.max_len = max_len
        print(f"\n🔧 Tokenizing {len(text):,} characters...")
        self.tokens = tokenizer.encode(text)
        print(f"   ✅ Created {len(self.tokens):,} tokens")
        self.num_samples = len(self.tokens) // max_len
        print(f"   📦 Dataset size: {self.num_samples:,} samples")
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        start_idx = idx * self.max_len
        end_idx = start_idx + self.max_len + 1
        if end_idx > len(self.tokens):
            end_idx = len(self.tokens)
            start_idx = end_idx - self.max_len - 1
        chunk = self.tokens[start_idx:end_idx]
        if len(chunk) < self.max_len + 1:
            chunk = chunk + [0] * (self.max_len + 1 - len(chunk))
        x = torch.tensor(chunk[:-1], dtype=torch.long)
        y = torch.tensor(chunk[1:], dtype=torch.long)
        return x, y

print("📦 Creating base training dataset...")
base_dataset = TextDataset(combined_base_text, tokenizer, max_len=1024)

print("\n📦 Creating fine-tuning dataset...")
finetune_dataset = TextDataset(fine_tune_text, tokenizer, max_len=1024)

print("\n✅ Datasets ready!")

## 💾 Step 7: Enhanced Model Saving/Loading Functions

In [None]:
def save_enhanced_model(model, save_path, config, training_history=None, stage='base', tokenizer_info=None):
    print(f"\n💾 Saving enhanced {stage} model to: {save_path}")
    
    if tokenizer_info is None:
        tokenizer_info = {
            'tokenizer_type': 'tiktoken',
            'encoding': 'cl100k_base',
            'vocab_size': VOCAB_SIZE,
            'description': 'GPT-3/GPT-4 tokenizer (tiktoken cl100k_base)'
        }
    
    enhanced_checkpoint = {
        'metadata': {
            'version': '3.0',
            'stage': stage,
            'created_at': datetime.now().isoformat(),
            'platform': ENV,
            'device': str(device),
            'pytorch_version': torch.__version__,
            'training_id': str(uuid.uuid4())[:8],
        },
        'model_architecture': {
            'class_name': 'AdvancedLLM',
            'components': ['RMSNorm', 'RoPE', 'SwiGLU', 'MultiHeadAttention', 'TransformerBlock'],
            'vocab_size': model.vocab_size,
            'emb_size': model.emb_size,
            'n_layers': model.n_layers,
            'n_heads': model.n_heads,
            'max_len': model.max_len,
            'use_gradient_checkpoint': model.use_gradient_checkpoint,
            'total_parameters': model.get_num_params(),
        },
        'model_state_dict': model.state_dict(),
        'training_config': config,
        'tokenizer_config': tokenizer_info,
        'training_history': training_history or {},
        'generation_config': {
            'temperature': 0.8,
            'top_k': 50,
            'top_p': 0.92,
            'repetition_penalty': 1.2,
            'no_repeat_ngram_size': 3,
        },
        'retraining_info': {
            'can_continue_training': True,
            'supported_datasets': ['literary', 'custom', 'wikitext', 'sherlock_holmes', 'numerical'],
            'fine_tuning_ready': True,
            'recommended_lr': 1e-4,
            'recommended_batch_size': 8,
        },
        'datasets_used': {
            'base_training': ['WikiText-2', 'Sherlock Holmes (Project Gutenberg)', 'Numerical Samples'],
            'fine_tuning': ['Sherlock Holmes', 'Numerical Samples'] if stage == 'fine_tuned' else [],
        }
    }
    
    torch.save(enhanced_checkpoint, save_path)
    file_size = os.path.getsize(save_path) / (1024**2)
    print(f"\n✅ Model saved successfully!")
    print(f"   📁 File: {save_path}")
    print(f"   💾 Size: {file_size:.1f}MB")
    print(f"   🔢 Parameters: {model.get_num_params():,}")
    print(f"   🎯 Stage: {stage}")
    print(f"   🔤 Tokenizer: {tokenizer_info['tokenizer_type']} ({tokenizer_info['encoding']})")
    return save_path

def load_enhanced_model(load_path, device='cuda'):
    print(f"\n📂 Loading enhanced model from: {load_path}")
    checkpoint = torch.load(load_path, map_location=device)
    
    print(f"\n📋 Model Information:")
    print(f"   Version: {checkpoint['metadata']['version']}")
    print(f"   Stage: {checkpoint['metadata']['stage']}")
    print(f"   Created: {checkpoint['metadata']['created_at']}")
    print(f"   Tokenizer: {checkpoint['tokenizer_config']['tokenizer_type']}")
    
    arch = checkpoint['model_architecture']
    
    model = AdvancedLLM(
        vocab_size=arch['vocab_size'],
        emb_size=arch['emb_size'],
        n_layers=arch['n_layers'],
        n_heads=arch['n_heads'],
        max_len=arch['max_len'],
        use_gradient_checkpoint=arch.get('use_gradient_checkpoint', False)
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    
    print(f"\n✅ Model loaded: {arch['total_parameters']:,} parameters")
    
    return model, checkpoint['training_config'], checkpoint.get('training_history', {})

print("✅ Enhanced save/load functions ready!")

## 🎯 Step 8: Enhanced Training Function with Validation & Checkpoints

In [None]:
def calculate_perplexity(loss):
    return math.exp(min(loss, 20))

def create_train_val_split(dataset, val_ratio=0.1):
    val_size = int(len(dataset) * val_ratio)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )
    return train_dataset, val_dataset

@torch.no_grad()
def evaluate_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    total_batches = 0
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = F.cross_entropy(logits.view(-1, (model.module.vocab_size if hasattr(model, "module") else model.vocab_size)), y.view(-1))
        total_loss += loss.item()
        total_batches += 1
    
    model.train()
    avg_loss = total_loss / total_batches
    perplexity = calculate_perplexity(avg_loss)
    return avg_loss, perplexity

def save_checkpoint(model, optimizer, scheduler, epoch, training_history, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.module.state_dict() if hasattr(model, "module") else model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'training_history': training_history,
        'timestamp': datetime.now().isoformat(),
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"   💾 Checkpoint saved: {checkpoint_path}")

def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device='cuda'):
    print(f"\n📂 Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and checkpoint.get('scheduler_state_dict'):
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    
    start_epoch = checkpoint['epoch'] + 1
    training_history = checkpoint.get('training_history', {})
    
    print(f"✅ Checkpoint loaded! Resuming from epoch {start_epoch}")
    return start_epoch, training_history

def train_model(model, dataset, config, stage_name='base', resume_from=None, val_dataset=None):
    print(f"\n🎯 Starting {stage_name} training...")
    print(f"   Dataset size: {len(dataset):,} samples")
    print(f"   Epochs: {config['epochs']}")
    print(f"   Batch size: {config['batch_size']}")
    print(f"   Learning rate: {config['learning_rate']}")
    print(f"   FP16: {config.get('use_fp16', False)}")
    print(f"   Validation: {val_dataset is not None}")
    
    model.train()
    model.to(device)
    
    if val_dataset is None:
        print("\n📊 Creating train/validation split (90/10)...")
        dataset, val_dataset = create_train_val_split(dataset, val_ratio=0.1)
        print(f"   Train: {len(dataset):,} samples")
        print(f"   Validation: {len(val_dataset):,} samples")
    
    # Setup distributed sampler for DDP
    train_sampler = None
    val_sampler = None
    if USE_MULTI_GPU:
        model = nn.DataParallel(model)
        print(f"   Model wrapped with DataParallel for {torch.cuda.device_count()} GPUs")
    # val_sampler = DistributedSampler(val_dataset, shuffle=False)
    
    train_loader = DataLoader(dataset, batch_size=config['batch_size'], sampler=train_sampler, shuffle=(train_sampler is None), 
                             num_workers=0, pin_memory=True if device == 'cuda' else False)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], sampler=val_sampler, shuffle=False,
                           num_workers=0, pin_memory=True if device == 'cuda' else False)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], 
                                 weight_decay=config.get('weight_decay', 0.01))
    
    total_steps = len(train_loader) * config['epochs']
    warmup_steps = int(0.1 * total_steps)
    
    def lr_lambda(step):
        if step < warmup_steps:
            return step / warmup_steps
        return 0.5 * (1 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    scaler = GradScaler() if config.get('use_fp16', False) and device == 'cuda' else None
    
    start_epoch = 0
    training_history = {
        'stage': stage_name,
        'train_losses': [],
        'val_losses': [],
        'val_perplexities': [],
        'epochs': [],
        'start_time': datetime.now().isoformat(),
    }
    
    if resume_from and os.path.exists(resume_from):
        start_epoch, loaded_history = load_checkpoint(model, optimizer, scheduler, resume_from, device)
        training_history.update(loaded_history)
    
    global_step = 0
    best_val_loss = training_history.get('best_val_loss', float('inf'))
    avg_train_loss = training_history.get('final_train_loss', None)
    val_loss = training_history.get('final_val_loss', None)
    
    for epoch in range(start_epoch, config['epochs']):
        # Set epoch for distributed sampler to ensure proper shuffling
        # DataParallel doesn't need epoch setting
        pass  # Using DataParallel, not DistributedSampler
        
        epoch_loss = 0
        epoch_start = time.time()
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        
        for batch_idx, (x, y) in enumerate(progress_bar):
            x, y = x.to(device), y.to(device)
            
            optimizer.zero_grad()
            
            if scaler:
                with autocast():
                    logits = model(x)
                    loss = F.cross_entropy(logits.view(-1, (model.module.vocab_size if hasattr(model, "module") else model.vocab_size)), y.view(-1))
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
            else:
                logits = model(x)
                loss = F.cross_entropy(logits.view(-1, (model.module.vocab_size if hasattr(model, "module") else model.vocab_size)), y.view(-1))
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            
            scheduler.step()
            
            epoch_loss += loss.item()
            global_step += 1
            
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'lr': f"{scheduler.get_last_lr()[0]:.6f}"
            })
            
            if batch_idx % 100 == 0 and device == 'cuda':
                torch.cuda.empty_cache()
        
        avg_train_loss = epoch_loss / len(train_loader)
        train_perplexity = calculate_perplexity(avg_train_loss)
        
        print(f"\n📊 Evaluating on validation set...")
        val_loss, val_perplexity = evaluate_model(model, val_loader, device)
        
        epoch_time = time.time() - epoch_start
        
        training_history['train_losses'].append(avg_train_loss)
        training_history['val_losses'].append(val_loss)
        training_history['val_perplexities'].append(val_perplexity)
        training_history['epochs'].append(epoch + 1)
        
        print(f"\n📊 Epoch {epoch+1}/{config['epochs']} Results:")
        print(f"   Train Loss: {avg_train_loss:.4f} | Perplexity: {train_perplexity:.2f}")
        print(f"   Val Loss: {val_loss:.4f} | Perplexity: {val_perplexity:.2f}")
        print(f"   Time: {epoch_time:.1f}s | LR: {scheduler.get_last_lr()[0]:.6f}")
        
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{stage_name}_checkpoint_epoch_{epoch+1}.pt")
        save_checkpoint(model, optimizer, scheduler, epoch, training_history, checkpoint_path)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_checkpoint_path = os.path.join(CHECKPOINT_DIR, f"{stage_name}_best.pt")
            save_checkpoint(model, optimizer, scheduler, epoch, training_history, best_checkpoint_path)
            print(f"   🎉 New best validation loss: {best_val_loss:.4f}")
    
    training_history['end_time'] = datetime.now().isoformat()
    
    if avg_train_loss is not None and val_loss is not None:
        training_history['best_val_loss'] = best_val_loss
        training_history['final_train_loss'] = avg_train_loss
        training_history['final_val_loss'] = val_loss
    print(f"\n✅ {stage_name.capitalize()} training completed!")
    if start_epoch < config['epochs']:
        if avg_train_loss is not None and val_loss is not None:
            print(f"   Best Val Loss: {best_val_loss:.4f}")
            print(f"   Final Train Loss: {avg_train_loss:.4f}")
            print(f"   Final Val Loss: {val_loss:.4f}")
    else:
        print(f"   Resumed from final epoch - training already complete")
        if 'best_val_loss' in training_history:
            print(f"   Best Val Loss: {training_history['best_val_loss']:.4f}")
        if 'final_val_loss' in training_history:
            print(f"   Final Val Loss: {training_history['final_val_loss']:.4f}")
    return training_history


## 🚀 Step 9: Initialize Model and Configuration

In [None]:
print("🚀 Initializing model...")

model = AdvancedLLM(
    vocab_size=VOCAB_SIZE,
    emb_size=768,
    n_layers=8,
    n_heads=8,
    ff_size=3072,
    max_len=1024,
    dropout=0.1,
    use_gradient_checkpoint=True
).to(device)

print(f"\n✅ Model initialized!")
print(f"   Total parameters: {model.get_num_params():,}")
print(f"   Model size: ~{model.get_num_params() * 4 / (1024**2):.1f}MB (FP32)")

base_config = {
    'epochs': 3,
    'batch_size': 4,
    'learning_rate': 3e-4,
    'weight_decay': 0.01,
    'use_fp16': device == 'cuda',
    'max_len': 1024,
    'dataset': 'WikiText-2 + Sherlock Holmes + Numerical',
}

finetune_config = {
    'epochs': 2,
    'batch_size': 4,
    'learning_rate': 1e-4,
    'weight_decay': 0.01,
    'use_fp16': device == 'cuda',
    'max_len': 1024,
    'dataset': 'Sherlock Holmes + Numerical (Fine-tuning)',
}

print("\n📋 Training Configuration:")
print(f"   Base Training: {base_config['epochs']} epochs, LR={base_config['learning_rate']}")
print(f"   Fine-tuning: {finetune_config['epochs']} epochs, LR={finetune_config['learning_rate']}")
print(f"   Batch size: {base_config['batch_size']}")
print(f"   FP16: {base_config['use_fp16']}")

monitor_resources()

## 🏋️ Step 10: Base Model Training

In [None]:
print("\n" + "="*80)
print("🏋️ STAGE 1: BASE MODEL TRAINING")
print("="*80)

base_history = train_model(model, base_dataset, base_config, stage_name='base')

print("\n" + "="*80)
print("💾 Saving base model as final.pt")
print("="*80)

base_model_path = os.path.join(MODEL_DIR, "final.pt")
save_enhanced_model(
    model=model,
    save_path=base_model_path,
    config=base_config,
    training_history=base_history,
    stage='base_trained',
    tokenizer_info={
        'tokenizer_type': 'tiktoken',
        'encoding': 'cl100k_base',
        'vocab_size': VOCAB_SIZE,
        'description': 'GPT-3/GPT-4 tokenizer (tiktoken cl100k_base)'
    }
)

print("\n✅ Base model training complete!")
print(f"   Model saved to: {base_model_path}")
monitor_resources()

## 🎨 Step 11: Automatic Fine-Tuning

In [None]:
print("\n" + "="*80)
print("🎨 STAGE 2: FINE-TUNING ON LITERARY DATA")
print("="*80)

print("\n📂 Loading base model for fine-tuning...")
model_ft, loaded_config, loaded_history = load_enhanced_model(base_model_path, device=device)

print("\n🎯 Starting fine-tuning...")
finetune_history = train_model(model_ft, finetune_dataset, finetune_config, stage_name='fine-tuning')

combined_history = {
    'base_training': base_history,
    'fine_tuning': finetune_history,
}

print("\n" + "="*80)
print("💾 Saving fine-tuned model")
print("="*80)

finetuned_model_path = os.path.join(MODEL_DIR, "final_finetuned.pt")
save_enhanced_model(
    model=model_ft,
    save_path=finetuned_model_path,
    config=finetune_config,
    training_history=combined_history,
    stage='fine_tuned',
    tokenizer_info={
        'tokenizer_type': 'tiktoken',
        'encoding': 'cl100k_base',
        'vocab_size': VOCAB_SIZE,
        'description': 'GPT-3/GPT-4 tokenizer (tiktoken cl100k_base)'
    }
)

print("\n" + "="*80)
print("🎉 TRAINING PIPELINE COMPLETED!")
print("="*80)
print(f"\n📦 Models saved:")
print(f"   Base model: {base_model_path}")
print(f"   Fine-tuned model: {finetuned_model_path}")
print(f"\n📊 Training Summary:")
print(f"   Base training loss: {base_history['final_loss']:.4f}")
print(f"   Fine-tuning loss: {finetune_history['final_loss']:.4f}")
print(f"   Total parameters: {model_ft.get_num_params():,}")
print(f"   Tokenizer: tiktoken (cl100k_base - GPT-3/GPT-4)")

monitor_resources()

## 🧪 Step 12: Test Text Generation

In [None]:
print("\n🧪 Testing text generation...")

@torch.no_grad()
def generate_text(model, prompt, max_tokens=100, temperature=0.8):
    model.eval()
    tokens = tokenizer.encode(prompt)
    input_ids = torch.tensor([tokens]).to(device)
    
    for _ in range(max_tokens):
        if input_ids.size(1) >= model.max_len:
            input_ids = input_ids[:, -model.max_len:]
        
        logits = model(input_ids)
        next_token_logits = logits[:, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        input_ids = torch.cat([input_ids, next_token], dim=1)
    
    generated_tokens = input_ids[0].tolist()
    return tokenizer.decode(generated_tokens)

test_prompts = [
    "Once upon a time",
    "Sherlock Holmes",
    "The number",
]

print("\n📝 Generated samples:\n")
for prompt in test_prompts:
    print(f"Prompt: '{prompt}'")
    generated = generate_text(model_ft, prompt, max_tokens=50)
    print(f"Generated: {generated}")
    print("-" * 80)

print("\n✅ Text generation test complete!")

## 📋 Step 13: Model Information Summary

In [None]:
print("\n" + "="*80)
print("📋 FINAL MODEL INFORMATION")
print("="*80)

checkpoint = torch.load(finetuned_model_path, map_location='cpu')

print("\n🔧 Architecture:")
for key, value in checkpoint['model_architecture'].items():
    print(f"   {key}: {value}")

print("\n🔤 Tokenizer:")
for key, value in checkpoint['tokenizer_config'].items():
    print(f"   {key}: {value}")

print("\n📊 Training History:")
if 'base_training' in checkpoint['training_history']:
    print(f"   Base training epochs: {len(checkpoint['training_history']['base_training'].get('train_losses', []))}")
    print(f"   Fine-tuning epochs: {len(checkpoint['training_history']['fine_tuning'].get('train_losses', []))}")
    print(f"   Final base val loss: {checkpoint['training_history']['base_training'].get('final_val_loss', 'N/A')}")
    print(f"   Final fine-tuning val loss: {checkpoint['training_history']['fine_tuning'].get('final_val_loss', 'N/A')}")

print("\n📚 Datasets Used:")
for key, datasets in checkpoint['datasets_used'].items():
    print(f"   {key}:")
    for dataset in datasets:
        print(f"      - {dataset}")

print("\n🔄 Retraining Information:")
for key, value in checkpoint['retraining_info'].items():
    print(f"   {key}: {value}")

print("\n📦 Files:")
for filename in os.listdir(MODEL_DIR):
    filepath = os.path.join(MODEL_DIR, filename)
    if os.path.isfile(filepath):
        size = os.path.getsize(filepath) / (1024**2)
        print(f"   {filename}: {size:.1f}MB")

print("\n" + "="*80)
print("✅ ALL TRAINING COMPLETE!")
print("="*80)
print("\nYour model is ready to use! The final.pt file contains:")
print("  ✅ Complete model architecture and weights")
print("  ✅ GPT-3 tokenizer configuration (tiktoken cl100k_base)")
print("  ✅ Full training history and metadata")
print("  ✅ Everything needed for future retraining")
print("  ✅ Validation metrics and perplexity scores")
print("  ✅ Checkpoint recovery for interrupted training")
print("\nSee cells below for:")
print("  - Continue training with new datasets")
print("  - Enhanced text generation interface")
print("  - Custom dataset upload support")

---
# 🔄 ADVANCED FEATURES
## Use the cells below for additional functionality

## 🔁 Continue Training from Saved Model

In [None]:
print("🔁 CONTINUE TRAINING WORKFLOW")
print("="*80)
print("\nThis cell demonstrates how to load a saved model and continue training.")
print("You can add any new dataset and continue training from where you left off.\n")

model_to_continue = input("Enter model path to load (or press Enter for latest): ").strip()
if not model_to_continue:
    model_to_continue = finetuned_model_path

print(f"\n📂 Loading model from: {model_to_continue}")
continued_model, cont_config, cont_history = load_enhanced_model(model_to_continue, device=device)

print("\n📝 Add your custom training data below:")
print("Example: additional_text = 'Your new training data here...'")
additional_text = input("Enter additional training text (or press Enter to skip): ").strip()

if additional_text:
    print(f"\n🔧 Creating dataset from new text ({len(additional_text)} characters)...")
    continue_dataset = TextDataset(additional_text, tokenizer, max_len=1024)
    
    continue_config = {
        'epochs': 2,
        'batch_size': 8,
        'learning_rate': 5e-5,
        'weight_decay': 0.01,
        'use_fp16': device == 'cuda',
        'max_len': 512,
        'dataset': 'Additional Custom Data',
    }
    
    print("\n🎯 Continuing training with new data...")
    continue_history = train_model(continued_model, continue_dataset, continue_config, stage_name='continued')
    
    continued_model_path = os.path.join(MODEL_DIR, "continued_model.pt")
    save_enhanced_model(
        model=continued_model,
        save_path=continued_model_path,
        config=continue_config,
        training_history={'previous': cont_history, 'continued': continue_history},
        stage='continued_training',
        tokenizer_info={
            'tokenizer_type': 'tiktoken',
            'encoding': 'cl100k_base',
            'vocab_size': VOCAB_SIZE,
            'description': 'GPT-3/GPT-4 tokenizer (tiktoken cl100k_base)'
        }
    )
    
    print(f"\n✅ Continued training complete! Model saved to: {continued_model_path}")
else:
    print("\n⏭️ Skipping continued training. Model loaded and ready for use.")

## 🎨 Enhanced Text Generation Interface

In [None]:
class EnhancedTextGenerator:
    def __init__(self, model, tokenizer, device='cuda'):
        self.model = model.to(device).eval()
        if hasattr(model, 'module'):  # Unwrap DDP
            self.model = model.module
        self.tokenizer = tokenizer
        self.device = device
    
    @torch.no_grad()
    def generate(self, prompt, max_tokens=150, temperature=0.8, top_k=50, 
                 top_p=0.92, repetition_penalty=1.2, num_samples=1):
        results = []
        
        for sample_idx in range(num_samples):
            tokens = self.tokenizer.encode(prompt)
            input_ids = torch.tensor([tokens]).to(self.device)
            
            past_tokens = set()
            
            for _ in range(max_tokens):
                if input_ids.size(1) >= self.model.max_len:
                    input_ids = input_ids[:, -self.model.max_len:]
                
                logits = self.model(input_ids)
                next_token_logits = logits[:, -1, :].float()
                
                if repetition_penalty != 1.0:
                    for token_id in past_tokens:
                        next_token_logits[:, token_id] /= repetition_penalty
                
                next_token_logits = next_token_logits / temperature
                
                if top_k > 0:
                    indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][:, -1, None]
                    next_token_logits[indices_to_remove] = float('-inf')
                
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
                    sorted_indices_to_remove[:, 0] = 0
                    indices_to_remove = sorted_indices[sorted_indices_to_remove]
                    next_token_logits[:, indices_to_remove] = float('-inf')
                
                probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
                
                input_ids = torch.cat([input_ids, next_token], dim=1)
                past_tokens.add(next_token.item())
            
            generated_tokens = input_ids[0].tolist()
            generated_text = self.tokenizer.decode(generated_tokens)
            results.append(generated_text)
        
        return results

print("🎨 Enhanced Text Generation Interface")
print("="*80)

generator = EnhancedTextGenerator(model_ft, tokenizer, device=device)

print("\n🎯 Interactive Generation Mode")
print("\nEnter prompts to generate text. Type 'quit' to exit.\n")

while True:
    user_prompt = input("\nPrompt (or 'quit'): ").strip()
    if user_prompt.lower() == 'quit':
        break
    
    if not user_prompt:
        continue
    
    try:
        max_tokens = int(input("Max tokens (default 100): ") or "100")
        temperature = float(input("Temperature 0.1-2.0 (default 0.8): ") or "0.8")
        num_samples = int(input("Number of samples (default 1): ") or "1")
        
        print(f"\n🔄 Generating {num_samples} sample(s)...\n")
        
        results = generator.generate(
            prompt=user_prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            num_samples=num_samples
        )
        
        for i, text in enumerate(results, 1):
            print(f"\n{'='*80}")
            print(f"Sample {i}:")
            print(f"{'='*80}")
            print(text)
    
    except ValueError as e:
        print(f"❌ Invalid input: {e}")
    except KeyboardInterrupt:
        print("\n\n⏸️ Generation interrupted.")
        break

print("\n✅ Text generation interface closed.")

## 📤 Custom Dataset Upload & Preprocessing

In [None]:
print("📤 CUSTOM DATASET UPLOAD & PREPROCESSING")
print("="*80)
print("\nThis cell allows you to upload and preprocess custom datasets for training.\n")

def preprocess_custom_dataset(text, min_length=10, max_length=1000):
    lines = text.split('\n')
    filtered_lines = []
    
    for line in lines:
        line = line.strip()
        if min_length <= len(line) <= max_length:
            filtered_lines.append(line)
    
    return '\n'.join(filtered_lines)

def load_custom_dataset_from_file(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            return f.read()
    except Exception as e:
        print(f"❌ Error loading file: {e}")
        return None

def load_custom_dataset_from_url(url):
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        return response.text
    except Exception as e:
        print(f"❌ Error downloading from URL: {e}")
        return None

print("Choose dataset source:")
print("1. Enter text directly")
print("2. Load from file path")
print("3. Download from URL")

choice = input("\nEnter choice (1/2/3): ").strip()

custom_text = None

if choice == '1':
    print("\nEnter your text (type 'END' on a new line to finish):\n")
    lines = []
    while True:
        line = input()
        if line.strip() == 'END':
            break
        lines.append(line)
    custom_text = '\n'.join(lines)

elif choice == '2':
    file_path = input("\nEnter file path: ").strip()
    custom_text = load_custom_dataset_from_file(file_path)

elif choice == '3':
    url = input("\nEnter URL: ").strip()
    custom_text = load_custom_dataset_from_url(url)

else:
    print("❌ Invalid choice")

if custom_text:
    print(f"\n📊 Dataset loaded: {len(custom_text):,} characters")
    
    preprocess = input("\nPreprocess dataset? (y/n): ").strip().lower() == 'y'
    
    if preprocess:
        min_len = int(input("Minimum line length (default 10): ") or "10")
        max_len = int(input("Maximum line length (default 1000): ") or "1000")
        custom_text = preprocess_custom_dataset(custom_text, min_len, max_len)
        print(f"✅ Preprocessed: {len(custom_text):,} characters")
    
    save_path = os.path.join(DATA_DIR, "custom_dataset.txt")
    with open(save_path, 'w', encoding='utf-8') as f:
        f.write(custom_text)
    print(f"\n💾 Custom dataset saved to: {save_path}")
    
    print("\n🔧 Creating training dataset...")
    custom_dataset = TextDataset(custom_text, tokenizer, max_len=1024)
    
    print("\n📊 Dataset ready for training!")
    print(f"   Characters: {len(custom_text):,}")
    print(f"   Samples: {len(custom_dataset):,}")
    print("\nYou can now use 'custom_dataset' in the training function.")
    print("Example: train_model(model, custom_dataset, config, stage_name='custom')")
else:
    print("\n⏭️ No dataset loaded.")

## 🚑 Checkpoint Recovery Example

In [None]:
print("🚑 CHECKPOINT RECOVERY")
print("="*80)
print("\nThis cell shows how to recover from an interrupted training session.\n")

print("📁 Available checkpoints:\n")
checkpoint_files = [f for f in os.listdir(CHECKPOINT_DIR) if f.endswith('.pt')]

if checkpoint_files:
    for i, ckpt in enumerate(checkpoint_files, 1):
        ckpt_path = os.path.join(CHECKPOINT_DIR, ckpt)
        size = os.path.getsize(ckpt_path) / (1024**2)
        print(f"{i}. {ckpt} ({size:.1f}MB)")
    
    print("\n💡 To resume training from a checkpoint:")
    print("\n# Load the model")
    print("recovery_model, _, _ = load_enhanced_model(base_model_path, device=device)")
    print("\n# Resume training with the checkpoint")
    print("checkpoint_path = os.path.join(CHECKPOINT_DIR, 'base_checkpoint_epoch_1.pt')")
    print("recovery_history = train_model(")
    print("    model=recovery_model,")
    print("    dataset=base_dataset,")
    print("    config=base_config,")
    print("    stage_name='recovered',")
    print("    resume_from=checkpoint_path")
    print(")")
    
    print("\n✅ The training will automatically resume from the saved epoch!")
else:
    print("No checkpoints found. Checkpoints are saved during training.")

print("\n📋 Checkpoint files contain:")
print("   - Model state (weights)")
print("   - Optimizer state")
print("   - Scheduler state")
print("   - Training history")
print("   - Current epoch number")
print("\nThis allows seamless recovery from any interruption!")

## 🧹 Step 19: Cleanup DataParallel Resources

In [None]:
# Clean up GPU resources
if device.startswith('cuda'):
    torch.cuda.empty_cache()
    if USE_MULTI_GPU:
        print(f"✅ Cleared cache on {torch.cuda.device_count()} GPUs")
    else:
        print("✅ GPU cache cleared")

print("\n🎉 Training pipeline completed successfully!")
print(f"   Final model saved: {os.path.join(MODEL_DIR, 'final_finetuned.pt')}")
print(f"   Model parameters: ~{(768 * 8 * 8 * 3072 * 4) / (1024**2):.0f}M")
