## 1. Setup and Imports

In [1]:
import os
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from glob import glob
import warnings
warnings.filterwarnings('ignore')

# Add parent directory to path
sys.path.insert(0, str(Path.cwd().parent))
from utils import partial_match, partial_match_scores

## 2. Configuration

In [13]:
# Configure your dataset path here
MODEL_NAME = "llama3.1_8b"  # Change to your model name
BASE_DIR = f"/net/tokyo100-10g/data/str01_01/y-guo/datasets/myriadlama/llama3.1_8b/"

# Check if path exists
if not os.path.exists(BASE_DIR):
    print(f"‚ö†Ô∏è  Warning: Base directory not found: {BASE_DIR}")
    print(f"   Please update BASE_DIR variable")
else:
    print(f"‚úÖ Base directory found: {BASE_DIR}")

‚úÖ Base directory found: /net/tokyo100-10g/data/str01_01/y-guo/datasets/myriadlama/llama3.1_8b/


## 3. Lemmatization Setup

In [14]:
# Import spacy for lemmatization
try:
    import spacy
    nlp = spacy.load("en_core_web_lg")
    print("‚úÖ Spacy loaded successfully")
except ImportError:
    print("‚ö†Ô∏è  Spacy not found. Install with: pip install spacy")
    print("   Then download model: python -m spacy download en_core_web_lg")
    nlp = None
except OSError:
    print("‚ö†Ô∏è  Spacy model not found. Download with: python -m spacy download en_core_web_lg")
    nlp = None

def lemmatize_text(text):
    """Lemmatize a text string."""
    if nlp is None:
        return text.lower().split()
    doc = nlp(text)
    return [token.lemma_.lower() for token in doc]

def lemmatize_answers(answers):
    """Lemmatize a list of answer strings."""
    if isinstance(answers, str):
        # Parse string representation of list
        import ast
        try:
            answers = ast.literal_eval(answers)
        except:
            answers = [answers]
    return [lemmatize_text(ans) for ans in answers]

‚úÖ Spacy loaded successfully


## 4. Discover All Custom Attention Results

In [7]:
def discover_custom_attention_files(base_dir):
    """
    Discover all custom attention result files.
    
    Returns:
        dict: Dictionary mapping file names to file paths
    """
    results = {}
    
    if not os.path.exists(base_dir):
        print(f"‚ùå Base directory not found: {base_dir}")
        return results
    
    # Search for custom attention files (CSV)
    custom_files = glob(os.path.join(base_dir, "myriadlama_custom_*paras.csv"))
    
    for file_path in custom_files:
        file_name = os.path.basename(file_path)
        # Extract number of paraphrases from filename
        # e.g., myriadlama_custom_1paras.csv -> 1
        import re
        match = re.search(r'custom_(\d+)paras', file_name)
        if match:
            num_paras = int(match.group(1))
            results[f"custom_{num_paras}paras"] = file_path
    
    # Also search for baseline and flex attention files for comparison
    baseline_files = glob(os.path.join(base_dir, "*baseline*.csv")) + glob(os.path.join(base_dir, "*baseline*.feather"))
    flex_files = glob(os.path.join(base_dir, "*flex*.csv")) + glob(os.path.join(base_dir, "*flex*.feather"))
    
    for file_path in baseline_files:
        file_name = os.path.basename(file_path).replace('.csv', '').replace('.feather', '')
        results[f"baseline_{file_name}"] = file_path
    
    for file_path in flex_files:
        file_name = os.path.basename(file_path).replace('.csv', '').replace('.feather', '')
        results[f"flex_{file_name}"] = file_path
    
    return results

# Discover all files
all_files = discover_custom_attention_files(BASE_DIR)

print(f"\n{'='*70}")
print("Discovered Files")
print(f"{'='*70}")
print(f"\nTotal files found: {len(all_files)}\n")

# Categorize by type
custom_files = {k: v for k, v in all_files.items() if k.startswith('custom_')}
baseline_files = {k: v for k, v in all_files.items() if k.startswith('baseline_')}
flex_files = {k: v for k, v in all_files.items() if k.startswith('flex_')}

if custom_files:
    print("üéØ Custom Attention Results:")
    for method in sorted(custom_files.keys()):
        print(f"   - {method}")

if baseline_files:
    print("\nüìä Baseline Results:")
    for method in sorted(baseline_files.keys()):
        print(f"   - {method}")

if flex_files:
    print("\n‚ö° FlexAttention Results:")
    for method in sorted(flex_files.keys()):
        print(f"   - {method}")

print(f"\n{'='*70}")


Discovered Files

Total files found: 5

üéØ Custom Attention Results:
   - custom_1paras
   - custom_5paras

