In [None]:
 # Cell 1: Environment Setup and Imports
# Install libraries (if needed) and import all required modules

# Install required packages (run this if packages are missing)
import subprocess
import sys

def install_package(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# List of required packages
required_packages = [
    "transformers>=4.36.0",
    "datasets>=2.14.0", 
    "torch>=2.0.0",
    "accelerate>=0.24.0",
    "flash-attn>=2.3.0",
    "wandb",
    "matplotlib",
    "seaborn",
    "numpy",
    "tqdm"
]

# Install packages if needed (uncomment if running first time)
# for package in required_packages:
#     install_package(package.split(">=")[0])

# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast

# Transformers and datasets
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoConfig,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    TrainingArguments,
    set_seed
)
from datasets import load_dataset

# Utilities
import os
import json
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ All imports successful!")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}")

In [None]:
# Cell 2: GPU Check and Random Seed Setup
# Log GPU info, number of GPUs, and set seeds for reproducibility

def setup_device_and_seed(seed=42):
    """Setup device, log GPU info, and set random seeds"""
    
    # Set random seeds for reproducibility
    set_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # Ensure deterministic behavior (may slow down training)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print("🔧 Device and Seed Setup")
    print("=" * 50)
    
    # Check CUDA availability
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpus = torch.cuda.device_count()
        print(f"✅ CUDA is available!")
        print(f"📱 Number of GPUs: {num_gpus}")
        
        for i in range(num_gpus):
            gpu_name = torch.cuda.get_device_name(i)
            gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1e9
            print(f"   GPU {i}: {gpu_name} ({gpu_memory:.1f} GB)")
            
        # Display current GPU memory
        print(f"\n💾 Current GPU Memory Usage:")
        for i in range(num_gpus):
            allocated = torch.cuda.memory_allocated(i) / 1e9
            cached = torch.cuda.memory_reserved(i) / 1e9
            print(f"   GPU {i}: {allocated:.2f} GB allocated, {cached:.2f} GB cached")
            
    else:
        device = torch.device("cpu")
        num_gpus = 0
        print("⚠️  CUDA not available, using CPU")
        
    print(f"\n🎲 Random seed set to: {seed}")
    print(f"🖥️  Primary device: {device}")
    
    return device, num_gpus

# Setup device and seeds
SEED = 42
device, num_gpus = setup_device_and_seed(SEED)

# Training configuration
MULTI_GPU = num_gpus > 1
print(f"\n🔄 Multi-GPU training: {'Enabled' if MULTI_GPU else 'Disabled'}")

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print("🧹 GPU cache cleared") 

In [None]:
# Cell 3: Load and Explore Dataset (Cosmopedia 100k)
# Load from Hugging Face, print structure, show sample rows, and verify cleanliness

def load_and_explore_dataset():
    """Load Cosmopedia 100k dataset and explore its structure"""
    
    print("📚 Loading Cosmopedia 100k Dataset")
    print("=" * 50)
    
    # Load dataset from Hugging Face
    try:
        ds = load_dataset("HuggingFaceTB/cosmopedia-100k", split="train")
        print(f"✅ Dataset loaded successfully!")
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return None
    
    # Basic dataset info
    print(f"\n📊 Dataset Overview:")
    print(f"   Total samples: {len(ds):,}")
    print(f"   Features: {list(ds.features.keys())}")
    print(f"   Dataset size: {ds.data.nbytes / 1e6:.1f} MB")
    
    # Examine dataset structure
    print(f"\n🔍 Dataset Schema:")
    for feature_name, feature_type in ds.features.items():
        print(f"   {feature_name}: {feature_type}")
    
    # Show sample rows
    print(f"\n📝 Sample Rows (first 3):")
    for i in range(min(3, len(ds))):
        sample = ds[i]
        print(f"\n--- Sample {i+1} ---")
        for key, value in sample.items():
            if isinstance(value, str):
                # Truncate long text for display
                display_value = value[:200] + "..." if len(value) > 200 else value
                print(f"   {key}: {display_value}")
            else:
                print(f"   {key}: {value}")
    
    # Analyze text lengths
    print(f"\n📏 Text Length Analysis:")
    text_lengths = [len(sample['text']) for sample in ds.select(range(min(1000, len(ds))))]
    
    print(f"   Average text length: {np.mean(text_lengths):.0f} characters")
    print(f"   Median text length: {np.median(text_lengths):.0f} characters")
    print(f"   Min text length: {min(text_lengths):,} characters")
    print(f"   Max text length: {max(text_lengths):,} characters")
    
    # Check for any obvious data quality issues
    print(f"\n🔍 Data Quality Check:")
    empty_texts = sum(1 for sample in ds.select(range(min(1000, len(ds)))) if not sample['text'].strip())
    print(f"   Empty texts (in first 1000): {empty_texts}")
    
    # Plot text length distribution
    plt.figure(figsize=(10, 6))
    plt.hist(text_lengths, bins=50, alpha=0.7, edgecolor='black')
    plt.xlabel('Text Length (characters)')
    plt.ylabel('Frequency')
    plt.title('Distribution of Text Lengths in Cosmopedia 100k (Sample)')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    return ds

# Load and explore the dataset
dataset = load_and_explore_dataset()

# Store dataset info for later use
if dataset:
    DATASET_SIZE = len(dataset)
    print(f"\n✅ Dataset ready for preprocessing! Total samples: {DATASET_SIZE:,}")
else:
    print("❌ Failed to load dataset. Please check your connection and try again.") 

In [None]:
# Cell 4: Tokenizer - Existing Tokenizer
# Load the pre-trained SmolLM2 tokenizer

def load_smollm2_tokenizer():
    """Load the existing SmolLM2 tokenizer"""
    
    print("🔤 Loading SmolLM2 Tokenizer")
    print("=" * 50)
    
    # SmolLM2 model ID
    model_id = "HuggingFaceTB/SmolLM2-1.7B"
    
    try:
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        print(f"✅ Tokenizer loaded successfully from: {model_id}")
        
        # Add padding token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            print("🔧 Set pad_token to eos_token")
            
    except Exception as e:
        print(f"❌ Error loading tokenizer: {e}")
        return None
    
    # Tokenizer information
    print(f"\n📊 Tokenizer Information:")
    print(f"   Vocabulary size: {tokenizer.vocab_size:,}")
    print(f"   Model max length: {tokenizer.model_max_length:,}")
    print(f"   Padding side: {tokenizer.padding_side}")
    print(f"   Truncation side: {tokenizer.truncation_side}")
    
    # Special tokens
    print(f"\n🔑 Special Tokens:")
    special_tokens = {
        'bos_token': tokenizer.bos_token,
        'eos_token': tokenizer.eos_token,
        'pad_token': tokenizer.pad_token,
        'unk_token': tokenizer.unk_token,
    }
    
    for token_name, token_value in special_tokens.items():
        token_id = getattr(tokenizer, f"{token_name}_id", None)
        print(f"   {token_name}: '{token_value}' (ID: {token_id})")
    
    # Test tokenization
    print(f"\n🧪 Tokenization Test:")
    test_text = "Hello, this is a test sentence for SmolLM2 tokenizer."
    tokens = tokenizer.encode(test_text)
    decoded = tokenizer.decode(tokens)
    
    print(f"   Original: {test_text}")
    print(f"   Tokens: {tokens[:10]}... (showing first 10)")
    print(f"   Token count: {len(tokens)}")
    print(f"   Decoded: {decoded}")
    
    # Check tokenizer speed on a sample
    print(f"\n⚡ Speed Test:")
    test_texts = ["This is a test sentence."] * 100
    
    start_time = time.time()
    tokenized = tokenizer(test_texts, padding=True, truncation=True, return_tensors="pt")
    end_time = time.time()
    
    print(f"   Tokenized 100 sentences in {(end_time - start_time)*1000:.2f}ms")
    print(f"   Batch shape: {tokenized['input_ids'].shape}")
    
    return tokenizer

# Load the tokenizer
tokenizer = load_smollm2_tokenizer()

if tokenizer:
    # Set sequence length for training
    MAX_SEQ_LENGTH = 1024  # Adjust based on your needs and GPU memory
    print(f"\n📏 Max sequence length for training: {MAX_SEQ_LENGTH}")
    print("✅ Tokenizer ready for dataset preprocessing!")
else:
    print("❌ Failed to load tokenizer. Please check your connection and model ID.") 

In [None]:
# Cell 5: Dataset Preprocessing - ULTRA OPTIMIZED (Minimal Memory)
# Ultra-aggressive memory optimization to prevent checkpoint corruption

import os
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CosmopediaDataset(Dataset):
    """Ultra Memory-Optimized Dataset"""
    
    def __init__(self, dataset, tokenizer, max_length=256):  # Further reduced
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Get text from dataset
        text = self.dataset[idx]['text']
        
        # Tokenize with aggressive truncation
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Extract input_ids and attention_mask
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        # For causal language modeling, labels are the same as input_ids
        labels = input_ids.clone()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

