# FlexAttention Generation Results Analysis

This notebook analyzes the results from FlexAttention-based ensemble generation and compares them with traditional ensemble methods.

## Setup

In [None]:
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add parent directory to path
sys.path.insert(0, '..')
from utils import partial_match, partial_match_scores

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## Configuration

Set the dataset and model to analyze:

In [None]:
# Configuration
DATASET = "myriadlama"  # or "webqa"
MODEL = "qwen2.5_7b_it"  # e.g., "llama3.2_3b_it", "qwen2.5_7b_it"
NUM_PARAPHRASES = 5

# Construct paths
root = f"../datasets/{DATASET}/{MODEL}"
print(f"Dataset root: {root}")
print(f"Exists: {os.path.exists(root)}")

## Load FlexAttention Results

In [None]:
# Load FlexAttention results
flex_file = os.path.join(root, f"flex_attention-{NUM_PARAPHRASES}.feather")

if os.path.exists(flex_file):
    df_flex = pd.read_feather(flex_file)
    print(f"✅ Loaded FlexAttention results: {len(df_flex)} samples")
    print(f"\nColumns: {df_flex.columns.tolist()}")
    print(f"\nFirst few rows:")
    display(df_flex.head())
else:
    print(f"❌ FlexAttention results not found: {flex_file}")
    print(f"   Run: python flex_attention_generate.py --dataset {DATASET} --model {MODEL} --num_paraphrases {NUM_PARAPHRASES}")

## Compute Accuracy

In [None]:
# Check if lemmatized results are available
if "predict_lemma" in df_flex.columns and "answer_lemmas" in df_flex.columns:
    # Process lemmas
    df_flex["answer_lemmas"] = df_flex["answer_lemmas"].apply(
        lambda xs: [list(x) for x in xs] if isinstance(xs, list) else xs
    )
    
    answers = df_flex["answer_lemmas"].tolist()
    predictions = df_flex['predict_lemma'].tolist()
    
    # Compute accuracy
    flex_acc = partial_match_scores(predictions, answers)
    print(f"FlexAttention Accuracy: {flex_acc:.3f}")
    
    # Compute per-sample matches
    matches = [partial_match(pred, ans, False) for pred, ans in zip(predictions, answers)]
    df_flex['correct'] = matches
    
    print(f"Correct predictions: {sum(matches)}/{len(matches)} ({sum(matches)/len(matches)*100:.1f}%)")
else:
    print("⚠️  Lemmatized results not available")
    print("   Run: python flex_attention_generate.py --dataset {DATASET} --lemmaize")

## Sample Generations

Look at some example generations:

In [None]:
# Show some examples
print("Sample Generations:")
print("="*70)

for i in range(min(5, len(df_flex))):
    print(f"\nSample {i+1}:")
    print(f"  UUID: {df_flex.iloc[i]['uuid']}")
    print(f"  Answer: {df_flex.iloc[i]['answers']}")
    print(f"  Prediction: {df_flex.iloc[i]['prediction']}")
    if 'correct' in df_flex.columns:
        print(f"  Correct: {'✓' if df_flex.iloc[i]['correct'] else '✗'}")
    print(f"  Generation: {df_flex.iloc[i]['generation'][:150]}...")
    print("-"*70)

## Compare with Traditional Ensemble Methods

In [None]:
# Load and compare with traditional ensemble methods
methods = ["avg", "max", "weighted_avg", "weighted_max"]
results = {}

for method in methods:
    ensemble_file = os.path.join(root, f"ensemble_{method}-{NUM_PARAPHRASES}.feather")
    
    if os.path.exists(ensemble_file):
        df_ensemble = pd.read_feather(ensemble_file)
        
        if "predict_lemma" in df_ensemble.columns and "answer_lemmas" in df_ensemble.columns:
            df_ensemble["answer_lemmas"] = df_ensemble["answer_lemmas"].apply(
                lambda xs: [list(x) for x in xs] if isinstance(xs, list) else xs
            )
            answers = df_ensemble["answer_lemmas"].tolist()
            predictions = df_ensemble['predict_lemma'].tolist()
            
            acc = partial_match_scores(predictions, answers)
            results[method] = acc
            print(f"{method:15s}: {acc:.3f}")

# Add FlexAttention results
if 'flex_acc' in locals():
    results['flex_attention'] = flex_acc
    print(f"{'flex_attention':15s}: {flex_acc:.3f}")