üìä Baseline Results:
   - baseline_baseline_per_prompt

‚ö° FlexAttention Results:
   - flex_myriadlama_flex_1paras
   - flex_myriadlama_flex_5paras



## 5. Load Results

In [8]:
def load_file(file_path):
    """
    Load a result file (CSV or Feather).
    """
    if file_path.endswith('.csv'):
        return pd.read_csv(file_path)
    elif file_path.endswith('.feather'):
        return pd.read_feather(file_path)
    else:
        raise ValueError(f"Unsupported file format: {file_path}")

def load_all_files(all_files):
    """
    Load all discovered files.
    """
    loaded_data = {}
    
    print("\nLoading all files...")
    for name, file_path in all_files.items():
        try:
            df = load_file(file_path)
            loaded_data[name] = df
            print(f"‚úÖ Loaded {name}: {len(df)} samples")
        except Exception as e:
            print(f"‚ùå Error loading {name}: {e}")
    
    return loaded_data

# Load all files
loaded_data = load_all_files(all_files)


Loading all files...
‚úÖ Loaded custom_5paras: 2000 samples
‚úÖ Loaded custom_1paras: 2000 samples
‚úÖ Loaded baseline_baseline_per_prompt: 10000 samples
‚úÖ Loaded flex_myriadlama_flex_5paras: 2000 samples
‚úÖ Loaded flex_myriadlama_flex_1paras: 100 samples


## 6. Apply Lemmatization

In [15]:
def apply_lemmatization_to_df(df):
    """
    Apply lemmatization to a DataFrame if not already present.
    """
    if nlp is None:
        print("‚ö†Ô∏è  Spacy not available, skipping lemmatization")
        return df
    
    df = df.copy()
    
    # Lemmatize predictions
    if 'prediction' in df.columns and 'predict_lemma' not in df.columns:
        print("   Lemmatizing predictions...")
        df['predict_lemma'] = df['prediction'].apply(
            lambda x: lemmatize_text(str(x)) if x is not None and str(x).strip() else []
        )
    
    # Lemmatize answers
    if 'answers' in df.columns and 'answer_lemmas' not in df.columns:
        print("   Lemmatizing answers...")
        df['answer_lemmas'] = df['answers'].apply(
            lambda x: lemmatize_answers(x) if x is not None else []
        )
    
    return df

# Apply lemmatization to all loaded data
print("\nApplying lemmatization...")
for name in loaded_data.keys():
    print(f"Processing {name}...")
    loaded_data[name] = apply_lemmatization_to_df(loaded_data[name])

print("\n‚úÖ Lemmatization complete")


Applying lemmatization...
Processing custom_5paras...
Processing custom_1paras...
Processing baseline_baseline_per_prompt...
   Lemmatizing predictions...
   Lemmatizing answers...
   Lemmatizing answers...
Processing flex_myriadlama_flex_5paras...
   Lemmatizing predictions...
Processing flex_myriadlama_flex_5paras...
   Lemmatizing predictions...
   Lemmatizing answers...
   Lemmatizing answers...
Processing flex_myriadlama_flex_1paras...
   Lemmatizing predictions...
Processing flex_myriadlama_flex_1paras...
   Lemmatizing predictions...
   Lemmatizing answers...
   Lemmatizing answers...

‚úÖ Lemmatization complete

‚úÖ Lemmatization complete


## 7. Calculate Accuracies

In [None]:
def calculate_accuracy(df):
    """
    Calculate accuracy for a result DataFrame.
    """
    if 'predict_lemma' not in df.columns or 'answer_lemmas' not in df.columns:
        return None
    
    # Ensure proper format
    df_copy = df.copy()
    
    predictions = df_copy['predict_lemma'].tolist()
    answers = df_copy['answer_lemmas'].tolist()
    
    try:
        accuracy = partial_match_scores(predictions, answers)
        return accuracy
    except Exception as e:
        print(f"Error calculating accuracy: {e}")
        return None

# Calculate accuracies
accuracies = {}

print("\nCalculating accuracies...")
for name, df in loaded_data.items():
    acc = calculate_accuracy(df)
    if acc is not None:
        accuracies[name] = acc
        print(f"‚úÖ {name}: {acc:.4f} ({acc*100:.2f}%)")
    else:
        print(f"‚ö†Ô∏è  {name}: Could not calculate accuracy")


Calculating accuracies...
‚úÖ custom_5paras: 0.5125 (51.25%)
‚úÖ custom_1paras: 0.5860 (58.60%)
‚ö†Ô∏è  baseline_baseline_per_prompt: Could not calculate accuracy
‚ö†Ô∏è  flex_myriadlama_flex_5paras: Could not calculate accuracy
‚ö†Ô∏è  flex_myriadlama_flex_1paras: Could not calculate accuracy


