# Diversity Comparison with BERTThis notebook computes the diversity of paraphrases for each prompt combination using BERT embeddings.We use a modern BERT model to calculate semantic similarity and diversity metrics.## Features- **BERT-based Semantic Diversity**: Uses transformer models to compute semantic similarity between paraphrases- **Multiple Model Support**: Automatically tries modern sentence-transformers models with fallback to standard BERT- **Comprehensive Metrics**: Computes both traditional (Jaccard) and semantic (BERT-based) diversity metrics- **Per-Prompt Analysis**: Calculates diversity for each prompt combination in your dataset## Prerequisites1. Install required packages:   ```bash   pip install -r requirements.txt   ```2. Ensure you have generated diversity data using the appropriate generation script3. Internet access is required for first-time model download from HuggingFace## How to Use1. **Update Data Path**: In the "Configure data paths" cell, adjust the `root` path to point to your dataset2. **Run All Cells**: Execute cells in order. The notebook will:   - Load a BERT model (tries modern models first, falls back to bert-base-uncased)   - Process each diversity file in your dataset   - Calculate BERT-based semantic diversity for paraphrases   - Compare with traditional metrics   - Display results and correlations3. **Review Results**: The final cells display:   - Per-file diversity scores   - Correlation analysis between different metrics   - Top most/least diverse prompt combinations## Understanding the Metrics- **BERT Diversity**: `1 - average_cosine_similarity` between paraphrase embeddings  - Higher values (close to 1.0) = more semantic variation  - Lower values (close to 0.0) = paraphrases are semantically similar- **Jaccard Diversity**: Traditional `OR_matches / AND_matches` ratio- **Consistency Score**: Proportion of predictions that match across paraphrases## Notes- The notebook gracefully handles missing data files- GPU acceleration is used if available- First run may be slow due to model download- Subsequent runs use cached models

In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import torch
from sklearn.metrics.pairwise import cosine_similarity
from utils import partial_match_scores, is_matched_str, partial_match

In [None]:
# Load BERT model for semantic similarity
# Try modern models first, fallback to bert-base-uncased
model_options = [
    'sentence-transformers/all-MiniLM-L6-v2',  # Efficient sentence-transformers model
    'sentence-transformers/all-mpnet-base-v2',  # High quality sentence-transformers
    'bert-base-uncased',  # Standard BERT
]

bert_model = None
tokenizer = None
model_name = None

for model_candidate in model_options:
    try:
        print(f"Trying to load: {model_candidate}")
        tokenizer = AutoTokenizer.from_pretrained(model_candidate)
        bert_model = AutoModel.from_pretrained(model_candidate)
        model_name = model_candidate
        print(f"✓ Successfully loaded: {model_name}")
        break
    except Exception as e:
        print(f"✗ Failed to load {model_candidate}: {str(e)[:100]}")
        continue

if bert_model is None:
    raise RuntimeError("Failed to load any BERT model. Please ensure you have internet access and try again.")

# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
bert_model = bert_model.to(device)
bert_model.eval()

print(f"\nModel loaded successfully on {device}")

In [None]:
def get_bert_embeddings(texts, batch_size=32):
    """
    Get BERT embeddings for a list of texts.
    Uses mean pooling over token embeddings.
    """
    embeddings = []
    
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i+batch_size]
            
            # Tokenize and encode
            encoded = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt'
            ).to(device)
            
            # Get model output
            outputs = bert_model(**encoded)
            
            # Mean pooling
            attention_mask = encoded['attention_mask']
            token_embeddings = outputs.last_hidden_state
            input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
            sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
            batch_embeddings = (sum_embeddings / sum_mask).cpu().numpy()
            
            embeddings.append(batch_embeddings)
    
    return np.vstack(embeddings)

def compute_diversity_score(paraphrases):
    """
    Compute diversity score for a list of paraphrases.
    Higher diversity score means more semantic variation.
    
    Diversity is calculated as 1 - average pairwise cosine similarity.
    """
    if len(paraphrases) < 2:
        return 0.0
    
    # Get embeddings
    embeddings = get_bert_embeddings(paraphrases)
    
    # Compute pairwise cosine similarity
    similarities = cosine_similarity(embeddings)
    
    # Get upper triangle (excluding diagonal) to avoid counting pairs twice
    n = len(paraphrases)
    sum_similarity = 0
    count = 0
    
    for i in range(n):
        for j in range(i+1, n):
            sum_similarity += similarities[i, j]
            count += 1
    
    avg_similarity = sum_similarity / count if count > 0 else 0
    
    # Diversity is inverse of similarity
    diversity = 1.0 - avg_similarity
    
    return diversity

print("BERT embedding and diversity functions defined")

In [None]:
# Configure data paths
# You can change this to point to your specific dataset
root_options = [
    "../datasets/myriadlama/llama3.2_1b_it",
    "../datasets/myriadlama/llama3.2_3b_it",
    "../datasets/myriadlama/llama3.1_8b_it",
    "../datasets/myriadlama/qwen2.5_3b_it",
    "../datasets/myriadlama/qwen2.5_7b_it",
    "datasets/myriadlama/llama3.2_1b_it",
    "datasets/myriadlama/llama3.2_3b_it",
    "datasets/myriadlama/llama3.1_8b_it",
    "datasets/myriadlama/qwen2.5_3b_it",
    "datasets/myriadlama/qwen2.5_7b_it",
]

