# DiffusionBERT Model Evaluation

This notebook provides a comprehensive evaluation of the DiffusionBERT model. It includes:
1. Model loading and setup
2. Data preparation
3. Evaluation metrics computation
4. Results visualization
5. Sample generation analysis

## 1. Setup and Imports

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import logging
from datetime import datetime
from tqdm.notebook import tqdm

from transformers import AutoConfig, AutoModel, AutoTokenizer
from models.modeling_diffusion_bert import DiffusionBertForMaskedLM

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Enable CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Configuration

In [None]:
config = {
    # Model settings
    'model_name': 'bert-base-uncased',
    'model_checkpoint_path': 'path/to/your/checkpoint.th',  # Update this
    'max_position_embeddings': 512,
    
    # Data settings
    'word_freq_path': 'word_freq/bert-base-uncased_lm1b.pt',
    'test_data_path': 'data/test.txt',  # Update this
    'max_seq_length': 128,
    'batch_size': 32,
    
    # Evaluation settings
    'num_samples': 1000,
    'temperature': 1.0,
    'top_k': 50,
    'top_p': 0.9,
    
    # Output settings
    'output_dir': 'evaluation_results',
    'save_samples': True
}

## 3. Model Loading and Setup

In [None]:
def load_model_and_tokenizer(config):
    """Load and setup the DiffusionBERT model and tokenizer."""
    try:
        # Register custom model
        AutoConfig.register("diffusion-bert", DiffusionBertForMaskedLM)
        AutoModel.register(DiffusionBertForMaskedLM, "diffusion-bert")
        
        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
        
        # Load config and update parameters
        model_config = AutoConfig.from_pretrained(config['model_name'])
        model_config.max_position_embeddings = config['max_position_embeddings']
        
        # Initialize model
        model = DiffusionBertForMaskedLM(model_config)
        
        # Load checkpoint
        checkpoint = torch.load(config['model_checkpoint_path'], map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        
        # Move to device and set to eval mode
        model = model.to(device)
        model.eval()
        
        return model, tokenizer
        
    except Exception as e:
        logger.error(f"Error loading model and tokenizer: {str(e)}")
        raise

# Load model and tokenizer
model, tokenizer = load_model_and_tokenizer(config)
print("Model and tokenizer loaded successfully")

## 4. Data Preparation

In [None]:
def load_word_frequencies(config):
    """Load and preprocess word frequencies."""
    try:
        word_freq_path = Path(config['word_freq_path'])
        if word_freq_path.suffix == '.pt':
            word_freq = torch.load(word_freq_path)
        else:
            with open(word_freq_path) as f:
                word_freq = json.load(f)
            word_freq = torch.tensor(word_freq)
        
        # Normalize frequencies
        word_freq = word_freq + 1  # Add smoothing
        word_freq = word_freq.log()
        word_freq = word_freq / word_freq.max()
        
        return word_freq.to(device)
        
    except Exception as e:
        logger.error(f"Error loading word frequencies: {str(e)}")
        raise

# Load word frequencies
word_freq = load_word_frequencies(config)
print(f"Word frequencies loaded with shape: {word_freq.shape}")

## 5. Evaluation Metrics

In [None]:
def compute_perplexity(model, input_ids, attention_mask):
    """Compute perplexity for given input."""
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        return torch.exp(outputs.loss).item()

def compute_word_freq_score(generated_ids, word_freq):
    """Compute word frequency score for generated text."""
    scores = word_freq.gather(0, generated_ids.view(-1))
    return scores.mean().item()

def evaluate_samples(model, tokenizer, word_freq, config):
    """Generate and evaluate samples."""
    results = {
        'perplexities': [],
        'word_freq_scores': [],
        'samples': []
    }
    
    for _ in tqdm(range(config['num_samples']), desc="Generating samples"):
        # Generate sample
        input_ids = torch.randint(100, 1000, (1, config['max_seq_length'])).to(device)
        attention_mask = torch.ones_like(input_ids)
        
        with torch.no_grad():
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=config['max_seq_length'],
                temperature=config['temperature'],
                top_k=config['top_k'],
                top_p=config['top_p'],
                do_sample=True
            )
        
        # Compute metrics
        perplexity = compute_perplexity(model, outputs, attention_mask)
        word_freq_score = compute_word_freq_score(outputs, word_freq)
        
        # Store results
        results['perplexities'].append(perplexity)
        results['word_freq_scores'].append(word_freq_score)
        
        if config['save_samples']:
            results['samples'].append({
                'input': tokenizer.decode(input_ids[0]),
                'generated': tokenizer.decode(outputs[0]),
                'perplexity': perplexity,
                'word_freq_score': word_freq_score
            })
    
    return results