def preprocess_dataset_ultra_optimized(dataset, tokenizer, max_length=256, batch_size=1):
    """Ultra memory-optimized preprocessing"""
    
    print("🔄 Dataset Preprocessing - ULTRA OPTIMIZED")
    print("=" * 60)
    
    if dataset is None or tokenizer is None:
        print("❌ Dataset or tokenizer is None. Please load them first.")
        return None, None
    
    print(f"📊 ULTRA OPTIMIZED Configuration:")
    print(f"   Max sequence length: {max_length} (ultra-reduced)")
    print(f"   Batch size: {batch_size}")
    print(f"   Total samples: {len(dataset):,}")
    print(f"   Memory optimization: ULTRA AGGRESSIVE")
    
    # Create custom dataset
    print(f"\n🔄 Creating PyTorch Dataset...")
    torch_dataset = CosmopediaDataset(dataset, tokenizer, max_length)
    
    # Test a single sample
    print(f"\n🧪 Testing Dataset Sample:")
    sample = torch_dataset[0]
    print(f"   Input IDs shape: {sample['input_ids'].shape}")
    print(f"   Attention mask shape: {sample['attention_mask'].shape}")
    print(f"   Labels shape: {sample['labels'].shape}")
    
    # Create DataLoader with ultra optimization
    print(f"\n🚀 Creating Ultra-Optimized DataLoader...")
    dataloader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,  # No multiprocessing
        pin_memory=False,  # Disabled
        drop_last=True,
        persistent_workers=False  # Disable persistent workers
    )
    
    print(f"✅ DataLoader created successfully!")
    print(f"   Number of batches: {len(dataloader):,}")
    
    # Test DataLoader with memory monitoring
    print(f"\n🧪 Testing DataLoader:")
    try:
        # Clear cache before testing
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            
        test_batch = next(iter(dataloader))
        print(f"   Batch input_ids shape: {test_batch['input_ids'].shape}")
        
        # Check memory usage
        if torch.cuda.is_available():
            batch_memory = torch.cuda.memory_allocated() / 1e9
            print(f"   GPU memory after batch: {batch_memory:.2f} GB")
            
    except Exception as e:
        print(f"   ❌ Error testing DataLoader: {e}")
        return None, None
    
    return torch_dataset, dataloader

# ULTRA OPTIMIZED Configuration for preventing memory corruption
BATCH_SIZE = 1  # Keep at 1
GRADIENT_ACCUMULATION_STEPS = 64  # Increased to maintain effective batch size
MAX_SEQ_LENGTH_ULTRA = 256  # Reduced from 512 to 256

# Calculate effective batch size
EFFECTIVE_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * max(1, num_gpus)

print(f"🔧 ULTRA OPTIMIZED Configuration:")
print(f"   Batch size per GPU: {BATCH_SIZE}")
print(f"   Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")
print(f"   Number of GPUs: {num_gpus}")
print(f"   Effective batch size: {EFFECTIVE_BATCH_SIZE}")
print(f"   Sequence length: {MAX_SEQ_LENGTH_ULTRA}")
print(f"   Target GPU usage: <12GB (safe zone)")

# Clear GPU memory aggressively
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    print("🧹 GPU cache cleared and synchronized")

# Preprocess dataset with ultra settings
train_dataset, train_dataloader = preprocess_dataset_ultra_optimized(
    dataset, 
    tokenizer, 
    max_length=MAX_SEQ_LENGTH_ULTRA, 
    batch_size=BATCH_SIZE
)

if train_dataloader:
    # Calculate final training statistics
    STEPS_PER_EPOCH = len(train_dataloader)
    TOTAL_STEPS = STEPS_PER_EPOCH * 3  # 3 epochs
    
    print(f"\n✅ ULTRA OPTIMIZED Preprocessing Complete!")
    print(f"📊 Final Training Statistics:")
    print(f"   Steps per epoch: {STEPS_PER_EPOCH:,}")
    print(f"   Total training steps (3 epochs): {TOTAL_STEPS:,}")
    print(f"   Samples per effective step: {EFFECTIVE_BATCH_SIZE}")
    print(f"   Memory footprint: MINIMIZED")
    print(f"   Expected GPU usage: ~10-12GB (safe)")
else:
    print("❌ Failed to create DataLoader. Please check your setup.") 

In [None]:
# Cell 6: Model Configuration and Initialization (NO MIXED PRECISION VERSION)
# Define model config, optimizer, scheduler, gradient clipping setup WITHOUT mixed precision

import os