# Find the first valid root path
root = None
for root_path in root_options:
    if os.path.exists(root_path):
        root = root_path
        print(f"Using data root: {root}")
        break

if root is None:
    print("Warning: No valid data root found. Available paths:")
    for root_path in root_options:
        print(f"  - {root_path} (exists: {os.path.exists(root_path)})")
    print("\nYou may need to adjust the root paths or generate data first.")
    # Use a default path for demonstration
    root = "datasets/myriadlama/llama3.2_1b_it"
    print(f"Using default path: {root}")

In [None]:
# Load confidence data if available
confidence_file = os.path.join(root, "confidence.feather")

if os.path.exists(confidence_file):
    confi_df = pd.read_feather(confidence_file)
    confi_df["sample_lemmas"] = confi_df["sample_lemmas"].apply(lambda xs: [list(x) for x in xs])
    confi_df["answer_lemmas"] = confi_df["answer_lemmas"].apply(lambda xs: [list(x) for x in xs])
    print(f"Loaded confidence data: {len(confi_df)} rows")
    print(f"Columns: {confi_df.columns.tolist()}")
else:
    print(f"Confidence file not found: {confidence_file}")
    confi_df = None

In [None]:
# Analyze diversity for each prompt combination
droot = os.path.join(root, "diversity")

if not os.path.exists(droot):
    print(f"Diversity directory not found: {droot}")
    print("Please generate diversity data first using the appropriate script.")
else:
    print(f"Processing diversity data from: {droot}")
    
    # Storage for metrics
    results = []
    
    ensemble_scores = []
    consistency_scores = []
    or_matches_scores = []
    diversity_scores = []  # Original diversity (Jaccard)
    bert_diversity_scores = []  # New BERT-based semantic diversity
    avg_match_scores = []
    new_ratio_scores = []
    moe_ratio_scores = []
    
    file_list = [fn for fn in os.listdir(droot) if fn.endswith(".feather") and len(fn.split(",")) == 2]
    
    if len(file_list) == 0:
        print("No diversity files found matching the pattern (files with exactly one comma).")
    else:
        print(f"Found {len(file_list)} diversity files to process")
    
    for fn in tqdm(file_list, desc="Processing diversity files"):
        df = pd.read_feather(os.path.join(droot, fn))
        df["answer_lemmas"] = df["answer_lemmas"].apply(lambda xs: [list(x) for x in xs])
        
        # Compute partial match scores
        scores = partial_match_scores(df['predict_lemma'].tolist(), df["answer_lemmas"].tolist())
        
        # Extract paraphrases for diversity calculation
        all_paraphrases = []
        for paraphrases in df["paraphrases"].tolist():
            if isinstance(paraphrases, (list, tuple)):
                all_paraphrases.extend(paraphrases)
        
        # Compute BERT-based diversity for this prompt combination
        if len(all_paraphrases) >= 2:
            bert_diversity = compute_diversity_score(all_paraphrases)
        else:
            bert_diversity = 0.0
        
        bert_diversity_scores.append(bert_diversity)
        
        # Process traditional metrics if confidence data is available
        if confi_df is not None and "paraphrases" in df.columns:
            predict_by_set = [[], [], [], []]
            
            for paraphrases, predict_lemma in zip(df["paraphrases"].tolist(), df['predict_lemma'].tolist()):
                for idx, paraphrase in enumerate(paraphrases):
                    matching_rows = confi_df[confi_df["paraphrase"] == paraphrase]
                    if len(matching_rows) > 0:
                        predict = matching_rows['greedy_lemma'].tolist()[0]
                        predict_by_set[idx].append(predict.tolist())
                    
                predict_by_set[-2].append(predict_lemma)
                if len(matching_rows) > 0:
                    predict_by_set[-1].append(matching_rows['answer_lemmas'].tolist()[0])
            
            # Only proceed if we have valid data
            if len(predict_by_set[0]) > 0:
                consistency_matches = []
                or_matches = []
                and_matches = []
                avg_matches = []
                ensemble_matches = []
                
                for predict1, predict2, ensemble_predict, answer_lemmas in zip(*predict_by_set):
                    match1 = partial_match(predict1, answer_lemmas, birdirectional=False)
                    match2 = partial_match(predict2, answer_lemmas, birdirectional=False)
                    or_matches.append(match1 or match2)
                    and_matches.append(match1 and match2)
                    avg_matches.append(float(int(match1) + int(match2))/2)
                    ensemble_matches.append(partial_match(ensemble_predict, answer_lemmas, birdirectional=False))
                    consistency_matches.append(is_matched_str(predict1, predict2, birdirectional=True))
                
                moe_cnt = 0
                new_cnt = 0
                for and_match, or_match, ensemble_match in zip(and_matches, or_matches, ensemble_matches):
                    if not and_match and or_match and ensemble_match:
                        moe_cnt += 1
                    if not and_match and not or_match and ensemble_match:
                        new_cnt += 1
                
                moe_ratio = moe_cnt / len(and_matches) if len(and_matches) > 0 else 0
                new_ratio = new_cnt / len(and_matches) if len(and_matches) > 0 else 0
                
                moe_ratio_scores.append(moe_ratio)
                new_ratio_scores.append(new_ratio)
                
                ensemble_scores.append(scores)
                consistency_scores.append(sum(consistency_matches) / len(consistency_matches) if len(consistency_matches) > 0 else 0)
                or_matches_scores.append(sum(or_matches) / len(or_matches) if len(or_matches) > 0 else 0)
                
                # Original diversity (Jaccard)
                jaccard_diversity = sum(or_matches) / sum(and_matches) if sum(and_matches) > 0 else 0
                diversity_scores.append(jaccard_diversity)
                
                avg_match_scores.append(sum(avg_matches) / len(avg_matches) if len(avg_matches) > 0 else 0)
                
                # Store result for this file
                results.append({
                    'file': fn,
                    'ensemble_score': scores,
                    'consistency_score': consistency_scores[-1],
                    'or_match_score': or_matches_scores[-1],
                    'jaccard_diversity': jaccard_diversity,
                    'bert_diversity': bert_diversity,
                    'avg_match_score': avg_match_scores[-1],
                    'moe_ratio': moe_ratio,
                    'new_ratio': new_ratio,
                    'num_paraphrases': len(all_paraphrases)
                })
        else:
            # If no confidence data, just store basic info
            results.append({
                'file': fn,
                'bert_diversity': bert_diversity,
                'num_paraphrases': len(all_paraphrases)
            })
    
    # Create results dataframe
    results_df = pd.DataFrame(results)
    print("\n" + "="*80)
    print("DIVERSITY ANALYSIS RESULTS")
    print("="*80)
    print(results_df.to_string())
    print("\n" + "="*80)

