In [None]:
# Cell 1: Environment Setup and Imports
import subprocess
import sys

def install_package(package):
    try:
        __import__(package)
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Required packages
required_packages = [
    "transformers>=4.36.0",
    "datasets>=2.14.0", 
    "torch>=2.0.0",
    "accelerate>=0.24.0",
    "flash-attn>=2.3.0",
    "wandb",
    "matplotlib",
    "seaborn",
    "numpy",
    "tqdm"
]

# Install packages if needed (uncomment if running first time)
# for package in required_packages:
#     install_package(package.split(">=")[0])

# Core imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.cuda.amp import GradScaler, autocast

# Transformers and datasets
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    AutoConfig,
    get_linear_schedule_with_warmup,
    get_cosine_schedule_with_warmup,
    TrainingArguments,
    set_seed
)
from datasets import load_dataset

# Utilities
import os
import json
import time
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

print("✅ Environment setup complete") 

In [None]:
# Cell 2: GPU Check and Random Seed Setup
def setup_device_and_seed(seed=42):
    """Setup device, log GPU info, and set random seeds"""
    
    # Set random seeds for reproducibility
    set_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    # Ensure deterministic behavior
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # Check CUDA availability
    if torch.cuda.is_available():
        device = torch.device("cuda")
        num_gpus = torch.cuda.device_count()
    else:
        device = torch.device("cpu")
        num_gpus = 0
    
    return device, num_gpus

# Setup device and seeds
SEED = 42
device, num_gpus = setup_device_and_seed(SEED)

# Training configuration
MULTI_GPU = num_gpus > 1

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print(f"✅ Device setup complete - Using {device} with {num_gpus} GPUs") 

In [None]:
# Cell 3: Load and Explore Dataset (Cosmopedia 100k)
def load_and_explore_dataset():
    """Load Cosmopedia 100k dataset"""
    
    try:
        ds = load_dataset("HuggingFaceTB/cosmopedia-100k", split="train")
        return ds
    except Exception as e:
        print(f"❌ Error loading dataset: {e}")
        return None

# Load the dataset
dataset = load_and_explore_dataset()

# Store dataset info for later use
if dataset:
    DATASET_SIZE = len(dataset)
    print(f"✅ Dataset loaded - {DATASET_SIZE:,} samples")
else:
    print("❌ Failed to load dataset. Please check your connection and try again.") 

In [None]:
# Cell 4: Tokenizer - Existing Tokenizer
def load_smollm2_tokenizer():
    """Load the existing SmolLM2 tokenizer"""
    
    model_id = "HuggingFaceTB/SmolLM2-1.7B"
    
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id)
        
        # Add padding token if it doesn't exist
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            
        return tokenizer
        
    except Exception as e:
        print(f"❌ Error loading tokenizer: {e}")
        return None

# Load the tokenizer
tokenizer = load_smollm2_tokenizer()

if tokenizer:
    # Set sequence length for training
    MAX_SEQ_LENGTH = 1024
    print("✅ Tokenizer loaded successfully")
else:
    print("❌ Failed to load tokenizer. Please check your connection and model ID.") 

In [None]:
# Cell 5: Dataset Preprocessing - ULTRA OPTIMIZED (Minimal Memory)
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class CosmopediaDataset(Dataset):
    """Ultra Memory-Optimized Dataset"""
    
    def __init__(self, dataset, tokenizer, max_length=256):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        # Get text from dataset
        text = self.dataset[idx]['text']
        
        # Tokenize with aggressive truncation
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        # Extract input_ids and attention_mask
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        # For causal language modeling, labels are the same as input_ids
        labels = input_ids.clone()
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

def preprocess_dataset_ultra_optimized(dataset, tokenizer, max_length=256, batch_size=1):
    """Ultra memory-optimized preprocessing"""
    
    if dataset is None or tokenizer is None:
        print("❌ Dataset or tokenizer is None. Please load them first.")
        return None, None
    
    # Create custom dataset
    torch_dataset = CosmopediaDataset(dataset, tokenizer, max_length)
    
    # Create DataLoader with ultra optimization
    dataloader = DataLoader(
        torch_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False,
        drop_last=True,
        persistent_workers=False
    )
    
    return torch_dataset, dataloader