## 6. Run Evaluation

In [None]:
# Run evaluation
print("Starting evaluation...")
evaluation_results = evaluate_samples(model, tokenizer, word_freq, config)

# Compute summary statistics
metrics = {
    'avg_perplexity': np.mean(evaluation_results['perplexities']),
    'std_perplexity': np.std(evaluation_results['perplexities']),
    'avg_word_freq_score': np.mean(evaluation_results['word_freq_scores']),
    'std_word_freq_score': np.std(evaluation_results['word_freq_scores'])
}

# Add metrics to results
evaluation_results['metrics'] = metrics

## 7. Visualize Results

In [None]:
# Set up the visualization style
plt.style.use('seaborn')
sns.set_palette("husl")

# Create figure with subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Plot perplexity distribution
sns.histplot(evaluation_results['perplexities'], ax=ax1, kde=True)
ax1.set_title('Perplexity Distribution')
ax1.set_xlabel('Perplexity')
ax1.axvline(metrics['avg_perplexity'], color='r', linestyle='--', 
            label=f'Mean: {metrics["avg_perplexity"]:.2f}')
ax1.legend()

# Plot word frequency score distribution
sns.histplot(evaluation_results['word_freq_scores'], ax=ax2, kde=True)
ax2.set_title('Word Frequency Score Distribution')
ax2.set_xlabel('Word Frequency Score')
ax2.axvline(metrics['avg_word_freq_score'], color='r', linestyle='--',
            label=f'Mean: {metrics["avg_word_freq_score"]:.4f}')
ax2.legend()

plt.tight_layout()
plt.show()

## 8. Sample Analysis

In [None]:
# Display best samples (lowest perplexity)
print("Best Samples (Lowest Perplexity):")
best_samples = sorted(evaluation_results['samples'], key=lambda x: x['perplexity'])[:3]
for i, sample in enumerate(best_samples, 1):
    print(f"\nSample {i}:")
    print(f"Generated: {sample['generated']}")
    print(f"Perplexity: {sample['perplexity']:.2f}")
    print(f"Word Freq Score: {sample['word_freq_score']:.4f}")

# Display samples with highest word frequency scores
print("\nSamples with Highest Word Frequency Scores:")
best_freq_samples = sorted(evaluation_results['samples'], key=lambda x: x['word_freq_score'], reverse=True)[:3]
for i, sample in enumerate(best_freq_samples, 1):
    print(f"\nSample {i}:")
    print(f"Generated: {sample['generated']}")
    print(f"Perplexity: {sample['perplexity']:.2f}")
    print(f"Word Freq Score: {sample['word_freq_score']:.4f}")

## 9. Save Results

In [None]:
def save_results(results, config):
    """Save evaluation results to file."""
    try:
        # Create output directory
        output_dir = Path(config['output_dir'])
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Add timestamp
        results['timestamp'] = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
        results['config'] = config
        
        # Save results
        output_file = output_dir / f"eval_results_{results['timestamp']}.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
            
        print(f"Results saved to {output_file}")
        
    except Exception as e:
        logger.error(f"Error saving results: {str(e)}")
        raise

# Save results
save_results(evaluation_results, config)