print(f"\nResults summary:")
for method, acc in sorted(results.items(), key=lambda x: x[1], reverse=True):
    print(f"  {method:20s}: {acc:.3f}")

## Visualization: Method Comparison

In [None]:
# Plot comparison
if results:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    methods_list = list(results.keys())
    accs_list = list(results.values())
    
    colors = ['skyblue' if m != 'flex_attention' else 'orange' for m in methods_list]
    
    bars = ax.bar(methods_list, accs_list, color=colors, alpha=0.8)
    
    # Add value labels on bars
    for bar, acc in zip(bars, accs_list):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{acc:.3f}',
                ha='center', va='bottom')
    
    ax.set_ylabel('Accuracy')
    ax.set_title(f'Ensemble Method Comparison ({NUM_PARAPHRASES} paraphrases)')
    ax.set_ylim([0, max(accs_list) * 1.1])
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()
else:
    print("No results to plot")

## Effect of Number of Paraphrases

Compare FlexAttention performance with different numbers of paraphrases:

In [None]:
# Compare different numbers of paraphrases
paraphrase_results = {}

for n in range(2, 11):
    flex_file = os.path.join(root, f"flex_attention-{n}.feather")
    
    if os.path.exists(flex_file):
        df = pd.read_feather(flex_file)
        
        if "predict_lemma" in df.columns and "answer_lemmas" in df.columns:
            df["answer_lemmas"] = df["answer_lemmas"].apply(
                lambda xs: [list(x) for x in xs] if isinstance(xs, list) else xs
            )
            answers = df["answer_lemmas"].tolist()
            predictions = df['predict_lemma'].tolist()
            
            acc = partial_match_scores(predictions, answers)
            paraphrase_results[n] = acc

if paraphrase_results:
    print("FlexAttention accuracy vs. number of paraphrases:")
    for n, acc in sorted(paraphrase_results.items()):
        print(f"  {n} paraphrases: {acc:.3f}")
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 6))
    ns = list(paraphrase_results.keys())
    accs = list(paraphrase_results.values())
    
    ax.plot(ns, accs, marker='o', linewidth=2, markersize=8)
    ax.set_xlabel('Number of Paraphrases')
    ax.set_ylabel('Accuracy')
    ax.set_title('FlexAttention: Effect of Number of Paraphrases')
    ax.grid(True, alpha=0.3)
    ax.set_xticks(ns)
    plt.tight_layout()
    plt.show()
    
    # Find best
    best_n = max(paraphrase_results.items(), key=lambda x: x[1])
    print(f"\nBest: {best_n[0]} paraphrases with accuracy {best_n[1]:.3f}")
else:
    print("No results for different numbers of paraphrases")
    print("Generate them with: python flex_attention_generate.py --num_paraphrases N")

## Error Analysis

Look at cases where FlexAttention fails:

In [None]:
# Error analysis
if 'correct' in df_flex.columns:
    incorrect = df_flex[df_flex['correct'] == False]
    
    print(f"Incorrect predictions: {len(incorrect)} / {len(df_flex)}")
    print("\nSample incorrect predictions:")
    print("="*70)
    
    for i in range(min(5, len(incorrect))):
        row = incorrect.iloc[i]
        print(f"\nExample {i+1}:")
        print(f"  UUID: {row['uuid']}")
        print(f"  Expected: {row['answers']}")
        print(f"  Predicted: {row['prediction']}")
        print(f"  Generation: {row['generation'][:150]}...")
        print("-"*70)
else:
    print("Correctness information not available")

## Summary Statistics

In [None]:
# Summary
print("="*70)
print("SUMMARY")
print("="*70)
print(f"\nDataset: {DATASET}")
print(f"Model: {MODEL}")
print(f"Number of paraphrases: {NUM_PARAPHRASES}")
print(f"Total samples: {len(df_flex)}")

if 'flex_acc' in locals():
    print(f"\nFlexAttention accuracy: {flex_acc:.3f}")
    
    if results:
        traditional_accs = [acc for method, acc in results.items() if method != 'flex_attention']
        if traditional_accs:
            avg_traditional = sum(traditional_accs) / len(traditional_accs)
            print(f"Average traditional ensemble: {avg_traditional:.3f}")
            improvement = flex_acc - avg_traditional
            print(f"Improvement: {improvement:+.3f} ({improvement/avg_traditional*100:+.1f}%)")

print("\n" + "="*70)