# ULTRA OPTIMIZED Configuration
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 64
MAX_SEQ_LENGTH_ULTRA = 256

# Calculate effective batch size
EFFECTIVE_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS * max(1, num_gpus)

# Clear GPU memory aggressively
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

# Preprocess dataset with ultra settings
train_dataset, train_dataloader = preprocess_dataset_ultra_optimized(
    dataset, 
    tokenizer, 
    max_length=MAX_SEQ_LENGTH_ULTRA, 
    batch_size=BATCH_SIZE
)

if train_dataloader:
    # Calculate final training statistics
    STEPS_PER_EPOCH = len(train_dataloader)
    TOTAL_STEPS = STEPS_PER_EPOCH * 3  # 3 epochs
    
    print(f"✅ Dataset preprocessing complete - {len(train_dataloader):,} batches")
else:
    print("❌ Failed to create DataLoader. Please check your setup.") 

In [None]:
# Cell 6: Model Configuration and Initialization (NO MIXED PRECISION VERSION)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

def setup_model_and_training_components_no_amp():
    """Initialize SmolLM2 model with random weights and setup training components - NO MIXED PRECISION"""
    
    # Clear GPU cache before model initialization
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Load SmolLM2 configuration
    model_id = "HuggingFaceTB/SmolLM2-1.7B"
    
    try:
        # Load config only (not weights)
        config = AutoConfig.from_pretrained(model_id)
        
        # MEMORY OPTIMIZATION: Enable gradient checkpointing
        config.use_cache = False
        config.gradient_checkpointing = True
        
        # Enable Flash Attention 2 if available
        try:
            config.use_flash_attention_2 = True
        except:
            pass
        
    except Exception as e:
        print(f"❌ Error loading config: {e}")
        return None, None, None, None, None
    
    # Initialize model with random weights
    try:
        model = AutoModelForCausalLM.from_config(config)
        model.gradient_checkpointing_enable()
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        
    except Exception as e:
        print(f"❌ Error initializing model: {e}")
        return None, None, None, None, None
    
    # Move model to device
    model = model.to(device)
    
    # Setup optimizer (AdamW)
    learning_rate = 3e-5
    weight_decay = 0.01
    
    # No weight decay for bias and layer norm parameters
    no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    
    optimizer = torch.optim.AdamW(
        optimizer_grouped_parameters,
        lr=learning_rate,
        betas=(0.9, 0.95),
        eps=1e-8,
        foreach=False
    )
    
    # Setup learning rate scheduler
    num_warmup_steps = int(0.1 * TOTAL_STEPS)  # 10% warmup
    
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=TOTAL_STEPS
    )
    
    # NO MIXED PRECISION - Use None for scaler
    scaler = None
    
    # Final memory check
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    return model, optimizer, scheduler, scaler, config

# Initialize model and training components WITHOUT mixed precision
model, optimizer, scheduler, scaler, model_config = setup_model_and_training_components_no_amp()

if model is not None:
    total_params = sum(p.numel() for p in model.parameters())
    print(f"✅ Model initialized - {total_params/1e9:.2f}B parameters")
else:
    print("❌ Failed to initialize model. Please check your setup.")

# Gradient clipping value
MAX_GRAD_NORM = 1.0 

In [None]:
# Cell 7: Ultra-Safe Training Loop with Robust Checkpointing
import glob
import os
import gc
from pathlib import Path
import shutil

# Set environment variables for maximum stability
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Training tracking variables
training_stats = {
    'steps': [],
    'losses': [],
    'learning_rates': [],
    'epochs': [],
    'gpu_memory': []
}