# Set memory optimization environment variables
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def setup_model_and_training_components_no_amp():
    """Initialize SmolLM2 model with random weights and setup training components - NO MIXED PRECISION"""
    
    print("🤖 Model Configuration and Initialization (NO MIXED PRECISION)")
    print("=" * 60)
    
    # Clear GPU cache before model initialization
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print("🧹 GPU cache cleared before model setup")
    
    # Load SmolLM2 configuration
    model_id = "HuggingFaceTB/SmolLM2-1.7B"
    
    try:
        # Load config only (not weights)
        config = AutoConfig.from_pretrained(model_id)
        print(f"✅ Config loaded from: {model_id}")
        
        # MEMORY OPTIMIZATION: Enable gradient checkpointing
        config.use_cache = False  # Disable KV cache to save memory
        config.gradient_checkpointing = True
        
        # Display model configuration
        print(f"\n📊 Model Configuration:")
        print(f"   Model type: {config.model_type}")
        print(f"   Hidden size: {config.hidden_size:,}")
        print(f"   Number of layers: {config.num_hidden_layers}")
        print(f"   Number of attention heads: {config.num_attention_heads}")
        print(f"   Vocabulary size: {config.vocab_size:,}")
        print(f"   Max position embeddings: {config.max_position_embeddings:,}")
        print(f"   Use cache: {config.use_cache} (disabled for memory)")
        print(f"   Gradient checkpointing: {config.gradient_checkpointing}")
        
        # Calculate approximate model parameters
        approx_params = (config.vocab_size * config.hidden_size +  # Embedding
                        config.num_hidden_layers * (
                            4 * config.hidden_size * config.hidden_size +  # MLP
                            3 * config.hidden_size * config.hidden_size    # Attention
                        ) +
                        config.vocab_size * config.hidden_size)  # Output layer
        
        print(f"   Approximate parameters: {approx_params / 1e9:.2f}B")
        
    except Exception as e:
        print(f"❌ Error loading config: {e}")
        return None, None, None, None, None
    
    # Enable Flash Attention 2 if available (memory efficient)
    try:
        config.use_flash_attention_2 = True
        print(f"🚀 Flash Attention 2: Enabled (memory efficient)")
    except:
        print(f"⚠️  Flash Attention 2: Not available, using standard attention")
    
    # Initialize model with random weights
    print(f"\n🎲 Initializing model with random weights...")
    try:
        # Create model with memory optimizations
        model = AutoModelForCausalLM.from_config(config)
        
        # Enable gradient checkpointing for memory efficiency
        model.gradient_checkpointing_enable()
        print(f"✅ Gradient checkpointing enabled")
        
        print(f"✅ Model initialized successfully!")
        
        # Count actual parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"\n📈 Model Parameters:")
        print(f"   Total parameters: {total_params / 1e9:.2f}B ({total_params:,})")
        print(f"   Trainable parameters: {trainable_params / 1e9:.2f}B ({trainable_params:,})")
        
    except Exception as e:
        print(f"❌ Error initializing model: {e}")
        return None, None, None, None, None
    
    # Move model to device
    model = model.to(device)
    print(f"📱 Model moved to: {device}")
    
    # MEMORY CHECK after model loading
    if torch.cuda.is_available():
        model_memory = torch.cuda.memory_allocated() / 1e9
        print(f"💾 GPU memory after model loading: {model_memory:.2f} GB")
    
    # Setup optimizer (AdamW)
    print(f"\n⚙️  Setting up optimizer...")
    
    learning_rate = 3e-5
    weight_decay = 0.01
    
    # No weight decay for bias and layer norm parameters
    no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        betas=(0.9, 0.95),
        eps=1e-8,
        foreach=False
    )
    
    print(f"✅ AdamW optimizer configured:")
    print(f"   Learning rate: {learning_rate}")
    print(f"   Weight decay: {weight_decay}")
    
    # Setup learning rate scheduler
    print(f"\n📈 Setting up learning rate scheduler...")
    
    num_warmup_steps = int(0.1 * TOTAL_STEPS)  # 10% warmup
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=TOTAL_STEPS
    )
    
    print(f"✅ Cosine scheduler with warmup configured:")
    print(f"   Warmup steps: {num_warmup_steps:,}")
    print(f"   Total steps: {TOTAL_STEPS:,}")
    
    # NO MIXED PRECISION - Use None for scaler
    print(f"\n🎯 Training setup (NO MIXED PRECISION):")
    scaler = None  # No scaler needed for FP32 training
    print(f"✅ Mixed precision: DISABLED (using FP32 for compatibility)")
    print(f"   Memory usage: Higher but more stable")
    print(f"   Compatibility: Maximum (no precision issues)")
    
    # Training configuration summary
    print(f"\n🎛️  Training Configuration Summary:")
    print(f"   Mixed precision: DISABLED (FP32 training)")
    print(f"   Gradient checkpointing: ENABLED")
    print(f"   Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"   Max gradient norm: 1.0 (for clipping)")
    print(f"   Effective batch size: {EFFECTIVE_BATCH_SIZE}")
    print(f"   Precision: FP32 (maximum compatibility)")
    
    # Final memory check
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        final_memory = torch.cuda.memory_allocated() / 1e9
        print(f"\n💾 Final GPU memory usage: {final_memory:.2f} GB")
        free_memory = (torch.cuda.get_device_properties(0).total_memory / 1e9) - final_memory
        print(f"💾 Available GPU memory: {free_memory:.2f} GB")
    
    return model, optimizer, scheduler, scaler, config

# Initialize model and training components WITHOUT mixed precision
print("🚀 Initializing training components (NO MIXED PRECISION)...")
model, optimizer, scheduler, scaler, model_config = setup_model_and_training_components_no_amp()

if model is not None:
    print(f"\n✅ Model and training components ready!")
    
    # Final memory report
    if torch.cuda.is_available():
        memory_allocated = torch.cuda.memory_allocated() / 1e9
        memory_reserved = torch.cuda.memory_reserved() / 1e9
        total_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        
        print(f"\n💾 FINAL GPU Memory Report:")
        print(f"   Total GPU memory: {total_memory:.2f} GB")
        print(f"   Allocated: {memory_allocated:.2f} GB ({memory_allocated/total_memory*100:.1f}%)")
        print(f"   Reserved: {memory_reserved:.2f} GB ({memory_reserved/total_memory*100:.1f}%)")
        print(f"   Free: {total_memory - memory_reserved:.2f} GB")
        
else:
    print("❌ Failed to initialize model. Please check your setup.")

# Gradient clipping value
MAX_GRAD_NORM = 1.0
print(f"\n🔧 Gradient clipping norm: {MAX_GRAD_NORM}")
print(f"\n💡 Note: This version uses FP32 training for maximum compatibility")
print(f"   Memory usage will be higher but training should be stable") 

In [None]:
# Cell 7: Ultra-Safe Training Loop with Robust Checkpointing
# Ultra-aggressive memory management and corruption-resistant checkpointing

import glob
import os
import gc
from pathlib import Path
import tempfile
import shutil

# Set environment variables for maximum stability
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Training tracking variables
training_stats = {
    'steps': [],
    'losses': [],
    'learning_rates': [],
    'epochs': [],
    'gpu_memory': []
}

def aggressive_memory_cleanup():
    """Ultra-aggressive memory cleanup"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()  # Python garbage collection

def safe_checkpoint_save(model, optimizer, scheduler, step, epoch, loss, checkpoint_dir="./checkpoints"):
    """Ultra-safe checkpoint saving with corruption prevention"""
    
    Path(checkpoint_dir).mkdir(exist_ok=True)
    
    # Aggressive memory cleanup before saving
    aggressive_memory_cleanup()
    
    checkpoint = {
        'step': step,
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'training_stats': training_stats
    }
    
    checkpoint_path = f"{checkpoint_dir}/checkpoint_step_{step}.pt"
    temp_path = f"{checkpoint_dir}/temp_checkpoint_step_{step}.pt"
    
    try:
        # Save to temporary file first
        torch.save(checkpoint, temp_path)
        
        # Verify the file was written correctly
        test_load = torch.load(temp_path, map_location='cpu')
        if 'step' in test_load and test_load['step'] == step:
            # Move temp file to final location
            shutil.move(temp_path, checkpoint_path)
            print(f"💾 Checkpoint saved safely: {checkpoint_path}")
        else:
            raise Exception("Checkpoint verification failed")
            
    except Exception as e:
        print(f"❌ Checkpoint save failed: {e}")
        # Clean up temp file
        if os.path.exists(temp_path):
            os.remove(temp_path)
        return None
    
    # Final cleanup after saving
    aggressive_memory_cleanup()
    
    return checkpoint_path

def load_checkpoint_safe(checkpoint_path, model, optimizer, scheduler):
    """Safe checkpoint loading with verification"""
    
    try:
        # Load and verify checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        if 'model_state_dict' not in checkpoint:
            raise Exception("Invalid checkpoint format")
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['loss']
        
        global training_stats
        training_stats = checkpoint.get('training_stats', training_stats)
        
        print(f"📂 Checkpoint loaded successfully: {checkpoint_path}")
        print(f"   Resuming from step: {start_step}")
        print(f"   Previous loss: {best_loss:.4f}")
        
        return start_step, start_epoch, best_loss
        
    except Exception as e:
        print(f"❌ Error loading checkpoint: {e}")
        return 0, 0, float('inf')

def find_latest_checkpoint(checkpoint_dir="./checkpoints"):
    """Find the latest valid checkpoint"""
    
    if not os.path.exists(checkpoint_dir):
        return None
        
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint_step_*.pt")
    if not checkpoints:
        return None
        
    # Sort by step number and validate
    valid_checkpoints = []
    for cp in checkpoints:
        try:
            # Quick validation
            test = torch.load(cp, map_location='cpu')
            if 'step' in test:
                valid_checkpoints.append(cp)
        except:
            print(f"⚠️ Removing corrupted checkpoint: {cp}")
            os.remove(cp)
    
    if not valid_checkpoints:
        return None
        
    valid_checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    return valid_checkpoints[-1]

def train_model_ultra_safe():
    """Ultra-safe training with aggressive memory management"""
    
    print("🚀 Starting ULTRA-SAFE Training Loop")
    print("=" * 80)
    
    # Initial aggressive cleanup
    aggressive_memory_cleanup()
    
    # Check for existing checkpoints
    latest_checkpoint = find_latest_checkpoint()
    start_step = 0
    start_epoch = 0
    best_loss = float('inf')
    
    if latest_checkpoint:
        response = input(f"Found checkpoint: {latest_checkpoint}. Resume? (y/n): ")
        if response.lower() == 'y':
            start_step, start_epoch, best_loss = load_checkpoint_safe(
                latest_checkpoint, model, optimizer, scheduler
            )
    
    # Calculate starting position
    steps_per_epoch = len(train_dataloader)
    current_epoch = start_step // steps_per_epoch
    current_step_in_epoch = start_step % steps_per_epoch
    
    print(f"\n🎯 ULTRA-SAFE Training Configuration:")
    print(f"   Starting from step: {start_step}")
    print(f"   Total epochs: 3")
    print(f"   Steps per epoch: {steps_per_epoch:,}")
    print(f"   Total steps: {TOTAL_STEPS:,}")
    print(f"   Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS}")
    print(f"   Effective batch size: {EFFECTIVE_BATCH_SIZE}")
    print(f"   Sequence length: {MAX_SEQ_LENGTH_ULTRA}")
    print(f"   Memory management: ULTRA AGGRESSIVE")
    print(f"   Checkpoint frequency: Every 200 steps (safer)")
    
    # Training loop with ultra-safe memory management
    model.train()
    global_step = start_step
    running_loss = 0.0
    log_interval = 10
    current_loss = 0.0
    
    try:
        for epoch in range(start_epoch, 3):
            print(f"\n🔄 Epoch {epoch + 1}/3")
            print("-" * 60)
            
            epoch_start_time = time.time()
            epoch_loss = 0.0
            epoch_steps = 0
            
            # Ultra-aggressive cleanup at epoch start
            aggressive_memory_cleanup()
            
            # Skip batches if resuming mid-epoch
            dataloader_iter = iter(train_dataloader)
            for _ in range(current_step_in_epoch):
                next(dataloader_iter)
            
            if epoch > start_epoch:
                current_step_in_epoch = 0
                dataloader_iter = iter(train_dataloader)
            
            progress_bar = tqdm(
                dataloader_iter, 
                total=steps_per_epoch - current_step_in_epoch,
                desc=f"Epoch {epoch+1}",
                initial=current_step_in_epoch
            )
            
            for step, batch in enumerate(progress_bar, start=current_step_in_epoch):
                step_start_time = time.time()
                
                # Move batch to device
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)
                
                # Forward pass (FP32)
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    use_cache=False
                )
                loss = outputs.loss
                loss = loss / GRADIENT_ACCUMULATION_STEPS
                
                # Backward pass
                loss.backward()
                current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS
                
                # Gradient accumulation and optimization
                if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    
                    # Optimizer step
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    # Memory cleanup every optimizer step
                    if global_step % 10 == 0:
                        aggressive_memory_cleanup()
                
                # Update statistics
                running_loss += current_loss
                epoch_loss += current_loss
                epoch_steps += 1
                global_step += 1
                
                # Log progress
                if global_step % log_interval == 0:
                    avg_loss = running_loss / log_interval
                    current_lr = scheduler.get_last_lr()[0]
                    step_time = time.time() - step_start_time
                    
                    # GPU memory usage
                    gpu_memory = 0
                    if torch.cuda.is_available():
                        gpu_memory = torch.cuda.memory_allocated() / 1e9
                    
                    # Store statistics
                    training_stats['steps'].append(global_step)
                    training_stats['losses'].append(avg_loss)
                    training_stats['learning_rates'].append(current_lr)
                    training_stats['epochs'].append(epoch + 1)
                    training_stats['gpu_memory'].append(gpu_memory)
                    
                    # Update progress bar
                    progress_bar.set_postfix({
                        'Loss': f'{avg_loss:.4f}',
                        'LR': f'{current_lr:.2e}',
                        'GPU': f'{gpu_memory:.1f}GB',
                        'Time': f'{step_time:.2f}s'
                    })
                    
                    print(f"\n📊 Step {global_step:,}/{TOTAL_STEPS:,} | "
                          f"Loss: {avg_loss:.4f} | "
                          f"LR: {current_lr:.2e} | "
                          f"GPU: {gpu_memory:.1f}GB")
                    
                    running_loss = 0.0
                
                # SAFER checkpointing - every 200 steps instead of 100
                if global_step % 200 == 0:
                    print(f"\n💾 Saving checkpoint at step {global_step}...")
                    checkpoint_path = safe_checkpoint_save(
                        model, optimizer, scheduler, 
                        global_step, epoch, current_loss
                    )
                    
                    if checkpoint_path:
                        # Keep only last 2 checkpoints to save space
                        checkpoints = glob.glob("./checkpoints/checkpoint_step_*.pt")
                        if len(checkpoints) > 2:
                            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
                            for old_checkpoint in checkpoints[:-2]:
                                os.remove(old_checkpoint)
                                print(f"🗑️ Removed old checkpoint: {old_checkpoint}")
                
                # Memory monitoring with warnings
                if global_step % 100 == 0:
                    if torch.cuda.is_available():
                        memory_used = torch.cuda.memory_allocated() / 1e9
                        if memory_used > 12:
                            print(f"⚠️ Memory usage: {memory_used:.2f}GB - performing cleanup")
                            aggressive_memory_cleanup()
                
                # Early stopping check
                if global_step >= TOTAL_STEPS:
                    print(f"\n🎯 Reached maximum steps ({TOTAL_STEPS})")
                    break
            
            # Epoch summary
            epoch_time = time.time() - epoch_start_time
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            
            print(f"\n📈 Epoch {epoch + 1} Summary:")
            print(f"   Average Loss: {avg_epoch_loss:.4f}")
            print(f"   Time: {epoch_time / 60:.2f} minutes")
            print(f"   Steps: {epoch_steps}")
            print(f"   Global Step: {global_step:,}/{TOTAL_STEPS:,}")
            
            # Save end-of-epoch checkpoint
            safe_checkpoint_save(
                model, optimizer, scheduler, 
                global_step, epoch + 1, avg_epoch_loss
            )
            
            if global_step >= TOTAL_STEPS:
                break
    
    except KeyboardInterrupt:
        print(f"\n⚠️ Training interrupted by user")
        safe_checkpoint_save(model, optimizer, scheduler, global_step, epoch, current_loss)
    except Exception as e:
        print(f"\n❌ Training error: {e}")
        print(f"📊 Error occurred at step: {global_step}")
        print(f"💾 Attempting emergency checkpoint save...")
        safe_checkpoint_save(model, optimizer, scheduler, global_step, epoch, current_loss)
        raise
    
    print(f"\n🎉 Training completed!")
    print(f"   Total steps: {global_step:,}")
    print(f"   Final loss: {training_stats['losses'][-1] if training_stats['losses'] else 'N/A'}")
    
    return training_stats

# Check if all components are ready
missing_components = []
if 'model' not in globals() or model is None: missing_components.append("model")
if 'optimizer' not in globals() or optimizer is None: missing_components.append("optimizer") 
if 'scheduler' not in globals() or scheduler is None: missing_components.append("scheduler")
if 'train_dataloader' not in globals() or train_dataloader is None: missing_components.append("dataloader")

if not missing_components:
    print("🔍 All components ready for ULTRA-SAFE training!")
    print("🚀 Starting ultra-safe training with corruption prevention...")
    
    # Ultra-aggressive initial cleanup
    aggressive_memory_cleanup()
    
    if torch.cuda.is_available():
        initial_memory = torch.cuda.memory_allocated() / 1e9
        print(f"💾 Initial GPU memory: {initial_memory:.2f} GB")
        print(f"🎯 Target: Keep under 12GB to prevent corruption")
    
    try:
        training_stats = train_model_ultra_safe()
    except KeyboardInterrupt:
        print("Training cancelled by user")
    except Exception as e:
        print(f"Training failed: {e}")
        print("💡 Try restarting kernel if corruption persists")
else:
    print("❌ Some components are missing. Please run previous cells first.")
    print(f"Missing: {', '.join(missing_components)}")
    print("\n🔧 This ultra-safe version should prevent checkpoint corruption")
    print("   - Reduced sequence length: 256")
    print("   - Safer checkpointing: every 200 steps")
    print("   - Aggressive memory management")
    print("   - Corruption-resistant file operations") 

In [None]:
# Cell 8: Loss Visualization and Statistics Logging
# Plot training loss curves, log metrics, and save visualizations/artifacts

def plot_training_metrics(training_stats, save_plots=True):
    """Plot training metrics and statistics"""
    
    print("📊 Visualizing Training Metrics")
    print("=" * 50)
    
    if not training_stats['steps']:
        print("❌ No training statistics available. Please run training first.")
        return
    
    # Create subplots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('SmolLM2 Training Metrics', fontsize=16, fontweight='bold')
    
    # 1. Training Loss
    axes[0, 0].plot(training_stats['steps'], training_stats['losses'], 'b-', linewidth=2, alpha=0.8)
    axes[0, 0].set_xlabel('Training Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss Over Time')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].set_yscale('log')  # Log scale for better visualization
    
    # Add trend line
    if len(training_stats['steps']) > 10:
        z = np.polyfit(training_stats['steps'], np.log(training_stats['losses']), 1)
        p = np.poly1d(z)
        axes[0, 0].plot(training_stats['steps'], np.exp(p(training_stats['steps'])), 
                       'r--', alpha=0.7, label='Trend')
        axes[0, 0].legend()
    
    # 2. Learning Rate Schedule
    axes[0, 1].plot(training_stats['steps'], training_stats['learning_rates'], 'g-', linewidth=2)
    axes[0, 1].set_xlabel('Training Step')
    axes[0, 1].set_ylabel('Learning Rate')
    axes[0, 1].set_title('Learning Rate Schedule')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_yscale('log')
    
    # 3. GPU Memory Usage
    if training_stats['gpu_memory'] and any(mem > 0 for mem in training_stats['gpu_memory']):
        axes[1, 0].plot(training_stats['steps'], training_stats['gpu_memory'], 'r-', linewidth=2)
        axes[1, 0].set_xlabel('Training Step')
        axes[1, 0].set_ylabel('GPU Memory (GB)')
        axes[1, 0].set_title('GPU Memory Usage')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Add average line
        avg_memory = np.mean(training_stats['gpu_memory'])
        axes[1, 0].axhline(y=avg_memory, color='orange', linestyle='--', 
                          label=f'Average: {avg_memory:.1f}GB')
        axes[1, 0].legend()
    else:
        axes[1, 0].text(0.5, 0.5, 'GPU Memory\nData Not Available', 
                       ha='center', va='center', transform=axes[1, 0].transAxes,
                       fontsize=12)
        axes[1, 0].set_title('GPU Memory Usage')
    
    # 4. Loss per Epoch
    if training_stats['epochs']:
        epochs = training_stats['epochs']
        unique_epochs = sorted(set(epochs))
        epoch_losses = []
        
        for epoch in unique_epochs:
            epoch_indices = [i for i, e in enumerate(epochs) if e == epoch]
            if epoch_indices:
                epoch_loss = np.mean([training_stats['losses'][i] for i in epoch_indices])
                epoch_losses.append(epoch_loss)
        
        if epoch_losses:
            axes[1, 1].bar(unique_epochs, epoch_losses, alpha=0.7, color='purple')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Average Loss')
            axes[1, 1].set_title('Average Loss per Epoch')
            axes[1, 1].grid(True, alpha=0.3)
            
            # Add values on bars
            for i, v in enumerate(epoch_losses):
                axes[1, 1].text(unique_epochs[i], v, f'{v:.3f}', 
                               ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    
    if save_plots:
        plt.savefig('training_metrics.png', dpi=300, bbox_inches='tight')
        print("📁 Training metrics saved as 'training_metrics.png'")
    
    plt.show()

def generate_training_report(training_stats):
    """Generate a comprehensive training report"""
    
    print("\n📋 Training Report")
    print("=" * 50)
    
    if not training_stats['steps']:
        print("❌ No training data available")
        return
    
    # Basic statistics
    total_steps = len(training_stats['steps'])
    final_loss = training_stats['losses'][-1] if training_stats['losses'] else 0
    initial_loss = training_stats['losses'][0] if training_stats['losses'] else 0
    loss_reduction = ((initial_loss - final_loss) / initial_loss * 100) if initial_loss > 0 else 0
    
    print(f"📊 Training Summary:")
    print(f"   Total training steps: {training_stats['steps'][-1]:,}")
    print(f"   Logged data points: {total_steps:,}")
    print(f"   Initial loss: {initial_loss:.4f}")
    print(f"   Final loss: {final_loss:.4f}")
    print(f"   Loss reduction: {loss_reduction:.1f}%")
    
    # Loss statistics
    if training_stats['losses']:
        min_loss = min(training_stats['losses'])
        max_loss = max(training_stats['losses'])
        avg_loss = np.mean(training_stats['losses'])
        std_loss = np.std(training_stats['losses'])
        
        print(f"\n📈 Loss Statistics:")
        print(f"   Minimum loss: {min_loss:.4f}")
        print(f"   Maximum loss: {max_loss:.4f}")
        print(f"   Average loss: {avg_loss:.4f}")
        print(f"   Standard deviation: {std_loss:.4f}")
    
    # Learning rate statistics
    if training_stats['learning_rates']:
        initial_lr = training_stats['learning_rates'][0]
        final_lr = training_stats['learning_rates'][-1]
        max_lr = max(training_stats['learning_rates'])
        
        print(f"\n📈 Learning Rate:")
        print(f"   Initial LR: {initial_lr:.2e}")
        print(f"   Final LR: {final_lr:.2e}")
        print(f"   Maximum LR: {max_lr:.2e}")
    
    # GPU Memory statistics
    if training_stats['gpu_memory'] and any(mem > 0 for mem in training_stats['gpu_memory']):
        avg_memory = np.mean(training_stats['gpu_memory'])
        max_memory = max(training_stats['gpu_memory'])
        min_memory = min(training_stats['gpu_memory'])
        
        print(f"\n💾 GPU Memory Usage:")
        print(f"   Average: {avg_memory:.2f} GB")
        print(f"   Maximum: {max_memory:.2f} GB")
        print(f"   Minimum: {min_memory:.2f} GB")
    
    # Epoch statistics
    if training_stats['epochs']:
        completed_epochs = max(training_stats['epochs'])
        print(f"\n🔄 Training Progress:")
        print(f"   Completed epochs: {completed_epochs}")
        print(f"   Target epochs: 3")
        print(f"   Progress: {completed_epochs/3*100:.1f}%")
    
    # Training stability analysis
    if len(training_stats['losses']) > 10:
        recent_losses = training_stats['losses'][-10:]  # Last 10 losses
        early_losses = training_stats['losses'][:10]    # First 10 losses
        
        recent_trend = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]
        
        print(f"\n📉 Training Stability:")
        print(f"   Recent trend: {'Decreasing' if recent_trend < 0 else 'Increasing'}")
        print(f"   Trend slope: {recent_trend:.6f}")
        
        # Convergence indicator
        if abs(recent_trend) < 0.001:
            print(f"   Status: 🟢 Model appears to be converging")
        elif recent_trend < -0.01:
            print(f"   Status: 🟡 Model is still learning rapidly")
        else:
            print(f"   Status: 🔴 Model may be diverging or overfitting")

def save_training_stats(training_stats, filename='training_stats.json'):
    """Save training statistics to JSON file"""
    
    try:
        import json
        with open(filename, 'w') as f:
            json.dump(training_stats, f, indent=2)
        print(f"💾 Training statistics saved to: {filename}")
    except Exception as e:
        print(f"❌ Error saving training stats: {e}")

# Check if training stats are available
if 'training_stats' in globals() and training_stats['steps']:
    print("📊 Training statistics found! Generating visualizations...")
    
    # Plot training metrics
    plot_training_metrics(training_stats)
    
    # Generate training report
    generate_training_report(training_stats)
    
    # Save training statistics
    save_training_stats(training_stats)
    
    print("\n✅ Training analysis complete!")
    
else:
    print("❌ No training statistics available.")
    print("Please run the training loop (Cell 7) first to generate data.")
    
    # Create dummy data for demonstration (remove this in actual training)
    print("\n🔧 Creating sample visualization for demonstration...")
    
    dummy_stats = {
        'steps': list(range(0, 1000, 10)),
        'losses': [4.5 * np.exp(-i/500) + 0.5 + 0.1*np.random.random() for i in range(0, 1000, 10)],
        'learning_rates': [5e-5 * (0.5 + 0.5*np.cos(i*np.pi/1000)) for i in range(0, 1000, 10)],
        'epochs': [min(3, i//333 + 1) for i in range(0, 1000, 10)],
        'gpu_memory': [8.5 + 0.5*np.random.random() for _ in range(100)]
    }
    
    plot_training_metrics(dummy_stats, save_plots=False)
    print("📝 This is sample data. Run training to see real metrics.") 

In [None]:
# Cell 9: Sample Generation and Output Check
# Generate some text using the trained model to qualitatively assess performance

def generate_text_samples(model, tokenizer, prompts=None, max_length=200, num_samples=3, temperature=0.7):
    """Generate text samples using the trained model"""
    
    print("🎭 Text Generation - Model Performance Check")
    print("=" * 60)
    
    if model is None or tokenizer is None:
        print("❌ Model or tokenizer not available. Please run training first.")
        return
    
    # Default prompts if none provided
    if prompts is None:
        prompts = [
            "The future of artificial intelligence",
            "In a world where technology",
            "Scientists have recently discovered",
            "The most important lesson in life",
            "Climate change is"
        ]
    
    model.eval()  # Set to evaluation mode
    
    print(f"🎯 Generation Parameters:")
    print(f"   Max length: {max_length} tokens")
    print(f"   Temperature: {temperature}")
    print(f"   Number of samples per prompt: {num_samples}")
    print(f"   Total generations: {len(prompts) * num_samples}")
    
    all_generations = []
    
    try:
        with torch.no_grad():
            for i, prompt in enumerate(prompts):
                print(f"\n📝 Prompt {i+1}: '{prompt}'")
                print("-" * 40)
                
                for sample_idx in range(num_samples):
                    # Tokenize prompt
                    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
                    input_ids = inputs['input_ids'].to(device)
                    attention_mask = inputs['attention_mask'].to(device)
                    
                    # Generate text
                    with autocast():
                        generated_ids = model.generate(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            max_length=min(max_length, input_ids.shape[1] + 150),
                            temperature=temperature,
                            do_sample=True,
                            top_p=0.9,
                            top_k=50,
                            pad_token_id=tokenizer.eos_token_id,
                            repetition_penalty=1.1,
                            no_repeat_ngram_size=3
                        )
                    
                    # Decode generated text
                    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                    
                    # Remove the original prompt from generated text
                    if generated_text.startswith(prompt):
                        generated_text = prompt + generated_text[len(prompt):]
                    
                    print(f"\nSample {sample_idx + 1}:")
                    print(f"🤖 {generated_text}")
                    print()
                    
                    # Store generation for analysis
                    all_generations.append({
                        'prompt': prompt,
                        'generated_text': generated_text,
                        'sample_idx': sample_idx + 1
                    })
                    
    except Exception as e:
        print(f"❌ Error during text generation: {e}")
        return []
    
    print(f"✅ Generated {len(all_generations)} text samples successfully!")
    return all_generations

def analyze_generation_quality(generations, tokenizer):
    """Analyze the quality of generated text"""
    
    print("\n🔍 Generation Quality Analysis")
    print("=" * 50)
    
    if not generations:
        print("❌ No generations to analyze")
        return
    
    # Basic statistics
    total_gens = len(generations)
    avg_length = np.mean([len(gen['generated_text']) for gen in generations])
    avg_tokens = np.mean([len(tokenizer.encode(gen['generated_text'])) for gen in generations])
    
    print(f"📊 Basic Statistics:")
    print(f"   Total generations: {total_gens}")
    print(f"   Average character length: {avg_length:.1f}")
    print(f"   Average token length: {avg_tokens:.1f}")
    
    # Check for common issues
    print(f"\n🔍 Quality Checks:")
    
    # Repetition check
    repetitive_count = 0
    for gen in generations:
        text = gen['generated_text']
        words = text.split()
        if len(words) > 10:
            # Check for repeated 3-grams
            trigrams = [' '.join(words[i:i+3]) for i in range(len(words)-2)]
            unique_trigrams = set(trigrams)
            if len(unique_trigrams) / len(trigrams) < 0.8:  # Less than 80% unique trigrams
                repetitive_count += 1
    
    print(f"   Repetitive generations: {repetitive_count}/{total_gens} ({repetitive_count/total_gens*100:.1f}%)")
    
    # Length variation
    lengths = [len(gen['generated_text']) for gen in generations]
    length_std = np.std(lengths)
    print(f"   Length variation (std): {length_std:.1f}")
    
    # Check for coherence (basic)
    coherent_count = 0
    for gen in generations:
        text = gen['generated_text']
        # Basic coherence: contains periods, reasonable length, not all caps
        if ('.' in text and len(text) > 50 and 
            not text.isupper() and not text.islower()):
            coherent_count += 1
    
    print(f"   Potentially coherent: {coherent_count}/{total_gens} ({coherent_count/total_gens*100:.1f}%)")
    
    # Check vocabulary diversity
    all_words = []
    for gen in generations:
        words = gen['generated_text'].lower().split()
        all_words.extend(words)
    
    unique_words = len(set(all_words))
    total_words = len(all_words)
    vocabulary_diversity = unique_words / total_words if total_words > 0 else 0
    
    print(f"   Vocabulary diversity: {vocabulary_diversity:.3f}")
    print(f"   Unique words: {unique_words:,}")
    print(f"   Total words: {total_words:,}")
    
    # Overall assessment
    print(f"\n🏆 Overall Assessment:")
    
    quality_score = 0
    if repetitive_count / total_gens < 0.3:  # Less than 30% repetitive
        quality_score += 25
    if coherent_count / total_gens > 0.7:  # More than 70% coherent
        quality_score += 25
    if vocabulary_diversity > 0.5:  # Good vocabulary diversity
        quality_score += 25
    if length_std > 20:  # Good length variation
        quality_score += 25
    
    if quality_score >= 75:
        assessment = "🟢 Excellent - Model generating high-quality, diverse text"
    elif quality_score >= 50:
        assessment = "🟡 Good - Model showing decent performance"
    elif quality_score >= 25:
        assessment = "🟠 Fair - Model needs more training"
    else:
        assessment = "🔴 Poor - Model needs significant improvement"
    
    print(f"   Quality Score: {quality_score}/100")
    print(f"   Assessment: {assessment}")

def interactive_generation(model, tokenizer):
    """Interactive text generation for manual testing"""
    
    print("\n🎮 Interactive Generation Mode")
    print("=" * 50)
    print("Enter prompts to test the model (type 'quit' to exit)")
    
    model.eval()
    
    while True:
        try:
            prompt = input("\n🎯 Enter prompt: ").strip()
            
            if prompt.lower() in ['quit', 'exit', 'q']:
                print("👋 Exiting interactive mode")
                break
                
            if not prompt:
                print("Please enter a valid prompt")
                continue
            
            print(f"\n🤖 Generating response for: '{prompt}'")
            
            # Generate single response
            inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            
            with torch.no_grad():
                with autocast():
                    generated_ids = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=min(200, input_ids.shape[1] + 100),
                        temperature=0.7,
                        do_sample=True,
                        top_p=0.9,
                        pad_token_id=tokenizer.eos_token_id
                    )
            
            generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
            print(f"📝 Generated: {generated_text}")
            
        except KeyboardInterrupt:
            print("\n👋 Exiting interactive mode")
            break
        except Exception as e:
            print(f"❌ Error: {e}")

# Run text generation if model is available
if 'model' in globals() and model is not None:
    print("🤖 Model found! Running text generation tests...")
    
    # Generate sample texts
    sample_prompts = [
        "The future of artificial intelligence",
        "In a world where technology",
        "Scientists have recently discovered",
        "The most important lesson in life is",
        "Climate change represents"
    ]
    
    generations = generate_text_samples(
        model, 
        tokenizer, 
        prompts=sample_prompts,
        max_length=150,
        num_samples=2,
        temperature=0.8
    )
    
    if generations:
        # Analyze generation quality
        analyze_generation_quality(generations, tokenizer)
        
        # Save generations to file
        try:
            import json
            with open('generated_samples.json', 'w') as f:
                json.dump(generations, f, indent=2)
            print("\n💾 Generated samples saved to 'generated_samples.json'")
        except Exception as e:
            print(f"❌ Error saving generations: {e}")
        
        # Offer interactive mode
        print("\n🎮 Want to try interactive generation? (y/n)")
        response = input().strip().lower()
        if response == 'y':
            interactive_generation(model, tokenizer)
    
else:
    print("❌ Model not available. Please run training first (Cells 1-7)")
    print("🔧 Showing what text generation output would look like...")
    
    # Demo output format
    print("\n📝 Example generation output:")
    print("=" * 50)
    print("🎯 Prompt: 'The future of artificial intelligence'")
    print("-" * 40)
    print("🤖 Sample 1: The future of artificial intelligence will be shaped by...")
    print("🤖 Sample 2: The future of artificial intelligence holds great promise...")
    print("\n🔍 Quality analysis would appear here after real training.") 

In [None]:
# Cell 10: Save Final Model, Tokenizer, and Artifacts
# Save the trained model, tokenizer files, and config to disk

import shutil
from pathlib import Path

def save_final_model_and_artifacts(model, tokenizer, model_config, training_stats, save_dir="./smollm2_trained"):
    """Save the trained model and all related artifacts"""
    
    print("💾 Saving Trained Model and Artifacts")
    print("=" * 60)
    
    if model is None:
        print("❌ No model to save. Please run training first.")
        return False
    
    # Create save directory
    save_path = Path(save_dir)
    save_path.mkdir(exist_ok=True)
    print(f"📁 Save directory: {save_path.absolute()}")
    
    try:
        # 1. Save the model
        print("\n🤖 Saving model...")
        model_save_path = save_path / "model"
        
        # If model is wrapped with DataParallel, get the underlying model
        model_to_save = model.module if hasattr(model, 'module') else model
        
        # Save model using transformers save_pretrained method
        model_to_save.save_pretrained(model_save_path)
        print(f"✅ Model saved to: {model_save_path}")
        
        # 2. Save the tokenizer
        print("\n🔤 Saving tokenizer...")
        tokenizer_save_path = save_path / "tokenizer"
        tokenizer.save_pretrained(tokenizer_save_path)
        print(f"✅ Tokenizer saved to: {tokenizer_save_path}")
        
        # 3. Save model configuration
        print("\n⚙️  Saving model configuration...")
        config_save_path = save_path / "config.json"
        if model_config:
            model_config.save_pretrained(save_path)
            print(f"✅ Config saved to: {config_save_path}")
        
        # 4. Save training statistics
        print("\n📊 Saving training statistics...")
        stats_save_path = save_path / "training_stats.json"
        if training_stats and training_stats.get('steps'):
            import json
            with open(stats_save_path, 'w') as f:
                json.dump(training_stats, f, indent=2)
            print(f"✅ Training stats saved to: {stats_save_path}")
        
        # 5. Save training metadata
        print("\n📋 Saving training metadata...")
        metadata = {
            "model_name": "SmolLM2-1.7B-Custom",
            "training_dataset": "HuggingFaceTB/cosmopedia-100k",
            "training_date": time.strftime("%Y-%m-%d %H:%M:%S"),
            "total_epochs": 3,
            "effective_batch_size": EFFECTIVE_BATCH_SIZE if 'EFFECTIVE_BATCH_SIZE' in globals() else "Unknown",
            "max_sequence_length": MAX_SEQ_LENGTH if 'MAX_SEQ_LENGTH' in globals() else 1024,
            "learning_rate": 5e-5,
            "optimizer": "AdamW",
            "scheduler": "cosine_with_warmup",
            "mixed_precision": "FP16",
            "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS if 'GRADIENT_ACCUMULATION_STEPS' in globals() else 8,
            "flash_attention": "Enabled" if hasattr(model_config, 'use_flash_attention_2') and model_config.use_flash_attention_2 else "Disabled",
            "total_parameters": sum(p.numel() for p in model.parameters()),
            "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad)
        }
        
        # Add final training statistics if available
        if training_stats and training_stats.get('losses'):
            metadata.update({
                "final_loss": training_stats['losses'][-1],
                "initial_loss": training_stats['losses'][0],
                "total_training_steps": training_stats['steps'][-1] if training_stats['steps'] else 0,
                "loss_reduction_percent": ((training_stats['losses'][0] - training_stats['losses'][-1]) / training_stats['losses'][0] * 100) if training_stats['losses'][0] > 0 else 0
            })
        
        metadata_save_path = save_path / "training_metadata.json"
        with open(metadata_save_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        print(f"✅ Training metadata saved to: {metadata_save_path}")
        
        # 6. Save model card (README)
        print("\n📄 Creating model card...")
        model_card_content = f"""# SmolLM2-1.7B Custom Trained Model

## Model Description

This is a custom-trained version of SmolLM2-1.7B, trained from scratch on the Cosmopedia-100k dataset.

## Training Details

- **Base Architecture**: SmolLM2-1.7B
- **Training Dataset**: HuggingFaceTB/cosmopedia-100k
- **Training Date**: {metadata['training_date']}
- **Total Parameters**: {metadata['total_parameters']:,}
- **Trainable Parameters**: {metadata['trainable_parameters']:,}

## Training Configuration

- **Epochs**: {metadata['total_epochs']}
- **Effective Batch Size**: {metadata['effective_batch_size']}
- **Learning Rate**: {metadata['learning_rate']}
- **Optimizer**: {metadata['optimizer']}
- **Scheduler**: {metadata['scheduler']}
- **Mixed Precision**: {metadata['mixed_precision']}
- **Max Sequence Length**: {metadata['max_sequence_length']}
- **Flash Attention**: {metadata['flash_attention']}

## Training Results

"""
        
        if training_stats and training_stats.get('losses'):
            model_card_content += f"""- **Initial Loss**: {metadata.get('initial_loss', 'N/A'):.4f}
- **Final Loss**: {metadata.get('final_loss', 'N/A'):.4f}
- **Loss Reduction**: {metadata.get('loss_reduction_percent', 'N/A'):.1f}%
- **Total Training Steps**: {metadata.get('total_training_steps', 'N/A'):,}
"""
        
        model_card_content += f"""
## Usage

```python
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("{save_path}/tokenizer")
model = AutoModelForCausalLM.from_pretrained("{save_path}/model")

# Generate text
prompt = "The future of artificial intelligence"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)
```

## Training Infrastructure

- **Framework**: PyTorch with Transformers
- **GPU**: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}
- **CUDA Version**: {torch.version.cuda if torch.cuda.is_available() else 'N/A'}

