# Wikipedia LLM Training Notebook

This notebook contains the complete pipeline for training a Wikipedia-based language model using Phi-2 with LoRA fine-tuning.

## Table of Contents
1. [Package Installation](#Package-Installation)
2. [GPU Check](#GPU-Check)
3. [Data Download](#Data-Download)
4. [Data Preparation](#Data-Preparation)
5. [Model Setup](#Model-Setup)
6. [LoRA Configuration](#LoRA-Configuration)
7. [Training Progress Check](#Training-Progress-Check)
8. [Model Training](#Model-Training)
9. [VRAM Cleanup](#VRAM-Cleanup)
10. [Model Testing](#Model-Testing)
11. [Chat Interface](#Chat-Interface)
12. [Training Restart](#Training-Restart)
13. [Hyperparameter Tuning](#Hyperparameter-Tuning)
14. [Model Evaluation](#Model-Evaluation)
15. [Data Analysis](#Data-Analysis)

## Package Installation

Installs all required Python packages for the training pipeline, including PyTorch, Transformers, PEFT, and other dependencies.

In [None]:
# This is for RTX 5060 Ti GPU (Uncomment to use in a Jupyter notebook)
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
!pip install transformers datasets peft accelerate bitsandbytes trl tqdm protobuf scipy sentencepiece psutil matplotlib mlflow rouge-score nltk wordcloud seaborn pandas tensorboard

## GPU Check

Verifies CUDA availability and GPU information to ensure the system is ready for GPU-accelerated training.

In [None]:
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")

## Data Download

Downloads a subset of the Wikipedia dataset (100,000 articles) from Hugging Face and saves it locally for training.

In [None]:
from datasets import load_dataset
import os

os.makedirs("data", exist_ok=True)

print("Downloading Wikipedia dataset (this will take a while)...")
dataset = load_dataset(
    "wikimedia/wikipedia",
    "20231101.en",  # November 2023 snapshot
    split="train",
    streaming=False  # Set to True if to stream instead of download all
)

# Take subset for testing (adjust as needed)
print("Creating subset...")
subset = dataset.select(range(min(100000, len(dataset))))

print("Saving dataset locally...")
subset.save_to_disk("data/wikipedia_100k")
print(f"Dataset saved! Total articles: {len(subset)}")
print(f"Sample article title: {subset[0]['title']}")
print(f"Sample text: {subset[0]['text'][:200]}")

## Data Preparation

Processes the raw Wikipedia data with text cleaning, quality filtering, and multiple instruction formats for better training data quality and diversity.

In [None]:
from datasets import load_from_disk
import random
import re

# Load saved dataset
dataset = load_from_disk("data/wikipedia_100k")

def clean_text(text):
    """Clean and preprocess text"""
    # Remove HTML tags
    text = re.sub(r'<[^>]+>', '', text)
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    # Remove URLs
    text = re.sub(r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+', '', text)
    # Remove wiki markup
    text = re.sub(r'\[\[([^\]|]*\|)?([^\]]*)\]\]', r'\2', text)  # Remove [[links]]
    text = re.sub(r"''+", '', text)  # Remove italic/bold markup
    return text

def format_for_training(example):
    """Convert Wikipedia articles to multiple instruction formats"""
    text = clean_text(example['text'])
    title = example['title']
    
    # Skip very short articles
    if len(text) < 200:
        return {"text": ""}  # Will be filtered out
    
    # Create multiple instruction formats for variety
    formats = [
        f"### Instruction:\nProvide information about {title}.\n\n### Response:\n{text[:1000]}",
        f"### Instruction:\nExplain what {title} is.\n\n### Response:\n{text[:1000]}",
        f"### Instruction:\nTell me about {title}.\n\n### Response:\n{text[:1000]}",
        f"### Instruction:\nGive me details on {title}.\n\n### Response:\n{text[:1000]}"
    ]
    
    # Randomly select one format
    return {"text": random.choice(formats)}

# Format dataset
print("Formatting and cleaning dataset...")
formatted_dataset = dataset.map(
    format_for_training,
    remove_columns=dataset.column_names
)

# Filter out empty entries
formatted_dataset = formatted_dataset.filter(lambda x: len(x['text']) > 50)

# Split into train/validation
split_dataset = formatted_dataset.train_test_split(test_size=0.05, seed=42)

# Save formatted data
split_dataset.save_to_disk("data/formatted_wikipedia")
print("Data preparation complete!")
print(f"Training samples: {len(split_dataset['train'])}")
print(f"Validation samples: {len(split_dataset['test'])}")
print(f"Filtered out {len(dataset) - len(formatted_dataset)} short/invalid articles")

## Model Setup

Loads the Phi-2 model with 4-bit quantization, performs system resource checks, memory estimation, and provides model merging capabilities for optimized inference.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import psutil

# Choose a base model (recommendations for your hardware)
MODEL_OPTIONS = {
    "small": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # 1.1B - easiest to train
    "medium": "microsoft/phi-2",  # 2.7B - good balance
    "large": "mistralai/Mistral-7B-v0.1"  # 7B - needs quantization
}

model_name = MODEL_OPTIONS["medium"]  # Start with Phi-2

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

def estimate_memory_usage(model_name, quantized=True):
    """Estimate memory usage for the model"""
    # Rough estimates (in GB)
    base_sizes = {
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0": 2.2,
        "microsoft/phi-2": 5.0,
        "mistralai/Mistral-7B-v0.1": 14.0
    }
    
    base_size = base_sizes.get(model_name, 5.0)
    if quantized:
        estimated_vram = base_size * 0.3  # 4-bit quantization reduces to ~30%
    else:
        estimated_vram = base_size
    
    return estimated_vram

def check_system_resources():
    """Check available system resources"""
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 if torch.cuda.is_available() else 0
    cpu_memory = psutil.virtual_memory().total / 1024**3
    
    print(f"Available GPU Memory: {gpu_memory:.1f} GB")
    print(f"Available CPU Memory: {cpu_memory:.1f} GB")
    return gpu_memory, cpu_memory

# Check resources before loading
print("Checking system resources...")
gpu_mem, cpu_mem = check_system_resources()

estimated_vram = estimate_memory_usage(model_name, quantized=True)
print(f"Estimated VRAM usage: {estimated_vram:.1f} GB")

if gpu_mem > 0 and estimated_vram > gpu_mem * 0.8:
    print("WARNING: Estimated VRAM usage is high. Consider using smaller model or more aggressive quantization.")
elif gpu_mem == 0:
    print("WARNING: No GPU detected. Training will be very slow on CPU.")

# Load model
print(f"\nLoading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Model loaded successfully!")
print(f"Model size: ~{sum(p.numel() for p in model.parameters()) / 1e9:.2f}B parameters")

## Optional: Model Merging

This section provides functionality to merge LoRA weights back into the base model for faster inference after training.

### Usage Instructions:

1. Train your LoRA model using the training cells below.
2. After training completes, run the code cell below to merge the LoRA weights.
3. Provide the base model (loaded above), the path to your trained LoRA weights, and a save path for the merged model.
4. Example: `merged_model = merge_lora_weights(model, './wikipedia_model/final', './wikipedia_model/merged')`
5. The merged model will be saved and can be used for faster inference without LoRA adapters.

**Note:** This step is optional. If you prefer to keep the model in LoRA format for flexibility, you can skip this.

In [None]:
def merge_lora_weights(base_model, lora_path, save_path):
    """Merge LoRA weights back into base model for faster inference"""
    from peft import PeftModel
    
    print("Loading LoRA weights...")
    lora_model = PeftModel.from_pretrained(base_model, lora_path)
    
    print("Merging weights...")
    merged_model = lora_model.merge_and_unload()
    
    print(f"Saving merged model to {save_path}...")
    merged_model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)
    
    print("Model merging complete!")
    return merged_model

# Example usage (uncomment to use):
merged_model = merge_lora_weights(model, "./wikipedia_model/final", "./wikipedia_model/merged")

## LoRA Configuration

Configures Low-Rank Adaptation with dynamic rank scaling, expanded target modules, and enhanced regularization for optimal fine-tuning performance.

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch

# Dynamic LoRA rank based on model size
def get_dynamic_lora_rank(model):
    """Determine LoRA rank based on model parameter count"""
    param_count = sum(p.numel() for p in model.parameters())
    
    if param_count < 2e9:  # < 2B params
        rank = 8
    elif param_count < 7e9:  # 2-7B params
        rank = 16
    else:  # > 7B params
        rank = 32
    
    return rank

# Prepare model for training
model = prepare_model_for_kbit_training(model)

# Get dynamic rank
lora_rank = get_dynamic_lora_rank(model)
print(f"Using LoRA rank: {lora_rank} (based on model size)")

# Enhanced LoRA configuration with more target modules
lora_config = LoraConfig(
    r=lora_rank,  # Dynamic rank
    lora_alpha=lora_rank * 2,  # Typically 2x rank
    target_modules=[
        "q_proj",   # Query projection
        "k_proj",   # Key projection  
        "v_proj",   # Value projection
        "o_proj",   # Output projection
        "gate_proj",  # Gate projection (for some models)
        "up_proj",    # Up projection (for some models)
        "down_proj",  # Down projection (for some models)
    ],
    lora_dropout=0.1,  # Slightly higher dropout for better regularization
    bias="none",
    task_type="CAUSAL_LM"
)

# Apply LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# LoRA dropout scheduling (optional - requires custom training loop)
# This would need to be implemented in the training loop to decrease dropout over time
def get_scheduled_dropout(epoch, max_epochs, initial_dropout=0.1, final_dropout=0.01):
    """Linearly decrease dropout from initial to final over training"""
    return initial_dropout - (initial_dropout - final_dropout) * (epoch / max_epochs)

print(f"\nLoRA Configuration Summary:")
print(f"- Rank: {lora_rank}")
print(f"- Alpha: {lora_rank * 2}")
print(f"- Target Modules: {len(lora_config.target_modules)} modules")
print(f"- Dropout: {lora_config.lora_dropout}")
print("- Bias: none")
print("- Task: Causal LM")

## Training Progress Check

Analyzes existing checkpoints with loss curve visualization, training speed metrics, and live model performance preview through sample generation.

In [None]:
import os
import glob
import json
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

def check_training_progress():
    """Check current training progress with enhanced metrics and visualization"""
    
    output_dir = "./wikipedia_model"
    
    # Find all checkpoints
    checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*"))
    
    if not checkpoints:
        print("No checkpoints found. Training hasn't started yet.")
        return
    
    # Sort by step number
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
    
    print(f"Total checkpoints found: {len(checkpoints)}")
    print("\nCheckpoint history:")
    
    # Collect data for plotting
    steps = []
    losses = []
    learning_rates = []
    eval_losses = []
    
    for cp in checkpoints:
        step = cp.split("-")[-1]
        
        # Try to read trainer_state.json for more info
        state_file = os.path.join(cp, "trainer_state.json")
        if os.path.exists(state_file):
            with open(state_file, 'r') as f:
                state = json.load(f)
                epoch = state.get('epoch', 'N/A')
                print(f"  - Step {step} (Epoch {epoch:.2f})")
                
                # Collect training metrics
                log_history = state.get('log_history', [])
                training_entries = [entry for entry in log_history if 'loss' in entry]
                eval_entries = [entry for entry in log_history if 'eval_loss' in entry]
                
                if training_entries:
                    last_train = training_entries[-1]
                    steps.append(int(step))
                    losses.append(last_train.get('loss', 0))
                    learning_rates.append(last_train.get('learning_rate', 0))
                
                if eval_entries:
                    last_eval = eval_entries[-1]
                    eval_losses.append(last_eval.get('eval_loss', 0))
        else:
            print(f"  - Step {step}")
    
    # Latest checkpoint
    latest = checkpoints[-1]
    print(f"\nLatest checkpoint: {latest}")
    
    # Read detailed info from latest
    state_file = os.path.join(latest, "trainer_state.json")
    if os.path.exists(state_file):
        with open(state_file, 'r') as f:
            state = json.load(f)
            print(f"\nDetailed Progress:")
            print(f"  Current Step: {state.get('global_step', 'N/A')}")
            print(f"  Current Epoch: {state.get('epoch', 'N/A'):.2f}")
            
            # Get latest training metrics (find last entry with training loss)
            log_history = state.get('log_history', [])
            training_entries = [entry for entry in log_history if 'loss' in entry]
            eval_entries = [entry for entry in log_history if 'eval_loss' in entry]
            
            if training_entries:
                last_training = training_entries[-1]
                training_loss = last_training.get('loss', 'N/A')
                learning_rate = last_training.get('learning_rate', 'N/A')
            else:
                training_loss = 'N/A'
                learning_rate = 'N/A'
            
            if eval_entries:
                last_eval = eval_entries[-1]
                validation_loss = last_eval.get('eval_loss', 'N/A')
            else:
                validation_loss = 'N/A'
            
            print(f"  Training Loss: {training_loss}")
            print(f"  Validation Loss: {validation_loss}")
            print(f"  Learning Rate: {learning_rate}")
            
            # Calculate progress percentage
            max_steps = state.get('max_steps', None)
            if max_steps:
                current_step = state.get('global_step', 0)
                percentage = (current_step / max_steps) * 100
                print(f"  Progress: {percentage:.1f}% ({current_step}/{max_steps} steps)")
                
                # Estimate training speed
                if len(steps) > 1:
                    steps_per_hour = (steps[-1] - steps[0]) / ((len(steps) - 1) * 0.1)  # Rough estimate
                    print(f"  Estimated Speed: ~{steps_per_hour:.1f} steps/hour")
            else:
                print("  Progress: Unable to calculate (max_steps not available)")
            
            # Plot loss curves if we have data
            if len(losses) > 1:
                plt.figure(figsize=(12, 4))
                
                plt.subplot(1, 3, 1)
                plt.plot(steps, losses, 'b-', label='Training Loss')
                plt.xlabel('Steps')
                plt.ylabel('Loss')
                plt.title('Training Loss Curve')
                plt.legend()
                
                plt.subplot(1, 3, 2)
                plt.plot(steps, learning_rates, 'r-', label='Learning Rate')
                plt.xlabel('Steps')
                plt.ylabel('Learning Rate')
                plt.title('Learning Rate Schedule')
                plt.legend()
                
                plt.subplot(1, 3, 3)
                if eval_losses:
                    eval_steps = steps[:len(eval_losses)]  # Assume eval at same steps
                    plt.plot(eval_steps, eval_losses, 'g-', label='Validation Loss')
                plt.xlabel('Steps')
                plt.ylabel('Loss')
                plt.title('Validation Loss')
                plt.legend()
                
                plt.tight_layout()
                plt.show()
            
            # Show recent training history (last 3 training entries with validation loss)
            if training_entries:
                print(f"\nRecent Training History (last {min(3, len(training_entries))} entries):")
                for entry in training_entries[-3:]:
                    step = entry.get('step', 'N/A')
                    loss = entry.get('loss', 'N/A')
                    lr = entry.get('learning_rate', 'N/A')
                    epoch = entry.get('epoch', 'N/A')
                    
                    # Find corresponding validation loss for this step
                    val_loss = 'N/A'
                    for eval_entry in eval_entries:
                        if eval_entry.get('step') == step:
                            val_loss = eval_entry.get('eval_loss', 'N/A')
                            break
                    
                    print(f"    Step {step}: Train Loss={loss}, Val Loss={val_loss}, LR={lr}, Epoch={epoch:.2f}")
            else:
                print("\nNo training history available yet.")
            
            # Generate sample output from latest checkpoint
            print(f"\nSample Generation from Latest Checkpoint:")
            try:
                # Load model for inference
                base_model = AutoModelForCausalLM.from_pretrained(
                    "microsoft/phi-2",
                    torch_dtype=torch.float16,
                    device_map="auto"
                )
                lora_model = PeftModel.from_pretrained(base_model, latest)
                test_tokenizer = AutoTokenizer.from_pretrained(latest)
                
                # Test prompt
                prompt = "### Instruction:\nProvide information about Artificial Intelligence.\n\n### Response:\n"
                inputs = test_tokenizer(prompt, return_tensors="pt").to(lora_model.device)
                
                with torch.no_grad():
                    outputs = lora_model.generate(
                        **inputs,
                        max_length=100,
                        temperature=0.7,
                        do_sample=True,
                        pad_token_id=test_tokenizer.eos_token_id
                    )
                
                response = test_tokenizer.decode(outputs[0], skip_special_tokens=True)
                # Extract just the response part
                answer = response.split("### Response:\n")[-1][:200] + "..."
                print(f"AI: {answer}")
                
                # Clean up
                del base_model, lora_model
                torch.cuda.empty_cache()
                
            except Exception as e:
                print(f"Could not generate sample: {e}")
    
    print("\nTo resume training, simply run: python train.py")

if __name__ == "__main__":
    check_training_progress()

## Model Training

Executes the main training loop with optimized settings including gradient checkpointing, mixed precision, and automatic checkpoint resumption.

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_from_disk
import os
import glob
import gc
import time

def clear_vram():
    """Clear VRAM before training"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()
        print("VRAM cleared")

# Call it before loading model
clear_vram()

# Load formatted data
dataset = load_from_disk("data/formatted_wikipedia")

# Tokenize function
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,  # Adjust based on your needs
        padding="max_length"
    )

tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=dataset["train"].column_names
)

# Data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # Causal LM, not masked
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./wikipedia_model",
    num_train_epochs=3,
    per_device_train_batch_size=8,     # Adjust based on GPU memory
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=16,     # Effective batch size = 128
    learning_rate=2e-4,
    fp16=True,                         # Mixed precision training
    save_steps=100,                    # Save every 100 steps
    eval_steps=100,                    # Evaluate every 100 steps
    logging_steps=100,
    eval_strategy="steps",
    save_total_limit=3,                # Keep only last 3 checkpoints (saves disk space)
    load_best_model_at_end=True,
    warmup_steps=100,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",          # Memory efficient optimizer
    resume_from_checkpoint=True,       # Auto-resume from latest checkpoint
    gradient_checkpointing=True,       # Enable gradient checkpointing for memory efficiency
    dataloader_num_workers=2,          # Use 2 workers for data loading
    logging_dir="./logs",              # Directory for TensorBoard logs
    report_to=["tensorboard"],         # Enable TensorBoard logging
    seed=42,                           # Set seed for reproducibility
    weight_decay=0.01,                 # Add weight decay for regularization
    max_grad_norm=1.0,                 # Gradient clipping
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

# Check for existing checkpoints
def find_latest_checkpoint(output_dir):
    """Find the latest checkpoint in the output directory"""
    checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*"))
    if not checkpoints:
        return None
    # Sort by step number
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
    latest = checkpoints[-1]
    return latest

# Look for existing checkpoint
checkpoint_path = find_latest_checkpoint("./wikipedia_model")

start_time = time.time()

if checkpoint_path:
    print(f"Found existing checkpoint: {checkpoint_path}")
    print("Resuming training from checkpoint...")
    print(f"Progress: Step {checkpoint_path.split('-')[-1]}")
    
    # Resume training from checkpoint
    try:
        trainer.train(resume_from_checkpoint=checkpoint_path)
    except KeyboardInterrupt:
        print("Training interrupted by user")
else:
    print("No existing checkpoint found. Starting fresh training...")
    # Start training from scratch
    try:
        trainer.train()
    except KeyboardInterrupt:
        print("Training interrupted by user")

end_time = time.time()
training_duration = end_time - start_time

# Set model to eval mode
# model.eval() ## training sets it to eval automatically

# Save final model
print("Saving final model...")
trainer.save_model("./wikipedia_model/final")
tokenizer.save_pretrained("./wikipedia_model/final")
# Also save the LoRA weights explicitly
model.save_pretrained("./wikipedia_model/final")

# Training summary
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Total training time: {training_duration:.2f} seconds ({training_duration/3600:.2f} hours)")
print(f"Model saved to: ./wikipedia_model/final")
print(f"TensorBoard logs: ./logs (run 'tensorboard --logdir ./logs' to view)")
print(f"Checkpoints saved every {training_args.save_steps} steps")
print(f"Final step: {trainer.state.global_step}")
print(f"Final epoch: {trainer.state.epoch:.2f}")
print("Training complete!")
print("="*50)

## VRAM Cleanup

Clears GPU memory cache and forces garbage collection to free up VRAM after training or testing.

In [None]:
import torch
import gc

def clear_vram():
    """Clear VRAM completely"""
    print("Clearing VRAM...")
    
    # Clear PyTorch cache
    torch.cuda.empty_cache()
    
    # Force garbage collection
    gc.collect()
    
    # Print VRAM usage
    if torch.cuda.is_available():
        print(f"VRAM allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
        print(f"VRAM reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    
    print("VRAM cleared!")

if __name__ == "__main__":
    clear_vram()

## Model Testing

Loads a trained checkpoint and tests the model's ability to generate informative responses about various topics.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import numpy as np
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import glob
import os
import json

def calculate_perplexity(model, tokenizer, text, max_length=512):
    """Calculate perplexity for a given text"""
    inputs = tokenizer(text, return_tensors="pt", max_length=max_length, truncation=True).to(model.device)
    
    with torch.no_grad():
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        perplexity = torch.exp(loss).item()
    
    return perplexity

def calculate_bleu_rouge(generated, reference):
    """Calculate BLEU and ROUGE scores"""
    # BLEU
    smoothing = SmoothingFunction().method4
    bleu = sentence_bleu([reference.split()], generated.split(), smoothing_function=smoothing)
    
    # ROUGE
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = scorer.score(reference, generated)
    
    return bleu, rouge_scores

def load_checkpoint_model(checkpoint_path):
    """Load model from checkpoint"""
    base_model = AutoModelForCausalLM.from_pretrained(
        "microsoft/phi-2",
        torch_dtype=torch.float16,
        device_map="auto"
    )
    model = PeftModel.from_pretrained(base_model, checkpoint_path)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    return model, tokenizer

def evaluate_model(model, tokenizer, test_prompts):
    """Comprehensive model evaluation"""
    results = {
        'perplexity': [],
        'bleu': [],
        'rouge1': [],
        'rouge2': [],
        'rougeL': [],
        'responses': []
    }
    
    for prompt_data in test_prompts:
        prompt = prompt_data['prompt']
        reference = prompt_data['reference']
        
        # Generate response
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_length=len(inputs['input_ids'][0]) + 100,
                temperature=0.7,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id
            )
        
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_answer = response.split("### Response:\n")[-1].strip()
        
        # Calculate metrics
        perplexity = calculate_perplexity(model, tokenizer, generated_answer)
        bleu, rouge_scores = calculate_bleu_rouge(generated_answer, reference)
        
        results['perplexity'].append(perplexity)
        results['bleu'].append(bleu)
        results['rouge1'].append(rouge_scores['rouge1'].fmeasure)
        results['rouge2'].append(rouge_scores['rouge2'].fmeasure)
        results['rougeL'].append(rouge_scores['rougeL'].fmeasure)
        results['responses'].append({
            'prompt': prompt,
            'generated': generated_answer,
            'reference': reference
        })
    
    # Calculate averages
    for key in ['perplexity', 'bleu', 'rouge1', 'rouge2', 'rougeL']:
        results[f'avg_{key}'] = np.mean(results[key])
    
    return results

# Test prompts with references for evaluation
test_prompts = [
    {
        'prompt': "### Instruction:\nProvide information about Python programming language.\n\n### Response:\n",
        'reference': "Python is a high-level programming language known for its simplicity and readability. It supports multiple programming paradigms including procedural, object-oriented, and functional programming."
    },
    {
        'prompt': "### Instruction:\nExplain what Artificial Intelligence is.\n\n### Response:\n",
        'reference': "Artificial Intelligence (AI) is the simulation of human intelligence processes by machines, especially computer systems. It includes learning, reasoning, and self-correction."
    },
    {
        'prompt': "### Instruction:\nTell me about Machine Learning.\n\n### Response:\n",
        'reference': "Machine Learning is a subset of AI that enables computers to learn and improve from experience without being explicitly programmed. It uses algorithms to identify patterns in data."
    },
    {
        'prompt': "### Instruction:\nWhat is the capital of France?\n\n### Response:\n",
        'reference': "The capital of France is Paris, which is located in the north-central part of the country along the Seine River."
    },
    {
        'prompt': "### Instruction:\nDescribe the process of photosynthesis.\n\n### Response:\n",
        'reference': "Photosynthesis is the process by which plants convert light energy into chemical energy. It involves chlorophyll absorbing sunlight and using it to convert carbon dioxide and water into glucose and oxygen."
    }
]

# Find all checkpoints for comparison
checkpoints = glob.glob("./wikipedia_model/checkpoint-*")
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

if checkpoints:
    print("Available checkpoints:")
    for i, cp in enumerate(checkpoints):
        print(f"{i+1}. {cp}")
    
    # Test latest checkpoint
    latest_checkpoint = checkpoints[-1]
    print(f"\nTesting latest checkpoint: {latest_checkpoint}")
    
    model, tokenizer = load_checkpoint_model(latest_checkpoint)
    
    # Run evaluation
    results = evaluate_model(model, tokenizer, test_prompts)
    
    print("\n" + "="*60)
    print("MODEL EVALUATION RESULTS")
    print("="*60)
    print(f"Average Perplexity: {results['avg_perplexity']:.2f}")
    print(f"Average BLEU Score: {results['avg_bleu']:.4f}")
    print(f"Average ROUGE-1 F1: {results['avg_rouge1']:.4f}")
    print(f"Average ROUGE-2 F1: {results['avg_rouge2']:.4f}")
    print(f"Average ROUGE-L F1: {results['avg_rougeL']:.4f}")
    print("="*60)
    
    # Define markers for splitting
    instr_marker = '### Instruction:\n'
    resp_marker = '\n\n### Response:\n'
    
    # Show sample responses
    print("\nSample Responses:")
    for i, resp in enumerate(results['responses'][:3]):
        print(f"\n{i+1}. Prompt: {resp['prompt'].split(instr_marker)[1].split(resp_marker)[0]}")
        print(f"   Generated: {resp['generated'][:100]}...")
        print(f"   Reference: {resp['reference'][:100]}...")
    
    # Checkpoint comparison (if multiple checkpoints exist)
    if len(checkpoints) > 1:
        print(f"\nCheckpoint Comparison (last {min(3, len(checkpoints))} checkpoints):")
        comparison_results = []
        
        for cp in checkpoints[-3:]:
            try:
                model_comp, tokenizer_comp = load_checkpoint_model(cp)
                results_comp = evaluate_model(model_comp, tokenizer_comp, test_prompts[:2])  # Quick comparison
                step = cp.split("-")[-1]
                comparison_results.append({
                    'step': step,
                    'perplexity': results_comp['avg_perplexity'],
                    'bleu': results_comp['avg_bleu']
                })
                del model_comp
                torch.cuda.empty_cache()
            except:
                continue
        
        for res in comparison_results:
            print(f"Step {res['step']}: Perplexity={res['perplexity']:.2f}, BLEU={res['bleu']:.4f}")
    
    # Cleanup
    del model
    torch.cuda.empty_cache()
    
else:
    print("No checkpoints found. Please train the model first.")

print("\nModel testing complete!")

## Chat Interface

Creates an interactive chat interface for real-time conversation with the trained Wikipedia model.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import re
import warnings
warnings.filterwarnings("ignore")

class OrganizedChatInterface:
    def __init__(self, model_path="./wikipedia_model/final"):
        print("Loading organized model and tokenizer...")
        try:
            # Load base model with simpler device mapping for speed
            base_model = AutoModelForCausalLM.from_pretrained(
                "microsoft/phi-2",
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map={"": 0} if torch.cuda.is_available() else {"": "cpu"},  # Simpler mapping
                trust_remote_code=True
            )

            # Load PEFT adapter
            self.model = PeftModel.from_pretrained(base_model, model_path)
            self.tokenizer = AutoTokenizer.from_pretrained(model_path)

            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            print("Organized model loaded successfully!")

        except Exception as e:
            print(f"Error loading model: {e}")
            raise

        # Conversation history
        self.history = []
        self.max_history = 10

    def organize_response(self, response, query):
        """Organize and clean up the response to sound more coherent"""
        # Remove the original query from the response
        if response.lower().startswith(query.lower()):
            response = response[len(query):].strip()

        # Clean up common Wikipedia artifacts
        response = re.sub(r'\s+', ' ', response)  # Multiple spaces to single
        response = re.sub(r'\n+', ' ', response)  # Newlines to spaces

        # Split into sentences and clean up
        sentences = re.split(r'(?<=[.!?])\s+', response.strip())

        # Filter and organize sentences
        organized_sentences = []
        for sentence in sentences[:5]:  # Limit to 5 sentences
            sentence = sentence.strip()
            if len(sentence) > 10 and not sentence.startswith(('See also', 'References', 'External links', 'Category:')):
                # Capitalize first letter
                if sentence:
                    sentence = sentence[0].upper() + sentence[1:]
                organized_sentences.append(sentence)

        # Join with proper punctuation
        organized_response = '. '.join(organized_sentences)
        if organized_response and not organized_response.endswith(('.', '!', '?')):
            organized_response += '.'

        return organized_response.strip()

    def generate_organized_response(self, query, max_length=300, temperature=0.8, top_p=0.9):
        """Generate organized, coherent response"""
        try:
            # Create a better prompt that works with Wikipedia training data
            # Guide the model toward encyclopedic explanations
            if query.lower().startswith(('what is', 'who is', 'where is', 'when', 'how', 'why')):
                # For questions, create a prompt that leads to explanations
                prompt = f"{query[0].upper()}{query[1:]} "
            elif query.lower().startswith(('tell me about', 'explain', 'describe')):
                # For explanation requests
                prompt = f"{query[0].upper()}{query[1:]} "
            else:
                # For general queries, assume it's asking for information
                prompt = f"{query} is "

            inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=50)

            if torch.cuda.is_available():
                inputs = {k: v.cuda() for k, v in inputs.items()}

            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_length=max_length + len(inputs['input_ids'][0]),
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.2,  # Add some repetition penalty for coherence
                    no_repeat_ngram_size=3,  # Prevent 3-gram repetition
                    num_beams=1
                )

            full_response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Organize and clean the response
            organized_response = self.organize_response(full_response, prompt.strip())

            # If response is too short or empty, try a different approach
            if len(organized_response) < 20:
                # Fallback: try with "is" prefix
                fallback_prompt = f"{query} is"
                inputs2 = self.tokenizer(fallback_prompt, return_tensors="pt", truncation=True, max_length=50)
                if torch.cuda.is_available():
                    inputs2 = {k: v.cuda() for k, v in inputs2.items()}

                outputs2 = self.model.generate(
                    **inputs2,
                    max_length=max_length + len(inputs2['input_ids'][0]),
                    temperature=temperature,
                    top_p=top_p,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    repetition_penalty=1.2,
                    no_repeat_ngram_size=3,
                    num_beams=1
                )

                fallback_response = self.tokenizer.decode(outputs2[0], skip_special_tokens=True)
                organized_response = self.organize_response(fallback_response, fallback_prompt)

            # Add to history
            self.history.append((query, organized_response))
            if len(self.history) > self.max_history:
                self.history.pop(0)

            return organized_response if organized_response else f"I don't have enough information about '{query}' in my training data."

        except Exception as e:
            return f"Error generating response: {str(e)}"

    def interactive_chat(self):
        """Interactive chat with organized responses"""
        print("ORGANIZED WIKIPEDIA LLM CHAT")
        print("=" * 60)
        print("This model provides organized, coherent responses from Wikipedia knowledge.")
        print("Type 'quit' to exit, 'history' to view chat history, 'clear' to reset")
        print("=" * 60)

        while True:
            try:
                query = input("\nYou: ").strip()

                if query.lower() == 'quit':
                    print("Goodbye!")
                    break
                elif query.lower() == 'history':
                    print("\nChat History:")
                    for i, (q, r) in enumerate(self.history[-5:], 1):
                        print(f"{i}. Q: {q}")
                        print(f"   A: {r[:150]}...")
                    continue
                elif query.lower() == 'clear':
                    self.history.clear()
                    print("History cleared.")
                    continue

                if not query:
                    continue

                print("Assistant: ", end="", flush=True)
                response = self.generate_organized_response(query)
                print(response)

            except KeyboardInterrupt:
                print("\nInterrupted by user.")
                break
            except Exception as e:
                print(f"Error: {e}")

    def get_history(self):
        """Get conversation history"""
        return self.history

    def clear_history(self):
        """Clear conversation history"""
        self.history.clear()
        print("Conversation history cleared.")

# Create the organized interface
organized_chat = OrganizedChatInterface()

# Start interactive chat
organized_chat.interactive_chat()

## Training Restart

Safely deletes all training checkpoints and logs to allow restarting the training process from scratch.

In [None]:
import os
import shutil

def restart_training():
    """Delete all training files to restart from scratch"""
    
    # Directories to delete
    dirs_to_delete = ["./wikipedia_model", "./logs"]
    
    for dir_path in dirs_to_delete:
        if os.path.exists(dir_path):
            print(f"Deleting {dir_path}...")
            shutil.rmtree(dir_path)
            print(f"Deleted {dir_path}")
        else:
            print(f"{dir_path} does not exist")
    
    print("\nAll training files deleted. You can now run the training cell to start fresh.")

if __name__ == "__main__":
    # Ask for confirmation
    confirm = input("This will delete all training checkpoints and logs. Are you sure? (yes/no): ")
    if confirm.lower() == 'yes':
        restart_training()
    else:
        print("Restart cancelled.")

## Hyperparameter Tuning

Automated hyperparameter optimization using grid search to find the best learning rate, batch size, and LoRA parameters for your model.

In [None]:
import itertools
from transformers import TrainingArguments, Trainer
from datasets import load_from_disk
import torch
import os
import json
from tqdm import tqdm

# Load model and tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

MODEL_OPTIONS = {
    "small": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",  # 1.1B - easiest to train
    "medium": "microsoft/phi-2",  # 2.7B - good balance
    "large": "mistralai/Mistral-7B-v0.1"  # 7B - needs quantization
}

model_name = MODEL_OPTIONS["medium"]  # Start with Phi-2

# Quantization config for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("Model and tokenizer loaded successfully!")

def objective(params):
    """Objective function for hyperparameter evaluation"""
    
    learning_rate, per_device_batch_size, lora_rank, lora_alpha, weight_decay = params
    
    # Load data
    dataset = load_from_disk("data/formatted_wikipedia")
    
    # Tokenize function
    def tokenize_function(examples):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)
        tokenized["labels"] = tokenized["input_ids"].copy()  # For causal LM, labels are the same as input_ids
        return tokenized
    
    # Tokenize datasets
    tokenized_train = dataset["train"].select(range(1000)).map(tokenize_function, batched=True)
    tokenized_test = dataset["test"].select(range(200)).map(tokenize_function, batched=True)
    
    # Set format for PyTorch
    tokenized_train.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    tokenized_test.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
    
    # Calculate effective batch size
    gradient_accumulation_steps = max(1, 16 // per_device_batch_size)  # Target effective batch of 16
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir="./grid_trials",
        num_train_epochs=1,  # Quick trial
        per_device_train_batch_size=per_device_batch_size,
        per_device_eval_batch_size=per_device_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        fp16=True,
        save_steps=50,
        eval_steps=50,
        logging_steps=50,
        eval_strategy="steps",
        save_total_limit=1,
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        greater_is_better=False,
        seed=42,
        report_to=[],  # Disable integrations to avoid MLflow issues
        remove_unused_columns=False,  # Keep all columns
    )
    
    # Configure LoRA with parameters
    from peft import LoraConfig, get_peft_model
    lora_config = LoraConfig(
        r=lora_rank,
        lora_alpha=lora_alpha,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )
    
    # Load model (assuming model and tokenizer are already loaded in notebook)
    model_with_lora = get_peft_model(model, lora_config)
    
    # Initialize trainer
    trainer = Trainer(
        model=model_with_lora,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_test,
    )
    
    # Train
    trainer.train()
    
    # Get best eval loss
    best_eval_loss = trainer.state.best_metric
    
    # Cleanup
    del model_with_lora
    torch.cuda.empty_cache()
    
    return best_eval_loss, params

def run_hyperparameter_tuning():
    """Run grid search hyperparameter optimization"""
    print("Starting hyperparameter tuning with grid search...")
    print("This may take several hours depending on the number of combinations.")
    
    # Define parameter grid
    param_grid = {
        'learning_rate': [1e-4, 5e-4, 1e-3],
        'per_device_batch_size': [4, 8],
        'lora_rank': [8, 16],
        'lora_alpha': [16, 32],
        'weight_decay': [0.01, 0.05]
    }
    
    # Generate all combinations
    keys = param_grid.keys()
    values = param_grid.values()
    combinations = list(itertools.product(*values))
    
    print(f"Testing {len(combinations)} parameter combinations...")
    
    best_loss = float('inf')
    best_params = None
    
    for i, combo in enumerate(tqdm(combinations, desc="Hyperparameter Tuning", unit="trial")):
        params = dict(zip(keys, combo))
        print(f"\nTrial {i+1}/{len(combinations)}: {params}")
        
        try:
            loss, _ = objective(combo)
            print(f"Eval loss: {loss:.4f}")
            
            if loss < best_loss:
                best_loss = loss
                best_params = params
                print(f"New best loss: {best_loss:.4f} (Trial {i+1})")
                
        except Exception as e:
            print(f"Error in trial {i+1}: {e}")
            continue
    
    # Print results
    print("\n" + "="*50)
    print("HYPERPARAMETER TUNING RESULTS")
    print("="*50)
    print(f"Best eval loss: {best_loss:.4f}")
    print("\nBest hyperparameters:")
    for key, value in best_params.items():
        print(f"  {key}: {value}")
    
    # Save best parameters
    with open("best_hyperparams.json", "w") as f:
        json.dump(best_params, f, indent=2)
    
    print(f"\nBest parameters saved to best_hyperparams.json")
    print("Use these parameters in your main training script.")
    
    return best_params

# Uncomment to run tuning (takes significant time and resources)
if __name__ == "__main__":
    best_params = run_hyperparameter_tuning()

## Model Evaluation

Comprehensive model evaluation with multiple metrics including BLEU, ROUGE, perplexity, and custom benchmarks.

In [None]:
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from datasets import load_from_disk
import numpy as np
from rouge_score import rouge_scorer
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import glob
import os
import traceback
from typing import List, Dict, Any, Optional, Union

class RobustModelEvaluator:
    """Robust and comprehensive model evaluation suite with extensive error handling"""

    def __init__(self, model_path: str = "./wikipedia_model/final", device: str = "auto"):
        """
        Initialize the evaluator with robust error handling.

        Args:
            model_path: Path to the model directory
            device: Device to use ('auto', 'cuda', 'cpu', or specific device number)
        """
        self.model_path = model_path
        self.device = device
        self.model = None
        self.tokenizer = None
        self.generator = None
        self.rouge_scorer = None
        self.smooth = None

        print(f"Initializing robust model evaluator for {model_path}...")

        try:
            self._load_model()
            self._setup_tokenizer()
            self._setup_pipeline()
            self._setup_scorers()
            print("✓ Model evaluator initialized successfully!")
        except Exception as e:
            print(f"✗ Failed to initialize evaluator: {e}")
            traceback.print_exc()
            raise

    def _load_model(self):
        """Load the model with comprehensive error handling"""
        print("Loading model...")

        try:
            # Determine device mapping
            if self.device == "auto":
                if torch.cuda.is_available():
                    device_map = {"": 0}  # Use first GPU
                    torch_dtype = torch.float16
                else:
                    device_map = {"": "cpu"}
                    torch_dtype = torch.float32
            elif self.device == "cpu":
                device_map = {"": "cpu"}
                torch_dtype = torch.float32
            else:
                # Specific GPU device
                device_num = int(self.device) if isinstance(self.device, str) and self.device.isdigit() else 0
                device_map = {"": device_num}
                torch_dtype = torch.float16

            # Load base model
            base_model = AutoModelForCausalLM.from_pretrained(
                "microsoft/phi-2",
                torch_dtype=torch_dtype,
                device_map=device_map,
                trust_remote_code=True,
                low_cpu_mem_usage=True
            )

            # Load PEFT adapter
            self.model = PeftModel.from_pretrained(base_model, self.model_path)

            print("✓ Model loaded successfully!")

        except Exception as e:
            print(f"✗ Error loading model: {e}")
            raise

    def _setup_tokenizer(self):
        """Setup tokenizer with proper configuration"""
        print("Setting up tokenizer...")

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_path)

            # Ensure pad token is set
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Ensure we have the necessary special tokens
            if self.tokenizer.eos_token is None:
                self.tokenizer.eos_token = "</s>"

            print("✓ Tokenizer configured successfully!")

        except Exception as e:
            print(f"✗ Error setting up tokenizer: {e}")
            raise

    def _setup_pipeline(self):
        """Setup text generation pipeline with device handling"""
        print("Setting up text generation pipeline...")

        try:
            pipeline_kwargs = {
                "model": self.model,
                "tokenizer": self.tokenizer,
                "max_new_tokens": 100,
                "temperature": 0.7,
                "do_sample": True,
                "pad_token_id": self.tokenizer.eos_token_id,
                "eos_token_id": self.tokenizer.eos_token_id,
            }

            # Only specify device if not using accelerate device_map
            if not hasattr(self.model, 'hf_device_map') or self.model.hf_device_map is None:
                if torch.cuda.is_available():
                    pipeline_kwargs["device"] = 0
                else:
                    pipeline_kwargs["device"] = -1

            self.generator = pipeline("text-generation", **pipeline_kwargs)
            print("✓ Pipeline configured successfully!")

        except Exception as e:
            print(f"✗ Error setting up pipeline: {e}")
            raise

    def _setup_scorers(self):
        """Setup evaluation scorers"""
        try:
            self.rouge_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
            self.smooth = SmoothingFunction().method4
            print("✓ Scorers configured successfully!")
        except Exception as e:
            print(f"✗ Error setting up scorers: {e}")
            raise

    def calculate_perplexity(self, texts: List[str], batch_size: int = 4) -> float:
        """
        Calculate perplexity with robust error handling.

        Args:
            texts: List of text strings to evaluate
            batch_size: Batch size for processing

        Returns:
            Average perplexity score
        """
        if not texts:
            print("Warning: No texts provided for perplexity calculation")
            return float('inf')

        perplexities = []
        valid_texts = []

        # Filter and validate texts
        for text in texts:
            if isinstance(text, str) and len(text.strip()) > 0:
                valid_texts.append(text.strip())
            else:
                print(f"Warning: Skipping invalid text: {type(text)}")

        if not valid_texts:
            print("Warning: No valid texts for perplexity calculation")
            return float('inf')

        print(f"Calculating perplexity for {len(valid_texts)} texts...")

        try:
            for i in range(0, len(valid_texts), batch_size):
                batch_texts = valid_texts[i:i+batch_size]

                try:
                    # Tokenize with error handling
                    encodings = self.tokenizer(
                        batch_texts,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=512  # Limit length to avoid memory issues
                    )

                    # Move to appropriate device
                    device = next(self.model.parameters()).device
                    input_ids = encodings.input_ids.to(device)
                    attention_mask = encodings.attention_mask.to(device)

                    with torch.no_grad():
                        outputs = self.model(
                            input_ids=input_ids,
                            attention_mask=attention_mask,
                            labels=input_ids
                        )

                        if outputs.loss is not None:
                            loss = outputs.loss.item()
                            if not np.isinf(loss) and not np.isnan(loss):
                                perplexity = np.exp(loss)
                                if not np.isinf(perplexity) and not np.isnan(perplexity):
                                    perplexities.append(perplexity)

                except Exception as e:
                    print(f"Warning: Error processing batch {i//batch_size}: {e}")
                    continue

            if perplexities:
                avg_perplexity = np.mean(perplexities)
                print(f"Average perplexity: {avg_perplexity:.4f}")
                return avg_perplexity
            else:
                print("Warning: No valid perplexity calculations")
                return float('inf')

        except Exception as e:
            print(f"Error in perplexity calculation: {e}")
            return float('inf')

    def _safe_extract_generated_text(self, full_text: str, prompt: str) -> str:
        """
        Safely extract generated text from full pipeline output.

        Args:
            full_text: Complete text returned by pipeline
            prompt: Original prompt

        Returns:
            Extracted generated text
        """
        try:
            if not isinstance(full_text, str) or not isinstance(prompt, str):
                return ""

            # Remove prompt from the beginning if present
            if full_text.startswith(prompt):
                generated = full_text[len(prompt):].strip()
            else:
                generated = full_text.strip()

            # Clean up common artifacts
            generated = generated.split('\n\n')[0]  # Take only first paragraph
            generated = generated.split('\n')[0]    # Take only first line if multi-line

            return generated

        except Exception as e:
            print(f"Warning: Error extracting generated text: {e}")
            return ""

    def calculate_bleu_rouge(self, generated_texts: List[str], reference_texts: List[str],
                           prompts: Optional[List[str]] = None) -> Dict[str, float]:
        """
        Calculate BLEU and ROUGE scores with comprehensive error handling.

        Args:
            generated_texts: List of generated text strings
            reference_texts: List of reference text strings
            prompts: Optional list of prompts used for generation

        Returns:
            Dictionary with BLEU and ROUGE scores
        """
        if not generated_texts or not reference_texts:
            print("Warning: Empty text lists provided")
            return {'bleu': 0.0, 'rouge1': 0.0, 'rouge2': 0.0, 'rougeL': 0.0}

        bleu_scores = []
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []

        min_length = min(len(generated_texts), len(reference_texts))
        print(f"Calculating BLEU/ROUGE for {min_length} text pairs...")

        for i in range(min_length):
            try:
                gen_raw = generated_texts[i] if i < len(generated_texts) else ""
                ref_raw = reference_texts[i] if i < len(reference_texts) else ""
                prompt = prompts[i] if prompts and i < len(prompts) else ""

                # Extract clean generated text
                gen_clean = self._safe_extract_generated_text(gen_raw, prompt)

                # Validate inputs
                if not isinstance(gen_clean, str) or not isinstance(ref_raw, str):
                    print(f"Warning: Non-string inputs at index {i}")
                    continue

                gen_clean = gen_clean.strip()
                ref_clean = ref_raw.strip()

                if len(gen_clean) == 0 or len(ref_clean) == 0:
                    continue

                # BLEU calculation
                try:
                    gen_tokens = gen_clean.split()
                    ref_tokens = ref_clean.split()

                    if len(gen_tokens) > 0 and len(ref_tokens) > 0:
                        bleu = sentence_bleu([ref_tokens], gen_tokens, smoothing_function=self.smooth)
                        if not np.isnan(bleu):
                            bleu_scores.append(bleu)
                except Exception as e:
                    print(f"Warning: BLEU calculation failed for sample {i}: {e}")

                # ROUGE calculation
                try:
                    rouge_scores = self.rouge_scorer.score(ref_clean, gen_clean)
                    if hasattr(rouge_scores, 'get') and 'rouge1' in rouge_scores:
                        rouge1_scores.append(rouge_scores['rouge1'].fmeasure)
                        rouge2_scores.append(rouge_scores['rouge2'].fmeasure)
                        rougeL_scores.append(rouge_scores['rougeL'].fmeasure)
                except Exception as e:
                    print(f"Warning: ROUGE calculation failed for sample {i}: {e}")

            except Exception as e:
                print(f"Warning: General error processing sample {i}: {e}")
                continue

        # Calculate averages
        result = {
            'bleu': np.mean(bleu_scores) if bleu_scores else 0.0,
            'rouge1': np.mean(rouge1_scores) if rouge1_scores else 0.0,
            'rouge2': np.mean(rouge2_scores) if rouge2_scores else 0.0,
            'rougeL': np.mean(rougeL_scores) if rougeL_scores else 0.0
        }

        print(f"✓ Calculated scores - BLEU: {result['bleu']:.4f}, ROUGE-1: {result['rouge1']:.4f}")
        return result

    def generate_samples(self, prompts: List[str], num_samples: int = 10) -> List[Dict[str, Any]]:
        """
        Generate text samples with robust error handling.

        Args:
            prompts: List of prompt strings
            num_samples: Number of samples to generate

        Returns:
            List of sample dictionaries
        """
        if not prompts:
            print("Warning: No prompts provided")
            return []

        samples = []
        valid_prompts = [p for p in prompts if isinstance(p, str) and len(p.strip()) > 0][:num_samples]

        if not valid_prompts:
            print("Warning: No valid prompts found")
            return []

        print(f"Generating {len(valid_prompts)} text samples...")

        for i, prompt in enumerate(tqdm(valid_prompts, desc="Generating samples")):
            try:
                # Generate with timeout and error handling
                outputs = self.generator(
                    prompt,
                    max_new_tokens=100,
                    num_return_sequences=1,
                    temperature=0.8,
                    do_sample=True,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                )

                # Extract generated text safely
                if isinstance(outputs, list) and len(outputs) > 0:
                    if isinstance(outputs[0], dict) and 'generated_text' in outputs[0]:
                        generated_text = outputs[0]['generated_text']
                    elif isinstance(outputs[0], str):
                        generated_text = outputs[0]
                    else:
                        generated_text = str(outputs[0])
                else:
                    generated_text = str(outputs)

                # Clean the generated text
                clean_generated = self._safe_extract_generated_text(generated_text, prompt)

                samples.append({
                    'prompt': prompt,
                    'generated': generated_text,  # Keep full text
                    'generated_clean': clean_generated,  # Clean version for evaluation
                    'length': len(clean_generated.split()) if clean_generated else 0
                })

            except Exception as e:
                print(f"Warning: Error generating for prompt {i}: {e}")
                samples.append({
                    'prompt': prompt,
                    'generated': "Error: Failed to generate text",
                    'generated_clean': "",
                    'length': 0
                })

        print(f"✓ Generated {len(samples)} samples successfully")
        return samples

    def benchmark_model(self, test_dataset=None, num_samples: int = 50) -> Optional[Dict[str, Any]]:
        """
        Run comprehensive benchmark with robust error handling.

        Args:
            test_dataset: Optional dataset to use for evaluation
            num_samples: Number of samples to evaluate

        Returns:
            Dictionary with evaluation results or None if failed
        """
        print("\n" + "="*70)
        print("STARTING COMPREHENSIVE MODEL EVALUATION")
        print("="*70)

        try:
            # Load test dataset
            if test_dataset is None:
                print("Loading test dataset...")
                try:
                    dataset = load_from_disk("data/formatted_wikipedia")
                    test_dataset = dataset["test"]
                    print("✓ Dataset loaded successfully")
                except Exception as e:
                    print(f"✗ Error loading dataset: {e}")
                    print("Please ensure the dataset is available at 'data/formatted_wikipedia'")
                    return None

            # Sample test data
            available_samples = len(test_dataset)
            actual_samples = min(num_samples, available_samples)
            print(f"Using {actual_samples} samples out of {available_samples} available")

            test_items = test_dataset.select(range(actual_samples))

            # Extract texts and create prompts
            test_texts = []
            test_prompts = []

            for item in test_items:
                try:
                    text = item['text']
                    if isinstance(text, str) and len(text.strip()) > 0:
                        test_texts.append(text.strip())
                        # Create prompt from first 200 characters
                        prompt = text[:200].strip()
                        if not prompt.endswith(('...', '.', '!', '?')):
                            prompt += "..."
                        test_prompts.append(prompt)
                except Exception as e:
                    print(f"Warning: Error processing dataset item: {e}")
                    continue

            if not test_texts or not test_prompts:
                print("✗ No valid test data found")
                return None

            # Calculate perplexity
            print("\nCalculating perplexity...")
            perplexity = self.calculate_perplexity(test_texts)

            # Generate samples
            print("\nGenerating samples for BLEU/ROUGE evaluation...")
            generated_samples = self.generate_samples(test_prompts, num_samples=len(test_prompts))

            if not generated_samples:
                print("✗ No samples generated")
                return None

            # Extract texts for evaluation
            generated_texts = [sample['generated'] for sample in generated_samples]
            reference_texts = test_texts[:len(generated_samples)]

            # Calculate BLEU and ROUGE
            print("\nCalculating BLEU and ROUGE scores...")
            text_metrics = self.calculate_bleu_rouge(generated_texts, reference_texts, test_prompts)

            # Compile results
            results = {
                'perplexity': float(perplexity) if not np.isinf(perplexity) else 999.0,
                'bleu_score': float(text_metrics['bleu']),
                'rouge1_score': float(text_metrics['rouge1']),
                'rouge2_score': float(text_metrics['rouge2']),
                'rougeL_score': float(text_metrics['rougeL']),
                'samples': generated_samples[:5],  # Save first 5 samples
                'num_samples_evaluated': len(generated_samples),
                'evaluation_timestamp': str(torch.randint(0, 1000000, (1,)).item())  # Simple timestamp
            }

            # Print results
            print("\n" + "="*70)
            print("MODEL EVALUATION RESULTS")
            print("="*70)
            print(f"Perplexity: {results['perplexity']:.4f}")
            print(f"BLEU Score: {results['bleu_score']:.4f}")
            print(f"ROUGE-1 Score: {results['rouge1_score']:.4f}")
            print(f"ROUGE-2 Score: {results['rouge2_score']:.4f}")
            print(f"ROUGE-L Score: {results['rougeL_score']:.4f}")
            print(f"Samples evaluated: {results['num_samples_evaluated']}")
            print("="*70)

            # Save results
            try:
                with open("evaluation_results.json", "w", encoding='utf-8') as f:
                    # Create JSON-serializable version
                    json_results = {
                        'perplexity': results['perplexity'],
                        'bleu_score': results['bleu_score'],
                        'rouge1_score': results['rouge1_score'],
                        'rouge2_score': results['rouge2_score'],
                        'rougeL_score': results['rougeL_score'],
                        'num_samples_evaluated': results['num_samples_evaluated'],
                        'evaluation_timestamp': results['evaluation_timestamp'],
                        'samples': results['samples']
                    }
                    json.dump(json_results, f, indent=2, ensure_ascii=False)
                print("✓ Results saved to evaluation_results.json")
            except Exception as e:
                print(f"Warning: Error saving results: {e}")

            return results

        except Exception as e:
            print(f"✗ Error in benchmark: {e}")
            traceback.print_exc()
            return None

    def plot_evaluation_history(self, checkpoint_dirs=None):
        """Plot evaluation metrics across checkpoints"""
        try:
            if checkpoint_dirs is None:
                checkpoint_dirs = [d for d in os.listdir("./wikipedia_model")
                                 if d.startswith("checkpoint-") and os.path.isdir(f"./wikipedia_model/{d}")]
                checkpoint_dirs.sort(key=lambda x: int(x.split("-")[1]) if x.split("-")[1].isdigit() else 0)

            perplexities = []
            steps = []

            for checkpoint in checkpoint_dirs:
                checkpoint_path = f"./wikipedia_model/{checkpoint}"
                results_file = f"{checkpoint_path}/evaluation_results.json"

                if os.path.exists(results_file):
                    try:
                        with open(results_file, "r", encoding='utf-8') as f:
                            data = json.load(f)
                            perplexity = data.get('perplexity', 0)
                            if isinstance(perplexity, (int, float)) and not np.isinf(perplexity):
                                perplexities.append(perplexity)
                                step = int(checkpoint.split("-")[1]) if checkpoint.split("-")[1].isdigit() else 0
                                steps.append(step)
                    except Exception as e:
                        print(f"Warning: Error reading {results_file}: {e}")

            if perplexities and steps:
                plt.figure(figsize=(12, 8))
                plt.plot(steps, perplexities, marker='o', linewidth=2, markersize=8)
                plt.xlabel('Training Steps', fontsize=12)
                plt.ylabel('Perplexity', fontsize=12)
                plt.title('Model Perplexity Across Training Checkpoints', fontsize=14, fontweight='bold')
                plt.grid(True, alpha=0.3)
                plt.tight_layout()
                plt.savefig('evaluation_history.png', dpi=300, bbox_inches='tight')
                plt.show()
                print("✓ Evaluation history plot saved as 'evaluation_history.png'")
            else:
                print("No valid evaluation results found in checkpoints")

        except Exception as e:
            print(f"Error plotting evaluation history: {e}")
            traceback.print_exc()

def run_robust_evaluation(model_path: str = "./wikipedia_model/final",
                         num_samples: int = 50) -> Optional[Dict[str, Any]]:
    """
    Run robust model evaluation with comprehensive error handling.

    Args:
        model_path: Path to the model directory
        num_samples: Number of samples to evaluate

    Returns:
        Evaluation results dictionary or None if failed
    """
    try:
        print(f"Starting robust evaluation of model: {model_path}")
        evaluator = RobustModelEvaluator(model_path)
        results = evaluator.benchmark_model(num_samples=num_samples)

        if results:
            # Plot evaluation history
            evaluator.plot_evaluation_history()

        return results

    except Exception as e:
        print(f"✗ Evaluation failed: {e}")
        traceback.print_exc()
        return None

# Uncomment to run evaluation
if __name__ == "__main__":
    results = run_robust_evaluation()

## Data Analysis

Analyze your training dataset statistics, distribution, and quality metrics to understand your model's training data.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_from_disk
import numpy as np
from collections import Counter
import re
from tqdm import tqdm
import pandas as pd
from wordcloud import WordCloud
import json
import os

class DataAnalyzer:
    """Comprehensive dataset analysis tools"""
    
    def __init__(self, dataset_path="data/formatted_wikipedia"):
        try:
            self.dataset = load_from_disk(dataset_path)
            print(f"Loaded dataset with {len(self.dataset)} splits")
            
            # Debug: Check dataset structure
            if len(self.dataset) > 0:
                first_split = list(self.dataset.keys())[0]
                first_item = self.dataset[first_split][0]
                print(f"First item type: {type(first_item)}")
                if isinstance(first_item, dict):
                    print(f"First item keys: {list(first_item.keys())}")
                else:
                    print(f"First item (first 100 chars): {str(first_item)[:100]}")
        except Exception as e:
            print(f"Error loading dataset: {e}")
            self.dataset = None
    
    def basic_statistics(self):
        """Calculate basic dataset statistics"""
        if self.dataset is None:
            return None
        
        stats = {}
        
        for split_name, split_data in self.dataset.items():
            # Handle both dict format {'text': '...'} and string format
            try:
                # Try dict format first
                texts = [item['text'] for item in split_data]
            except (TypeError, KeyError):
                # If that fails, assume items are strings directly
                texts = [item for item in split_data if isinstance(item, str)]
            
            # Text lengths
            text_lengths = [len(text.split()) for text in texts]
            
            stats[split_name] = {
                'num_samples': len(texts),
                'avg_words': np.mean(text_lengths),
                'min_words': np.min(text_lengths),
                'max_words': np.max(text_lengths),
                'median_words': np.median(text_lengths),
                'total_words': sum(text_lengths),
                'avg_chars': np.mean([len(text) for text in texts])
            }
        
        # Print statistics
        print("\n" + "="*60)
        print("DATASET STATISTICS")
        print("="*60)
        
        for split, stat in stats.items():
            print(f"\n{split.upper()} SPLIT:")
            print(f"  Samples: {stat['num_samples']:,}")
            print(f"  Avg words per sample: {stat['avg_words']:.1f}")
            print(f"  Word count range: {stat['min_words']}-{stat['max_words']}")
            print(f"  Median words: {stat['median_words']:.1f}")
            print(f"  Total words: {stat['total_words']:,}")
            print(f"  Avg characters per sample: {stat['avg_chars']:.1f}")
        
        return stats
    
    def plot_distributions(self, stats):
        """Plot data distributions"""
        if stats is None:
            return
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle('Dataset Analysis', fontsize=16)
        
        splits = list(stats.keys())
        colors = ['blue', 'green', 'red', 'orange']
        
        # Sample counts
        axes[0, 0].bar(splits, [stats[s]['num_samples'] for s in splits], color=colors[:len(splits)])
        axes[0, 0].set_title('Number of Samples per Split')
        axes[0, 0].set_ylabel('Count')
        
        # Average words
        axes[0, 1].bar(splits, [stats[s]['avg_words'] for s in splits], color=colors[:len(splits)])
        axes[0, 1].set_title('Average Words per Sample')
        axes[0, 1].set_ylabel('Words')
        
        # Word length distributions
        for i, split in enumerate(splits):
            # Handle both dict format {'text': '...'} and string format
            try:
                # Try dict format first
                texts = [item['text'] for item in self.dataset[split]]
            except (TypeError, KeyError):
                # If that fails, assume items are strings directly
                texts = [item for item in self.dataset[split] if isinstance(item, str)]
            lengths = [len(text.split()) for text in texts[:1000]]  # Sample for plotting
            axes[1, 0].hist(lengths, alpha=0.7, label=split, bins=30, color=colors[i])
        
        axes[1, 0].set_title('Word Length Distribution (Sample)')
        axes[1, 0].set_xlabel('Words per Sample')
        axes[1, 0].set_ylabel('Frequency')
        axes[1, 0].legend()
        
        # Character length distributions
        for i, split in enumerate(splits):
            # Handle both dict format {'text': '...'} and string format
            try:
                # Try dict format first
                texts = [item['text'] for item in self.dataset[split]]
            except (TypeError, KeyError):
                # If that fails, assume items are strings directly
                texts = [item for item in self.dataset[split] if isinstance(item, str)]
            lengths = [len(text) for text in texts[:1000]]  # Sample for plotting
            axes[1, 1].hist(lengths, alpha=0.7, label=split, bins=30, color=colors[i])
        
        axes[1, 1].set_title('Character Length Distribution (Sample)')
        axes[1, 1].set_xlabel('Characters per Sample')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].legend()
        
        plt.tight_layout()
        plt.savefig('data_analysis/dataset_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def analyze_vocabulary(self, num_top_words=20):
        """Analyze vocabulary and word frequencies"""
        if self.dataset is None:
            return None
        
        print("\nAnalyzing vocabulary...")
        
        # Combine all texts
        all_texts = []
        for split_data in self.dataset.values():
            # Handle both dict format {'text': '...'} and string format
            try:
                # Try dict format first
                texts = [item['text'] for item in split_data]
            except (TypeError, KeyError):
                # If that fails, assume items are strings directly
                texts = [item for item in split_data if isinstance(item, str)]
            all_texts.extend(texts)
        
        # Tokenize and count words
        word_counts = Counter()
        
        for text in tqdm(all_texts[:5000], desc="Processing texts"):  # Sample for speed
            # Simple tokenization (split on whitespace and remove punctuation)
            words = re.findall(r'\b\w+\b', text.lower())
            word_counts.update(words)
        
        # Get top words
        top_words = word_counts.most_common(num_top_words)
        
        print(f"\nTop {num_top_words} most frequent words:")
        for word, count in top_words:
            print(f"  {word}: {count:,}")
        
        # Vocabulary size
        vocab_size = len(word_counts)
        print(f"\nTotal unique words (vocabulary size): {vocab_size:,}")
        
        return {
            'top_words': top_words,
            'vocab_size': vocab_size,
            'word_counts': word_counts
        }
    
    def create_wordcloud(self, word_counts, max_words=100):
        """Create word cloud visualization"""
        # Filter out common stop words
        stop_words = {'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those', 'i', 'you', 'he', 'she', 'it', 'we', 'they', 'me', 'him', 'her', 'us', 'them'}
        
        filtered_words = {word: count for word, count in word_counts.items() 
                         if word not in stop_words and len(word) > 2}
        
        # Create word cloud
        wordcloud = WordCloud(
            width=800, 
            height=400, 
            background_color='white',
            max_words=max_words,
            colormap='viridis'
        ).generate_from_frequencies(filtered_words)
        
        # Plot
        plt.figure(figsize=(15, 8))
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.axis('off')
        plt.title('Word Cloud of Training Data', fontsize=16)
        plt.savefig('data_analysis/wordcloud.png', dpi=300, bbox_inches='tight')
        plt.show()
    
    def export_analysis_report(self, stats, vocab_info):
        """Export comprehensive analysis report"""
        report = {
            'dataset_statistics': stats,
            'vocabulary_analysis': {
                'vocab_size': vocab_info['vocab_size'],
                'top_words': vocab_info['top_words'][:10]
            },
            'recommendations': []
        }
        
        # Generate recommendations
        if stats:
            total_samples = sum(s['num_samples'] for s in stats.values())
            if total_samples < 10000:
                report['recommendations'].append("Consider increasing dataset size for better model performance")
            
            avg_words = np.mean([s['avg_words'] for s in stats.values()])
            if avg_words < 50:
                report['recommendations'].append("Texts are quite short - consider longer passages for better context")
        
        if vocab_info and vocab_info['vocab_size'] < 10000:
            report['recommendations'].append("Limited vocabulary - model may struggle with diverse topics")
        
        # Save report
        with open('data_analysis/data_analysis_report.json', 'w') as f:
            json.dump(report, f, indent=2, default=str)
        
        print("\nAnalysis report saved to data_analysis/data_analysis_report.json")
        
        return report

def run_complete_analysis():
    """Run complete dataset analysis"""
    os.makedirs("data_analysis", exist_ok=True)
    analyzer = DataAnalyzer()
    
    # Basic statistics
    stats = analyzer.basic_statistics()
    
    # Plot distributions
    analyzer.plot_distributions(stats)
    
    # Vocabulary analysis
    vocab_info = analyzer.analyze_vocabulary()
    
    # Create word cloud
    if vocab_info:
        analyzer.create_wordcloud(vocab_info['word_counts'])
    
    # Export report
    report = analyzer.export_analysis_report(stats, vocab_info)
    
    print("\nAnalysis complete! Check the generated plots and report in the data_analysis folder.")
    
    return report

if __name__ == "__main__":
    run_complete_analysis()