## 8. Generate Comparison Table

In [None]:
def generate_comparison_table(accuracies, loaded_data):
    """
    Generate a comprehensive comparison table.
    """
    comparison_data = []
    
    for method in accuracies.keys():
        df = loaded_data[method]
        acc = accuracies[method]
        
        # Determine category
        if method.startswith('custom_'):
            category = 'Custom Attention'
            # Extract number of paraphrases
            import re
            match = re.search(r'custom_(\d+)paras', method)
            num_paras = int(match.group(1)) if match else 'N/A'
        elif method.startswith('baseline_'):
            category = 'Baseline'
            num_paras = 'N/A'
        elif method.startswith('flex_'):
            category = 'FlexAttention'
            num_paras = 'N/A'
        else:
            category = 'Other'
            num_paras = 'N/A'
        
        comparison_data.append({
            'Method': method,
            'Category': category,
            'Num_Paraphrases': num_paras,
            'Accuracy': acc,
            'Accuracy_Pct': acc * 100,
            'Total_Samples': len(df)
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    # Sort by accuracy (descending)
    comparison_df = comparison_df.sort_values('Accuracy', ascending=False)
    
    return comparison_df

# Generate and display comparison table
if accuracies:
    comparison_table = generate_comparison_table(accuracies, loaded_data)
    
    print("\n" + "="*80)
    print("COMPARISON TABLE")
    print("="*80 + "\n")
    
    display(comparison_table)
    
    # Statistical summary
    print("\n" + "="*80)
    print("STATISTICAL SUMMARY")
    print("="*80 + "\n")
    
    # Best method overall
    best_idx = comparison_table['Accuracy'].idxmax()
    best_method = comparison_table.loc[best_idx]
    print(f"üèÜ Best Method Overall: {best_method['Method']}")
    print(f"   Accuracy: {best_method['Accuracy']:.4f} ({best_method['Accuracy_Pct']:.2f}%)")
    print(f"   Category: {best_method['Category']}")
    
    # Best in each category
    for category in comparison_table['Category'].unique():
        cat_df = comparison_table[comparison_table['Category'] == category]
        if len(cat_df) > 0:
            best_in_cat = cat_df.iloc[0]
            print(f"\nüìä Best {category}: {best_in_cat['Method']}")
            print(f"   Accuracy: {best_in_cat['Accuracy']:.4f} ({best_in_cat['Accuracy_Pct']:.2f}%)")
    
    # Custom attention scaling analysis
    custom_results = comparison_table[comparison_table['Category'] == 'Custom Attention']
    if len(custom_results) > 1:
        print("\n" + "="*80)
        print("CUSTOM ATTENTION SCALING ANALYSIS")
        print("="*80 + "\n")
        
        custom_results_sorted = custom_results.sort_values('Num_Paraphrases')
        print("Accuracy vs Number of Paraphrases:")
        for _, row in custom_results_sorted.iterrows():
            print(f"   {row['Num_Paraphrases']} paraphrases: {row['Accuracy']:.4f} ({row['Accuracy_Pct']:.2f}%)")
else:
    print("\n‚ö†Ô∏è  No accuracies available for comparison")

## 9. Detailed Examples from Custom Attention Results

In [None]:
def show_detailed_examples(method_name, df, num_examples=5, show_correct=True, show_incorrect=True):
    """
    Show detailed examples from a result file.
    """
    print(f"\n{'='*80}")
    print(f"DETAILED EXAMPLES: {method_name}")
    print(f"{'='*80}\n")
    
    # Check if we can determine correctness
    can_check = 'predict_lemma' in df.columns and 'answer_lemmas' in df.columns
    
    if can_check:
        # Add correctness
        df_copy = df.copy()
        correctness = []
        for pred, ans in zip(df_copy['predict_lemma'], df_copy['answer_lemmas']):
            try:
                is_correct = partial_match(pred, ans)
                correctness.append(is_correct)
            except:
                correctness.append(None)
        df_copy['is_correct'] = correctness
        
        # Filter
        if show_correct and show_incorrect:
            filtered_df = df_copy
        elif show_correct:
            filtered_df = df_copy[df_copy['is_correct'] == True]
        elif show_incorrect:
            filtered_df = df_copy[df_copy['is_correct'] == False]
        else:
            filtered_df = df_copy
    else:
        filtered_df = df
    
    # Sample
    sample_df = filtered_df.head(num_examples)
    
    for idx, (i, row) in enumerate(sample_df.iterrows(), 1):
        print(f"\n{'‚îÄ'*80}")
        print(f"Example {idx}/{len(sample_df)}")
        print(f"{'‚îÄ'*80}\n")
        
        if 'uuid' in row:
            print(f"UUID: {row['uuid']}")
        
        if 'templates' in row:
            templates = row['templates']
            if pd.notna(templates):
                print(f"\nTemplates: {templates}")
        
        if 'prediction' in row:
            print(f"\nPrediction: {row['prediction']}")
        
        if 'answers' in row:
            print(f"Correct Answers: {row['answers']}")
        
        if 'predict_lemma' in row:
            print(f"\nPrediction (lemma): {row['predict_lemma']}")
        
        if 'answer_lemmas' in row:
            print(f"Answer Lemmas: {row['answer_lemmas']}")
        
        if can_check and 'is_correct' in row:
            status = "‚úÖ CORRECT" if row['is_correct'] else "‚ùå INCORRECT"
            print(f"\nStatus: {status}")
    
    print(f"\n{'='*80}\n")

# Show examples from custom attention results
if custom_files and loaded_data:
    # Show examples from first custom file
    method_name = list(custom_files.keys())[0]
    if method_name in loaded_data:
        show_detailed_examples(method_name, loaded_data[method_name], num_examples=5)
else:
    print("No custom attention files loaded")

## 10. Cross-Method Comparison

In [None]:
def compare_same_examples(methods, loaded_data, num_examples=3):
    """
    Compare the same examples across different methods.
    """
    print(f"\n{'='*80}")
    print(f"CROSS-METHOD COMPARISON")
    print(f"{'='*80}\n")
    
    # Get common UUIDs
    common_uuids = None
    for method in methods:
        if method not in loaded_data:
            continue
        df = loaded_data[method]
        if 'uuid' not in df.columns:
            print(f"‚ö†Ô∏è  Method {method} does not have UUID column")
            return
        
        uuids = set(df['uuid'].unique())
        if common_uuids is None:
            common_uuids = uuids
        else:
            common_uuids = common_uuids.intersection(uuids)
    
    common_uuids = list(common_uuids)[:num_examples]
    
    for idx, uuid in enumerate(common_uuids, 1):
        print(f"\n{'='*80}")
        print(f"Example {idx}/{len(common_uuids)} - UUID: {uuid}")
        print(f"{'='*80}\n")
        
        for method in methods:
            if method not in loaded_data:
                continue
            
            df = loaded_data[method]
            row = df[df['uuid'] == uuid].iloc[0]
            
            print(f"\n{'‚îÄ'*40}")
            print(f"Method: {method}")
            print(f"{'‚îÄ'*40}\n")
            
            if 'prediction' in row:
                print(f"Prediction: {row['prediction']}")
            
            if 'predict_lemma' in row and 'answer_lemmas' in row:
                try:
                    is_correct = partial_match(row['predict_lemma'], row['answer_lemmas'])
                    status = "‚úÖ CORRECT" if is_correct else "‚ùå INCORRECT"
                    print(f"Status: {status}")
                except:
                    pass
        
        # Show correct answer once
        first_method = methods[0]
        if first_method in loaded_data:
            df = loaded_data[first_method]
            row = df[df['uuid'] == uuid].iloc[0]
            if 'answers' in row:
                print(f"\n{'‚îÄ'*40}")
                print(f"Correct Answers: {row['answers']}")
                print(f"{'‚îÄ'*40}")

# Compare custom attention with baseline/flex
if len(loaded_data) >= 2:
    methods_to_compare = list(loaded_data.keys())[:3]  # Compare first 3 methods
    compare_same_examples(methods_to_compare, loaded_data, num_examples=3)
else:
    print("Need at least 2 methods to compare")

## 11. Summary Statistics

In [None]:
print(f"\n{'='*80}")
print("SUMMARY STATISTICS")
print(f"{'='*80}\n")

print(f"Model: {MODEL_NAME}")
print(f"Base Directory: {BASE_DIR}\n")

print(f"Total files found: {len(all_files)}")
print(f"Successfully loaded: {len(loaded_data)}")
print(f"With accuracy metrics: {len(accuracies)}\n")

if accuracies:
    best_method = max(accuracies.items(), key=lambda x: x[1])
    worst_method = min(accuracies.items(), key=lambda x: x[1])
    avg_accuracy = sum(accuracies.values()) / len(accuracies)
    
    print(f"üèÜ Best: {best_method[0]} ({best_method[1]:.4f})")
    print(f"üìâ Worst: {worst_method[0]} ({worst_method[1]:.4f})")
    print(f"üìä Average: {avg_accuracy:.4f}")
    print(f"üìà Range: {worst_method[1]:.4f} - {best_method[1]:.4f}")

print(f"\n{'='*80}\n")