## Files in this Model

- `model/`: Contains the trained model weights and configuration
- `tokenizer/`: Contains the tokenizer files
- `training_stats.json`: Detailed training statistics and loss curves
- `training_metadata.json`: Training configuration and metadata
- `README.md`: This file

## Notes

This model was trained from scratch (random initialization) rather than fine-tuning from pre-trained weights.
"""
        
        readme_save_path = save_path / "README.md"
        with open(readme_save_path, 'w') as f:
            f.write(model_card_content)
        print(f"✅ Model card saved to: {readme_save_path}")
        
        # 7. Copy training plots if they exist
        plot_files = ['training_metrics.png']
        for plot_file in plot_files:
            if Path(plot_file).exists():
                shutil.copy(plot_file, save_path / plot_file)
                print(f"✅ Copied {plot_file} to save directory")
        
        # 8. Calculate total size
        total_size = sum(f.stat().st_size for f in save_path.rglob('*') if f.is_file())
        size_gb = total_size / (1024**3)
        
        print(f"\n📊 Save Summary:")
        print(f"   Save directory: {save_path.absolute()}")
        print(f"   Total size: {size_gb:.2f} GB")
        print(f"   Files saved: {len(list(save_path.rglob('*')))}")
        
        print(f"\n✅ Model and artifacts saved successfully!")
        print(f"🎯 To use this model later:")
        print(f"   from transformers import AutoTokenizer, AutoModelForCausalLM")
        print(f"   tokenizer = AutoTokenizer.from_pretrained('{save_path}/tokenizer')")
        print(f"   model = AutoModelForCausalLM.from_pretrained('{save_path}/model')")
        
        return True
        
    except Exception as e:
        print(f"❌ Error saving model: {e}")
        return False

def create_kaggle_dataset_metadata(save_dir="./smollm2_trained"):
    """Create metadata for uploading to Kaggle Datasets"""
    
    print("\n📦 Creating Kaggle Dataset Metadata")
    print("-" * 40)
    
    save_path = Path(save_dir)
    if not save_path.exists():
        print("❌ Save directory doesn't exist")
        return
    
    # Kaggle dataset metadata
    kaggle_metadata = {
        "title": "SmolLM2-1.7B Trained on Cosmopedia-100k",
        "id": "your-username/smollm2-cosmopedia-trained",  # Update with your Kaggle username
        "licenses": [{"name": "apache-2.0"}],
        "keywords": ["nlp", "language-model", "transformer", "pytorch", "smollm"],
        "collaborators": [],
        "data": []
    }
    
    kaggle_metadata_path = save_path / "dataset-metadata.json"
    with open(kaggle_metadata_path, 'w') as f:
        json.dump(kaggle_metadata, f, indent=2)
    
    print(f"✅ Kaggle metadata saved to: {kaggle_metadata_path}")
    print("📝 Update the 'id' field with your Kaggle username before uploading")

# Execute saving if model is available
if 'model' in globals() and model is not None:
    print("🤖 Model found! Proceeding with save...")
    
    # Get training stats if available
    stats_to_save = training_stats if 'training_stats' in globals() else {}
    config_to_save = model_config if 'model_config' in globals() else None
    
    # Save everything
    success = save_final_model_and_artifacts(
        model=model,
        tokenizer=tokenizer,
        model_config=config_to_save,
        training_stats=stats_to_save,
        save_dir="./smollm2_trained"
    )
    
    if success:
        # Create Kaggle metadata
        create_kaggle_dataset_metadata("./smollm2_trained")
        
        print("\n🎉 All artifacts saved successfully!")
        print("\n📋 Next Steps:")
        print("1. Test the saved model by loading it back")
        print("2. Upload to Kaggle Datasets for sharing")
        print("3. Transfer to EC2 for continued training")
        print("4. Evaluate on additional benchmarks")
        
else:
    print("❌ No trained model found to save.")
    print("Please run the training loop (Cell 7) first.")
    
    # Show what would be saved
    print("\n🔧 Files that would be saved:")
    print("📁 smollm2_trained/")
    print("  ├── model/")
    print("  │   ├── pytorch_model.bin")
    print("  │   └── config.json")
    print("  ├── tokenizer/")
    print("  │   ├── tokenizer.json")
    print("  │   ├── tokenizer_config.json")
    print("  │   └── special_tokens_map.json")
    print("  ├── training_stats.json")
    print("  ├── training_metadata.json")
    print("  ├── training_metrics.png")
    print("  ├── README.md")
    print("  └── dataset-metadata.json (for Kaggle)") 

In [None]:
# Cell 11: Visualization and Results - Final Analysis
# Plot training loss curves, display training statistics, and comprehensive results summary

def create_comprehensive_training_report(training_stats, model, tokenizer):
    """Create a comprehensive final training report with all visualizations"""
    
    print("📊 COMPREHENSIVE TRAINING REPORT")
    print("=" * 80)
    print("🎯 SmolLM2-1.7B Training from Scratch on Cosmopedia-100k")
    print("=" * 80)
    
    if not training_stats or not training_stats.get('steps'):
        print("❌ No training data available for comprehensive report")
        return
    
    # 1. Training Overview
    print("\n🔍 TRAINING OVERVIEW")
    print("-" * 50)
    
    total_steps = training_stats['steps'][-1] if training_stats['steps'] else 0
    total_logged_points = len(training_stats['steps'])
    training_duration = "~3 epochs"  # Based on our setup
    
    print(f"📈 Training Completed Successfully!")
    print(f"   Total Training Steps: {total_steps:,}")
    print(f"   Logged Data Points: {total_logged_points:,}")
    print(f"   Training Duration: {training_duration}")
    print(f"   Dataset: Cosmopedia-100k (100,000 samples)")
    print(f"   Model: SmolLM2-1.7B (trained from scratch)")
    
    # 2. Loss Analysis
    print(f"\n📉 LOSS ANALYSIS")
    print("-" * 50)
    
    initial_loss = training_stats['losses'][0]
    final_loss = training_stats['losses'][-1]
    min_loss = min(training_stats['losses'])
    max_loss = max(training_stats['losses'])
    avg_loss = np.mean(training_stats['losses'])
    
    loss_reduction = ((initial_loss - final_loss) / initial_loss * 100) if initial_loss > 0 else 0
    convergence_ratio = min_loss / initial_loss if initial_loss > 0 else 0
    
    print(f"📊 Loss Statistics:")
    print(f"   Initial Loss: {initial_loss:.4f}")
    print(f"   Final Loss: {final_loss:.4f}")
    print(f"   Minimum Loss: {min_loss:.4f}")
    print(f"   Maximum Loss: {max_loss:.4f}")
    print(f"   Average Loss: {avg_loss:.4f}")
    print(f"   Loss Reduction: {loss_reduction:.2f}%")
    print(f"   Convergence Ratio: {convergence_ratio:.4f}")
    
    # Loss improvement assessment
    if loss_reduction > 50:
        loss_assessment = "🟢 Excellent - Strong learning progress"
    elif loss_reduction > 30:
        loss_assessment = "🟡 Good - Solid improvement"
    elif loss_reduction > 10:
        loss_assessment = "🟠 Fair - Some improvement"
    else:
        loss_assessment = "🔴 Poor - Limited learning"
    
    print(f"   Assessment: {loss_assessment}")
    
    # 3. Training Stability Analysis
    print(f"\n⚖️  TRAINING STABILITY")
    print("-" * 50)
    
    # Calculate volatility (standard deviation of losses)
    loss_volatility = np.std(training_stats['losses'])
    
    # Trend analysis on recent losses
    if len(training_stats['losses']) >= 20:
        recent_losses = training_stats['losses'][-20:]
        trend_slope = np.polyfit(range(len(recent_losses)), recent_losses, 1)[0]
        
        if abs(trend_slope) < 0.001:
            trend_status = "🟢 Converging - Loss stabilizing"
        elif trend_slope < -0.01:
            trend_status = "🟡 Still Learning - Loss decreasing rapidly"
        else:
            trend_status = "🔴 Potentially Diverging - Loss increasing"
    else:
        trend_status = "📊 Insufficient data for trend analysis"
        trend_slope = 0
    
    print(f"📈 Stability Metrics:")
    print(f"   Loss Volatility (σ): {loss_volatility:.4f}")
    print(f"   Recent Trend Slope: {trend_slope:.6f}")
    print(f"   Trend Status: {trend_status}")
    
    # 4. Learning Rate Analysis
    if training_stats.get('learning_rates'):
        print(f"\n📈 LEARNING RATE SCHEDULE")
        print("-" * 50)
        
        initial_lr = training_stats['learning_rates'][0]
        final_lr = training_stats['learning_rates'][-1]
        max_lr = max(training_stats['learning_rates'])
        
        print(f"📊 Learning Rate Statistics:")
        print(f"   Initial LR: {initial_lr:.2e}")
        print(f"   Final LR: {final_lr:.2e}")
        print(f"   Maximum LR: {max_lr:.2e}")
        print(f"   LR Decay Ratio: {final_lr/initial_lr:.4f}")
    
    # 5. Resource Utilization
    if training_stats.get('gpu_memory') and any(mem > 0 for mem in training_stats['gpu_memory']):
        print(f"\n💾 RESOURCE UTILIZATION")
        print("-" * 50)
        
        avg_memory = np.mean(training_stats['gpu_memory'])
        max_memory = max(training_stats['gpu_memory'])
        min_memory = min(training_stats['gpu_memory'])
        memory_efficiency = avg_memory / max_memory if max_memory > 0 else 0
        
        print(f"📊 GPU Memory Statistics:")
        print(f"   Average Usage: {avg_memory:.2f} GB")
        print(f"   Peak Usage: {max_memory:.2f} GB")
        print(f"   Minimum Usage: {min_memory:.2f} GB")
        print(f"   Memory Efficiency: {memory_efficiency:.2%}")
    
    # 6. Model Performance Summary
    if model is not None:
        print(f"\n🤖 MODEL PERFORMANCE SUMMARY")
        print("-" * 50)
        
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"📊 Model Statistics:")
        print(f"   Total Parameters: {total_params/1e9:.2f}B ({total_params:,})")
        print(f"   Trainable Parameters: {trainable_params/1e9:.2f}B ({trainable_params:,})")
        print(f"   Parameter Efficiency: {trainable_params/total_params:.2%}")
        
        # Calculate parameters per loss improvement
        if loss_reduction > 0:
            param_efficiency = total_params / loss_reduction
            print(f"   Parameters per % Loss Reduction: {param_efficiency/1e6:.1f}M")

def create_final_visualization_suite(training_stats):
    """Create comprehensive visualization suite for final analysis"""
    
    print("\n🎨 Creating Final Visualization Suite")
    print("=" * 60)
    
    if not training_stats or not training_stats.get('steps'):
        print("❌ No training data for visualization")
        return
    
    # Create a large comprehensive plot
    fig = plt.figure(figsize=(20, 16))
    gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)
    
    # Main title
    fig.suptitle('SmolLM2-1.7B Training Analysis - Comprehensive Report', 
                 fontsize=20, fontweight='bold', y=0.98)
    
    # 1. Training Loss (Large plot)
    ax1 = fig.add_subplot(gs[0, :2])
    ax1.plot(training_stats['steps'], training_stats['losses'], 'b-', linewidth=2, alpha=0.8, label='Training Loss')
    ax1.set_xlabel('Training Step', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training Loss Over Time', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.set_yscale('log')
    
    # Add trend line
    if len(training_stats['steps']) > 10:
        z = np.polyfit(training_stats['steps'], np.log(training_stats['losses']), 1)
        p = np.poly1d(z)
        ax1.plot(training_stats['steps'], np.exp(p(training_stats['steps'])), 
                'r--', alpha=0.7, linewidth=2, label='Trend Line')
    
    # Add annotations for key points
    min_loss_idx = np.argmin(training_stats['losses'])
    ax1.annotate(f'Min Loss: {min(training_stats["losses"]):.4f}', 
                xy=(training_stats['steps'][min_loss_idx], training_stats['losses'][min_loss_idx]),
                xytext=(10, 20), textcoords='offset points',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0'))
    
    ax1.legend(fontsize=12)
    
    # 2. Loss Distribution
    ax2 = fig.add_subplot(gs[0, 2])
    ax2.hist(training_stats['losses'], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
    ax2.set_xlabel('Loss Value', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Loss Distribution', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    # 3. Learning Rate Schedule
    if training_stats.get('learning_rates'):
        ax3 = fig.add_subplot(gs[1, 0])
        ax3.plot(training_stats['steps'], training_stats['learning_rates'], 'g-', linewidth=2)
        ax3.set_xlabel('Training Step', fontsize=12)
        ax3.set_ylabel('Learning Rate', fontsize=12)
        ax3.set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
        ax3.grid(True, alpha=0.3)
        ax3.set_yscale('log')
    
    # 4. Loss Smoothed (Moving Average)
    ax4 = fig.add_subplot(gs[1, 1])
    window_size = max(1, len(training_stats['losses']) // 20)  # 5% of data points
    if len(training_stats['losses']) >= window_size:
        smoothed_losses = np.convolve(training_stats['losses'], 
                                    np.ones(window_size)/window_size, mode='valid')
        smoothed_steps = training_stats['steps'][window_size-1:]
        ax4.plot(smoothed_steps, smoothed_losses, 'purple', linewidth=2)
    ax4.set_xlabel('Training Step', fontsize=12)
    ax4.set_ylabel('Smoothed Loss', fontsize=12)
    ax4.set_title(f'Smoothed Loss (Window: {window_size})', fontsize=14, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    
    # 5. GPU Memory Usage
    if training_stats.get('gpu_memory') and any(mem > 0 for mem in training_stats['gpu_memory']):
        ax5 = fig.add_subplot(gs[1, 2])
        ax5.plot(training_stats['steps'], training_stats['gpu_memory'], 'r-', linewidth=2)
        ax5.set_xlabel('Training Step', fontsize=12)
        ax5.set_ylabel('GPU Memory (GB)', fontsize=12)
        ax5.set_title('GPU Memory Usage', fontsize=14, fontweight='bold')
        ax5.grid(True, alpha=0.3)
        
        # Add average line
        avg_memory = np.mean(training_stats['gpu_memory'])
        ax5.axhline(y=avg_memory, color='orange', linestyle='--', 
                   label=f'Avg: {avg_memory:.1f}GB', linewidth=2)
        ax5.legend()
    
    # 6. Loss vs Learning Rate (if both available)
    if training_stats.get('learning_rates') and len(training_stats['learning_rates']) == len(training_stats['losses']):
        ax6 = fig.add_subplot(gs[2, 0])
        scatter = ax6.scatter(training_stats['learning_rates'], training_stats['losses'], 
                            c=training_stats['steps'], cmap='viridis', alpha=0.6, s=20)
        ax6.set_xlabel('Learning Rate', fontsize=12)
        ax6.set_ylabel('Loss', fontsize=12)
        ax6.set_title('Loss vs Learning Rate', fontsize=14, fontweight='bold')
        ax6.set_xscale('log')
        ax6.set_yscale('log')
        ax6.grid(True, alpha=0.3)
        plt.colorbar(scatter, ax=ax6, label='Training Step')
    
    # 7. Training Progress per Epoch
    if training_stats.get('epochs'):
        ax7 = fig.add_subplot(gs[2, 1])
        epochs = training_stats['epochs']
        unique_epochs = sorted(set(epochs))
        epoch_losses = []
        
        for epoch in unique_epochs:
            epoch_indices = [i for i, e in enumerate(epochs) if e == epoch]
            if epoch_indices:
                epoch_loss = np.mean([training_stats['losses'][i] for i in epoch_indices])
                epoch_losses.append(epoch_loss)
        
        if epoch_losses:
            bars = ax7.bar(unique_epochs, epoch_losses, alpha=0.7, 
                          color=['#FF6B6B', '#4ECDC4', '#45B7D1'][:len(unique_epochs)])
            ax7.set_xlabel('Epoch', fontsize=12)
            ax7.set_ylabel('Average Loss', fontsize=12)
            ax7.set_title('Average Loss per Epoch', fontsize=14, fontweight='bold')
            ax7.grid(True, alpha=0.3)
            
            # Add value labels on bars
            for i, (bar, val) in enumerate(zip(bars, epoch_losses)):
                ax7.text(bar.get_x() + bar.get_width()/2, bar.get_height(),
                        f'{val:.3f}', ha='center', va='bottom', fontweight='bold')
    
    # 8. Loss Improvement Rate
    ax8 = fig.add_subplot(gs[2, 2])
    if len(training_stats['losses']) > 1:
        loss_diffs = np.diff(training_stats['losses'])
        ax8.plot(training_stats['steps'][1:], loss_diffs, 'orange', linewidth=1, alpha=0.7)
        ax8.axhline(y=0, color='red', linestyle='--', alpha=0.8)
        ax8.set_xlabel('Training Step', fontsize=12)
        ax8.set_ylabel('Loss Change per Step', fontsize=12)
        ax8.set_title('Loss Improvement Rate', fontsize=14, fontweight='bold')
        ax8.grid(True, alpha=0.3)
    
    # 9. Training Summary Text
    ax9 = fig.add_subplot(gs[3, :])
    ax9.axis('off')
    
    # Create summary text
    summary_text = f"""
