# DiffusionBERT Evaluation

This notebook is for evaluating pre-trained DiffusionBERT models.

In [None]:
# Install required packages
!pip install transformers datasets torch tqdm accelerate

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import sys
import torch
import logging
from transformers import BertTokenizer, BertConfig
from tqdm.notebook import tqdm
import json
import numpy as np
from datetime import datetime

# Setup logging
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%m/%d/%Y %H:%M:%S',
    level=logging.INFO
)
logger = logging.getLogger(__name__)

In [None]:
class Config:
    def __init__(self):
        self.model_name = "bert-base-uncased"
        self.checkpoint_path = "/content/drive/MyDrive/DiffusionBERT/checkpoints/diffusion_bert_lm1b_final.pt"
        self.word_freq_path = "/content/drive/MyDrive/DiffusionBERT/word_freqs/word_freq.pt"
        self.output_dir = "/content/drive/MyDrive/DiffusionBERT/evaluation_results"
        self.batch_size = 32
        self.max_seq_length = 128
        self.num_eval_samples = 1000
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)

config = Config()

In [None]:
def load_model_and_tokenizer(config):
    """Load the pre-trained model and tokenizer"""
    try:
        logger.info(f"Loading tokenizer from {config.model_name}")
        tokenizer = BertTokenizer.from_pretrained(config.model_name)
        
        logger.info(f"Loading model from {config.checkpoint_path}")
        checkpoint = torch.load(config.checkpoint_path, map_location=config.device)
        
        # Load model configuration
        model_config = BertConfig.from_pretrained(config.model_name)
        model_config.vocab_size = tokenizer.vocab_size
        
        # Initialize model with config
        from models.modeling_bert import BertForMaskedLM
        model = BertForMaskedLM(model_config)
        
        # Load checkpoint weights
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
            
        model = model.to(config.device)
        model.eval()
        
        return model, tokenizer
    
    except Exception as e:
        logger.error(f"Error loading model and tokenizer: {str(e)}")
        raise

In [None]:
def load_word_frequencies(config):
    """Load word frequencies"""
    try:
        logger.info(f"Loading word frequencies from {config.word_freq_path}")
        word_freq = torch.load(config.word_freq_path, map_location=config.device)
        
        # Preprocess word frequencies
        word_freq = word_freq + 1  # Add smoothing
        word_freq = word_freq.log()
        word_freq = word_freq / word_freq.max()
        
        return word_freq
    
    except Exception as e:
        logger.error(f"Error loading word frequencies: {str(e)}")
        raise

In [None]:
def evaluate_model(model, tokenizer, word_freq, config):
    """Evaluate the model on various metrics"""
    try:
        from dataloader import DiffusionLoader
        loader = DiffusionLoader(tokenizer)
        eval_data = loader.my_load("lm1b", splits=["validation"])[0]
        
        results = {
            'perplexity': [],
            'elbo': [],
            'word_freq_score': [],
            'generation_samples': []
        }
        
        # Evaluation loop
        with torch.no_grad():
            for i, batch in enumerate(tqdm(eval_data, desc="Evaluating")):
                if i >= config.num_eval_samples:
                    break
                    
                input_ids = torch.tensor(batch['input_ids']).unsqueeze(0).to(config.device)
                attention_mask = torch.tensor(batch['attention_mask']).unsqueeze(0).to(config.device)
                
                # Forward pass
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits
                
                # Calculate metrics
                loss = torch.nn.functional.cross_entropy(logits.view(-1, tokenizer.vocab_size), 
                                                        input_ids.view(-1))
                perplexity = torch.exp(loss)
                
                # Word frequency score
                pred_tokens = torch.argmax(logits, dim=-1)
                freq_score = word_freq[pred_tokens].mean()
                
                # Store results
                results['perplexity'].append(perplexity.item())
                results['word_freq_score'].append(freq_score.item())
                
                # Generate sample text
                if i < 10:  # Store first 10 samples
                    input_text = tokenizer.decode(input_ids[0])
                    generated_text = tokenizer.decode(pred_tokens[0])
                    results['generation_samples'].append({
                        'input': input_text,
                        'generated': generated_text
                    })
        
        # Calculate final metrics
        final_results = {
            'avg_perplexity': np.mean(results['perplexity']),
            'std_perplexity': np.std(results['perplexity']),
            'avg_word_freq_score': np.mean(results['word_freq_score']),
            'std_word_freq_score': np.std(results['word_freq_score']),
            'samples': results['generation_samples']
        }
        
        return final_results
    
    except Exception as e:
        logger.error(f"Error during evaluation: {str(e)}")
        raise

In [None]:
def save_results(results, config):
    """Save evaluation results"""
    try:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = os.path.join(config.output_dir, f"eval_results_{timestamp}.json")
        
        # Add configuration to results
        results['config'] = {
            'model_name': config.model_name,
            'checkpoint_path': config.checkpoint_path,
            'batch_size': config.batch_size,
            'max_seq_length': config.max_seq_length,
            'num_eval_samples': config.num_eval_samples
        }
        
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
            
        logger.info(f"Results saved to {output_file}")
        
    except Exception as e:
        logger.error(f"Error saving results: {str(e)}")
        raise

In [None]:
def main():
    try:
        # Load model and tokenizer
        model, tokenizer = load_model_and_tokenizer(config)
        
        # Load word frequencies
        word_freq = load_word_frequencies(config)
        
        # Evaluate model
        results = evaluate_model(model, tokenizer, word_freq, config)
        
        # Save results
        save_results(results, config)
        
        # Print summary
        print("\nEvaluation Results:")
        print(f"Average Perplexity: {results['avg_perplexity']:.2f} ± {results['std_perplexity']:.2f}")
        print(f"Average Word Frequency Score: {results['avg_word_freq_score']:.4f} ± {results['std_word_freq_score']:.4f}")
        
        print("\nSample Generations:")
        for i, sample in enumerate(results['samples'][:3]):
            print(f"\nSample {i+1}:")
            print(f"Input: {sample['input']}")
            print(f"Generated: {sample['generated']}")
            
    except Exception as e:
        logger.error(f"Error in main execution: {str(e)}")
        raise

In [None]:
if __name__ == "__main__":
    main()