In [None]:
# Correlation analysis
if len(ensemble_scores) > 1 and len(bert_diversity_scores) > 0:
    print("\nCORRELATION ANALYSIS")
    print("="*80)
    
    if len(ensemble_scores) == len(bert_diversity_scores):
        print(f"Pearson correlation between ensemble and BERT diversity scores: {pd.Series(ensemble_scores).corr(pd.Series(bert_diversity_scores)):.4f}")
    
    if len(consistency_scores) > 0:
        print(f"Pearson correlation between ensemble and consistency scores: {pd.Series(ensemble_scores).corr(pd.Series(consistency_scores)):.4f}")
    
    if len(or_matches_scores) > 0:
        print(f"Pearson correlation between ensemble and OR match scores: {pd.Series(ensemble_scores).corr(pd.Series(or_matches_scores)):.4f}")
    
    if len(diversity_scores) > 0:
        print(f"Pearson correlation between ensemble and Jaccard diversity scores: {pd.Series(ensemble_scores).corr(pd.Series(diversity_scores)):.4f}")
    
    if len(avg_match_scores) > 0:
        print(f"Pearson correlation between ensemble and avg match scores: {pd.Series(ensemble_scores).corr(pd.Series(avg_match_scores)):.4f}")
    
    if len(moe_ratio_scores) > 0:
        print(f"\nMean of MOE ratio scores: {sum(moe_ratio_scores)/len(moe_ratio_scores):.4f}")
    
    if len(new_ratio_scores) > 0:
        print(f"Mean of new ratio scores: {sum(new_ratio_scores)/len(new_ratio_scores):.4f}")
    
    if len(bert_diversity_scores) > 0:
        print(f"\nMean BERT diversity score: {np.mean(bert_diversity_scores):.4f}")
        print(f"Std BERT diversity score: {np.std(bert_diversity_scores):.4f}")
    
    print("="*80)
else:
    print("\nInsufficient data for correlation analysis.")

In [None]:
# Summary
print("\n" + "="*80)
print("SUMMARY")
print("="*80)
print(f"Total files processed: {len(results)}")
print(f"BERT model used: {model_name}")
print(f"Device: {device}")

if len(results) > 0:
    print(f"\nAverage BERT diversity across all prompt combinations: {results_df['bert_diversity'].mean():.4f}")
    print(f"Min BERT diversity: {results_df['bert_diversity'].min():.4f}")
    print(f"Max BERT diversity: {results_df['bert_diversity'].max():.4f}")
    
    # Show top 5 most diverse prompt combinations
    print("\nTop 5 most diverse prompt combinations (by BERT diversity):")
    top_diverse = results_df.nlargest(5, 'bert_diversity')[['file', 'bert_diversity', 'num_paraphrases']]
    print(top_diverse.to_string(index=False))
    
    # Show top 5 least diverse prompt combinations
    print("\nTop 5 least diverse prompt combinations (by BERT diversity):")
    least_diverse = results_df.nsmallest(5, 'bert_diversity')[['file', 'bert_diversity', 'num_paraphrases']]
    print(least_diverse.to_string(index=False))

print("\n" + "="*80)
print("Analysis complete!")