TRAINING SUMMARY
• Dataset: Cosmopedia-100k (100,000 samples) • Model: SmolLM2-1.7B (trained from scratch)
• Total Steps: {training_stats['steps'][-1]:,} • Loss Reduction: {((training_stats['losses'][0] - training_stats['losses'][-1]) / training_stats['losses'][0] * 100):.1f}%
• Initial Loss: {training_stats['losses'][0]:.4f} • Final Loss: {training_stats['losses'][-1]:.4f} • Min Loss: {min(training_stats['losses']):.4f}
• Training Configuration: FP16 Mixed Precision, Flash Attention, Cosine LR Schedule, AdamW Optimizer
• Optimization: Gradient Accumulation, Gradient Clipping, Multi-GPU Support
"""
    
    ax9.text(0.5, 0.5, summary_text, ha='center', va='center', 
            fontsize=14, bbox=dict(boxstyle='round,pad=1', facecolor='lightblue', alpha=0.8),
            transform=ax9.transAxes, fontweight='bold')
    
    # Save the comprehensive plot
    plt.savefig('comprehensive_training_analysis.png', dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    print("✅ Comprehensive visualization saved as 'comprehensive_training_analysis.png'")
    
    plt.tight_layout()
    plt.show()

def generate_final_metrics_summary(training_stats):
    """Generate final numerical metrics summary"""
    
    print("\n📋 FINAL METRICS SUMMARY")
    print("=" * 60)
    
    if not training_stats or not training_stats.get('steps'):
        print("❌ No training data available")
        return {}
    
    # Calculate all key metrics
    metrics = {
        'total_steps': training_stats['steps'][-1],
        'initial_loss': training_stats['losses'][0],
        'final_loss': training_stats['losses'][-1],
        'min_loss': min(training_stats['losses']),
        'max_loss': max(training_stats['losses']),
        'avg_loss': np.mean(training_stats['losses']),
        'loss_std': np.std(training_stats['losses']),
        'loss_reduction_percent': ((training_stats['losses'][0] - training_stats['losses'][-1]) / training_stats['losses'][0] * 100) if training_stats['losses'][0] > 0 else 0,
        'convergence_ratio': min(training_stats['losses']) / training_stats['losses'][0] if training_stats['losses'][0] > 0 else 0,
    }
    
    # Add learning rate metrics if available
    if training_stats.get('learning_rates'):
        metrics.update({
            'initial_lr': training_stats['learning_rates'][0],
            'final_lr': training_stats['learning_rates'][-1],
            'max_lr': max(training_stats['learning_rates']),
            'lr_decay_ratio': training_stats['learning_rates'][-1] / training_stats['learning_rates'][0]
        })
    
    # Add GPU memory metrics if available
    if training_stats.get('gpu_memory') and any(mem > 0 for mem in training_stats['gpu_memory']):
        metrics.update({
            'avg_gpu_memory': np.mean(training_stats['gpu_memory']),
            'peak_gpu_memory': max(training_stats['gpu_memory']),
            'min_gpu_memory': min(training_stats['gpu_memory'])
        })
    
    # Print formatted metrics
    print("🔢 Key Performance Indicators:")
    print(f"   Loss Reduction: {metrics['loss_reduction_percent']:.2f}%")
    print(f"   Convergence Ratio: {metrics['convergence_ratio']:.4f}")
    print(f"   Training Stability (σ): {metrics['loss_std']:.4f}")
    print(f"   Final Performance: {metrics['final_loss']:.4f}")
    
    # Overall grade
    score = 0
    if metrics['loss_reduction_percent'] > 50: score += 40
    elif metrics['loss_reduction_percent'] > 30: score += 30
    elif metrics['loss_reduction_percent'] > 10: score += 20
    
    if metrics['convergence_ratio'] < 0.3: score += 30
    elif metrics['convergence_ratio'] < 0.5: score += 20
    elif metrics['convergence_ratio'] < 0.7: score += 10
    
    if metrics['loss_std'] < 0.5: score += 20
    elif metrics['loss_std'] < 1.0: score += 15
    elif metrics['loss_std'] < 2.0: score += 10
    
    if metrics['final_loss'] < 2.0: score += 10
    elif metrics['final_loss'] < 3.0: score += 5
    
    if score >= 90:
        grade = "A+ 🏆 Exceptional Training"
    elif score >= 80:
        grade = "A 🥇 Excellent Training"
    elif score >= 70:
        grade = "B+ 🥈 Very Good Training"
    elif score >= 60:
        grade = "B 🥉 Good Training"
    elif score >= 50:
        grade = "C+ ⭐ Acceptable Training"
    else:
        grade = "C ⚠️ Needs Improvement"
    
    print(f"\n🏆 OVERALL TRAINING GRADE: {grade}")
    print(f"📊 Training Score: {score}/100")
    
    return metrics

# Execute comprehensive analysis if training stats are available
if 'training_stats' in globals() and training_stats and training_stats.get('steps'):
    print("🎯 Training data found! Generating comprehensive analysis...")
    
    # Get model and tokenizer if available
    analysis_model = model if 'model' in globals() else None
    analysis_tokenizer = tokenizer if 'tokenizer' in globals() else None
    
    # Generate comprehensive report
    create_comprehensive_training_report(training_stats, analysis_model, analysis_tokenizer)
    
    # Create visualization suite
    create_final_visualization_suite(training_stats)
    
    # Generate final metrics
    final_metrics = generate_final_metrics_summary(training_stats)
    
    # Save final metrics
    try:
        import json
        with open('final_training_metrics.json', 'w') as f:
            json.dump(final_metrics, f, indent=2)
        print("\n💾 Final metrics saved to 'final_training_metrics.json'")
    except Exception as e:
        print(f"❌ Error saving final metrics: {e}")
    
    print("\n" + "="*80)
    print("🎉 COMPREHENSIVE ANALYSIS COMPLETE!")
    print("🎯 SmolLM2-1.7B has been successfully trained from scratch")
    print("📊 All visualizations and metrics have been generated")
    print("💾 Model and artifacts are ready for deployment")
    print("="*80)
    
else:
    print("❌ No training statistics available for comprehensive analysis.")
    print("Please run the training loop (Cell 7) first to generate training data.")
    
    # Show what the final analysis would look like
    print("\n🔧 This cell would generate:")
    print("📊 Comprehensive training report with detailed statistics")
    print("📈 Advanced visualization suite with multiple plots")
    print("🏆 Final performance grading and assessment")
    print("💾 Complete metrics export for future reference")
    print("🎯 Professional training summary for presentations")
    
    print("\n📝 Example output:")
    print("🏆 OVERALL TRAINING GRADE: A 🥇 Excellent Training")
    print("📊 Training Score: 85/100")
    print("🔢 Loss Reduction: 65.3%")
    print("⚖️ Training Stability: Excellent")
    print("🎯 Model Performance: High Quality") 