def aggressive_memory_cleanup():
    """Ultra-aggressive memory cleanup"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

def safe_checkpoint_save(model, optimizer, scheduler, step, epoch, loss, checkpoint_dir="./checkpoints"):
    """Ultra-safe checkpoint saving with corruption prevention"""
    
    Path(checkpoint_dir).mkdir(exist_ok=True)
    aggressive_memory_cleanup()
    
    checkpoint = {
        'step': step,
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'training_stats': training_stats
    }
    
    checkpoint_path = f"{checkpoint_dir}/checkpoint_step_{step}.pt"
    temp_path = f"{checkpoint_dir}/temp_checkpoint_step_{step}.pt"
    
    try:
        torch.save(checkpoint, temp_path)
        test_load = torch.load(temp_path, map_location='cpu')
        if 'step' in test_load and test_load['step'] == step:
            shutil.move(temp_path, checkpoint_path)
        else:
            raise Exception("Checkpoint verification failed")
    except Exception as e:
        if os.path.exists(temp_path):
            os.remove(temp_path)
        return None
    
    aggressive_memory_cleanup()
    return checkpoint_path

def load_checkpoint_safe(checkpoint_path, model, optimizer, scheduler):
    """Safe checkpoint loading with verification"""
    
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        if 'model_state_dict' not in checkpoint:
            raise Exception("Invalid checkpoint format")
        
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_step = checkpoint['step']
        start_epoch = checkpoint['epoch']
        best_loss = checkpoint['loss']
        
        global training_stats
        training_stats = checkpoint.get('training_stats', training_stats)
        
        print(f"📂 Checkpoint loaded - resuming from step {start_step}")
        return start_step, start_epoch, best_loss
        
    except Exception as e:
        print(f"❌ Error loading checkpoint: {e}")
        return 0, 0, float('inf')

def find_latest_checkpoint(checkpoint_dir="./checkpoints"):
    """Find the latest valid checkpoint"""
    
    if not os.path.exists(checkpoint_dir):
        return None
        
    checkpoints = glob.glob(f"{checkpoint_dir}/checkpoint_step_*.pt")
    if not checkpoints:
        return None
        
    valid_checkpoints = []
    for cp in checkpoints:
        try:
            test = torch.load(cp, map_location='cpu')
            if 'step' in test:
                valid_checkpoints.append(cp)
        except:
            os.remove(cp)
    
    if not valid_checkpoints:
        return None
        
    valid_checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
    return valid_checkpoints[-1]

def train_model_ultra_safe():
    """Ultra-safe training with aggressive memory management"""
    
    print("🚀 Starting training...")
    
    # Initial aggressive cleanup
    aggressive_memory_cleanup()
    
    # Check for existing checkpoints
    latest_checkpoint = find_latest_checkpoint()
    start_step = 0
    start_epoch = 0
    best_loss = float('inf')
    
    if latest_checkpoint:
        response = input(f"Found checkpoint: {latest_checkpoint}. Resume? (y/n): ")
        if response.lower() == 'y':
            start_step, start_epoch, best_loss = load_checkpoint_safe(
                latest_checkpoint, model, optimizer, scheduler
            )
    
    # Calculate starting position
    steps_per_epoch = len(train_dataloader)
    current_epoch = start_step // steps_per_epoch
    current_step_in_epoch = start_step % steps_per_epoch
    
    # Training loop
    model.train()
    global_step = start_step
    running_loss = 0.0
    log_interval = 50
    current_loss = 0.0
    
    try:
        for epoch in range(start_epoch, 3):
            print(f"\nEpoch {epoch + 1}/3")
            
            epoch_start_time = time.time()
            epoch_loss = 0.0
            epoch_steps = 0
            
            aggressive_memory_cleanup()
            
            # Skip batches if resuming mid-epoch
            dataloader_iter = iter(train_dataloader)
            for _ in range(current_step_in_epoch):
                next(dataloader_iter)
            
            if epoch > start_epoch:
                current_step_in_epoch = 0
                dataloader_iter = iter(train_dataloader)
            
            progress_bar = tqdm(
                dataloader_iter, 
                total=steps_per_epoch - current_step_in_epoch,
                desc=f"Epoch {epoch+1}"
            )
            
            for step, batch in enumerate(progress_bar, start=current_step_in_epoch):
                # Move batch to device
                input_ids = batch['input_ids'].to(device, non_blocking=True)
                attention_mask = batch['attention_mask'].to(device, non_blocking=True)
                labels = batch['labels'].to(device, non_blocking=True)
                
                # Forward pass (FP32)
                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    use_cache=False
                )
                loss = outputs.loss
                loss = loss / GRADIENT_ACCUMULATION_STEPS
                
                # Backward pass
                loss.backward()
                current_loss = loss.item() * GRADIENT_ACCUMULATION_STEPS
                
                # Gradient accumulation and optimization
                if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    
                    if global_step % 10 == 0:
                        aggressive_memory_cleanup()
                
                # Update statistics
                running_loss += current_loss
                epoch_loss += current_loss
                epoch_steps += 1
                global_step += 1
                
                # Log progress
                if global_step % log_interval == 0:
                    avg_loss = running_loss / log_interval
                    current_lr = scheduler.get_last_lr()[0]
                    
                    gpu_memory = 0
                    if torch.cuda.is_available():
                        gpu_memory = torch.cuda.memory_allocated() / 1e9
                    
                    # Store statistics
                    training_stats['steps'].append(global_step)
                    training_stats['losses'].append(avg_loss)
                    training_stats['learning_rates'].append(current_lr)
                    training_stats['epochs'].append(epoch + 1)
                    training_stats['gpu_memory'].append(gpu_memory)
                    
                    print(f"Step {global_step:,}/{TOTAL_STEPS:,} - Loss: {avg_loss:.4f}")
                    running_loss = 0.0
                
                # Checkpoint saving every 200 steps
                if global_step % 200 == 0:
                    checkpoint_path = safe_checkpoint_save(
                        model, optimizer, scheduler, 
                        global_step, epoch, current_loss
                    )
                    
                    if checkpoint_path:
                        # Keep only last 2 checkpoints
                        checkpoints = glob.glob("./checkpoints/checkpoint_step_*.pt")
                        if len(checkpoints) > 2:
                            checkpoints.sort(key=lambda x: int(x.split('_')[-1].split('.')[0]))
                            for old_checkpoint in checkpoints[:-2]:
                                os.remove(old_checkpoint)
                
                # Memory monitoring
                if global_step % 100 == 0:
                    if torch.cuda.is_available():
                        memory_used = torch.cuda.memory_allocated() / 1e9
                        if memory_used > 12:
                            aggressive_memory_cleanup()
                
                # Early stopping check
                if global_step >= TOTAL_STEPS:
                    break
            
            # Epoch summary
            epoch_time = time.time() - epoch_start_time
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else 0
            
            print(f"Epoch {epoch + 1} complete - Avg Loss: {avg_epoch_loss:.4f} - Time: {epoch_time/60:.1f}m")
            
            # Save end-of-epoch checkpoint
            safe_checkpoint_save(
                model, optimizer, scheduler, 
                global_step, epoch + 1, avg_epoch_loss
            )
            
            if global_step >= TOTAL_STEPS:
                break
    
    except KeyboardInterrupt:
        print("Training interrupted by user")
        safe_checkpoint_save(model, optimizer, scheduler, global_step, epoch, current_loss)
    except Exception as e:
        print(f"Training error: {e}")
        safe_checkpoint_save(model, optimizer, scheduler, global_step, epoch, current_loss)
        raise
    
    print(f"Training completed! Total steps: {global_step:,}")
    return training_stats

# Check if all components are ready
missing_components = []
if 'model' not in globals() or model is None: missing_components.append("model")
if 'optimizer' not in globals() or optimizer is None: missing_components.append("optimizer") 
if 'scheduler' not in globals() or scheduler is None: missing_components.append("scheduler")
if 'train_dataloader' not in globals() or train_dataloader is None: missing_components.append("dataloader")

if not missing_components:
    print("✅ All components ready for training")
    
    aggressive_memory_cleanup()
    
    try:
        training_stats = train_model_ultra_safe()
    except KeyboardInterrupt:
        print("Training cancelled by user")
    except Exception as e:
        print(f"Training failed: {e}")
else:
    print(f"❌ Missing components: {', '.join(missing_components)}")
    print("Please run previous cells first.") 