# Explainer Model Performance Analysis

This notebook analyzes the performance of different explainer models in generating explanations for SAE latent features. We compare models across multiple metrics including accuracy, F1 scores, token usage, and execution time.

## Analysis Overview

- **Accuracy Distribution**: Density plots showing accuracy distribution for each model
- **Mean Performance**: Bar charts comparing mean accuracy across models
- **Token Usage**: Analysis of computational efficiency and resource consumption
- **Performance Summary**: Comprehensive comparison tables

## 1. Setup and Configuration

Import required libraries, define model name mapping for prettier display names, and set up output directories.

In [1]:
import sys
import os
import json
from pathlib import Path
import pandas as pd
import numpy as np

# Add the parent directory to the path to import delphi modules
sys.path.append(str(Path.cwd().parent))

from delphi.log.result_analysis import (
    import_plotly,
    load_data
)

# Import plotly for plotting
px = import_plotly()

# Define model name mapping for prettier display - updated to match actual directory names
MODEL_NAME_MAPPING = {
    "gemma_3_4b_it_quantized_w4a16": "Gemma-3-4B-IT",
    "Qwen3_4B_quantized_w4a16": "Qwen3-4B", 
    "gemma_3_12b_it_quantized_w4a16": "Gemma-3-12B-IT",
    "gemma_3_27b_it_quantized_w4a16": "Gemma-3-27B-IT",
    "Qwen3_14B_quantized_w4a16": "Qwen3-14B",
}

def load_model_results(results_dir: Path, model_mapping: dict):
    """Load results and statistics for all explainer models."""
    model_results = {}
    model_stats = {}
    
    for exp_dir in results_dir.glob("*_explanation_comparison"):
        # Extract model name from directory
        model_key = exp_dir.name.replace("_explanation_comparison", "")
        display_name = model_mapping.get(model_key, model_key)
        
        scores_path = exp_dir / "scores"
        if scores_path.exists():
            # Load scoring results
            try:
                latents_path = exp_dir.parent / "cache_google_gemma-2-9b-it" / "latents"
                if not latents_path.exists():
                    latents_path = exp_dir / "latents"  # Fallback to local latents
                
                if latents_path.exists():
                    # Extract module names from the actual files
                    sample_score_dir = next(scores_path.iterdir())
                    sample_files = list(sample_score_dir.glob("*.txt"))
                    if sample_files:
                        # Extract module name from filename pattern (e.g., "layers.32_latent0.txt" -> "layers.32")
                        sample_filename = sample_files[0].stem
                        module_name = sample_filename.split('_latent')[0]
                        modules = [module_name]
                    else:
                        print(f"No score files found in {sample_score_dir}")
                        continue
                    
                    latent_df, counts = load_data(scores_path, latents_path, modules)
                    
                    # Calculate aggregate metrics similar to get_agg_metrics
                    processed_data = []
                    for score_type in latent_df["score_type"].unique():
                        score_subset = latent_df[latent_df["score_type"] == score_type]
                        
                        # Calculate metrics
                        accuracy = score_subset["correct"].mean()
                        
                        # Calculate F1, precision, recall
                        true_pos = ((score_subset["prediction"] == True) & (score_subset["activating"] == True)).sum()
                        false_pos = ((score_subset["prediction"] == True) & (score_subset["activating"] == False)).sum()
                        false_neg = ((score_subset["prediction"] == False) & (score_subset["activating"] == True)).sum()
                        
                        precision = true_pos / (true_pos + false_pos) if (true_pos + false_pos) > 0 else 0
                        recall = true_pos / (true_pos + false_neg) if (true_pos + false_neg) > 0 else 0
                        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
                        
                        processed_data.append({
                            'score_type': score_type,
                            'accuracy': accuracy,
                            'f1_score': f1_score,
                            'precision': precision,
                            'recall': recall
                        })
                    
                    model_results[display_name] = {
                        'latent_df': latent_df,
                        'processed_df': pd.DataFrame(processed_data),
                        'counts': counts
                    }
                    
            except Exception as e:
                print(f"Error loading results for {model_key}: {e}")
                continue
        
        # Load explainer statistics
        stats_file = exp_dir / "explainer_stats.json"
        if stats_file.exists():
            try:
                with open(stats_file, 'r') as f:
                    stats = json.load(f)
                    model_stats[display_name] = stats
            except Exception as e:
                print(f"Error loading stats for {model_key}: {e}")
                model_stats[display_name] = None
        else:
            model_stats[display_name] = None
    
    return model_results, model_stats

# Set up directories
results_dir = Path.cwd().parent / "results"
visualizations_dir = results_dir / "visualizations"
visualizations_dir.mkdir(exist_ok=True, parents=True)

print(f"Results directory: {results_dir}")
print(f"Visualizations output: {visualizations_dir}")
print(f"Available result directories:")
for d in results_dir.glob("*_explanation_comparison"):
    print(f"  - {d.name}")

Results directory: /home/jeremias/projects/delphi-explanations/results
Visualizations output: /home/jeremias/projects/delphi-explanations/results/visualizations
Available result directories:
  - gemma_3_4b_it_quantized_w4a16_explanation_comparison
  - Qwen3_4B_quantized_w4a16_explanation_comparison
  - gemma_3_12b_it_quantized_w4a16_explanation_comparison
  - gemma_3_27b_it_quantized_w4a16_explanation_comparison
  - Qwen3_14B_quantized_w4a16_explanation_comparison


## 2. Load Model Results

Load explanation comparison results from all models and extract performance metrics.

In [2]:
# Load all model results and statistics
print("Loading model results...")
model_results, model_stats = load_model_results(results_dir, MODEL_NAME_MAPPING)

print(f"\nLoaded results for {len(model_results)} models:")
for model_name in model_results.keys():
    print(f"  - {model_name}")

print(f"\nToken usage statistics available for {len([k for k, v in model_stats.items() if v is not None])} models:")
for model_name, stats in model_stats.items():
    if stats:
        print(f"  - {model_name}: {list(stats.keys())}")
    else:
        print(f"  - {model_name}: No stats available")

# Display sample metrics for the first model
if model_results:
    sample_model = list(model_results.keys())[0]
    sample_data = model_results[sample_model]['processed_df']
    print(f"\nSample metrics from {sample_model}:")
    print(sample_data[['score_type', 'accuracy', 'f1_score', 'precision', 'recall']].round(3))

Loading model results...

Loaded results for 5 models:
  - Gemma-3-4B-IT
  - Qwen3-4B
  - Gemma-3-12B-IT
  - Gemma-3-27B-IT
  - Qwen3-14B

Token usage statistics available for 5 models:
  - Gemma-3-4B-IT: ['DefaultExplainer']
  - Qwen3-4B: ['DefaultExplainer']
  - Gemma-3-12B-IT: ['DefaultExplainer']
  - Gemma-3-27B-IT: ['DefaultExplainer']
  - Qwen3-14B: ['DefaultExplainer']

Sample metrics from Gemma-3-4B-IT:
  score_type  accuracy  f1_score  precision  recall
0       fuzz     0.501     0.624      0.500   0.828
1  detection     0.539     0.581      0.531   0.641

Loaded results for 5 models:
  - Gemma-3-4B-IT
  - Qwen3-4B
  - Gemma-3-12B-IT
  - Gemma-3-27B-IT
  - Qwen3-14B

Token usage statistics available for 5 models:
  - Gemma-3-4B-IT: ['DefaultExplainer']
  - Qwen3-4B: ['DefaultExplainer']
  - Gemma-3-12B-IT: ['DefaultExplainer']
  - Gemma-3-27B-IT: ['DefaultExplainer']
  - Qwen3-14B: ['DefaultExplainer']

Sample metrics from Gemma-3-4B-IT:
  score_type  accuracy  f1_score  preci

## 3. Generate Accuracy Distribution Plots

Create density plots showing accuracy distribution for each model and score type.

In [3]:
# Generate inline density plots for accuracy distribution and save to files
print("Generating accuracy distribution plots...")

# Display the data being plotted
all_data = []
for model_name, data in model_results.items():
    for score_type in data['processed_df']['score_type'].unique():
        score_data = data['processed_df'][data['processed_df']['score_type'] == score_type].iloc[0]
        all_data.append({
            'model': model_name,  # model_name is already mapped to display name
            'score_type': score_type,
            'accuracy': score_data['accuracy'],
            'f1_score': score_data['f1_score'],
            'precision': score_data['precision'],
            'recall': score_data['recall']
        })

accuracy_df = pd.DataFrame(all_data)
print("\nAccuracy data summary:")
print(accuracy_df.groupby('score_type')[['accuracy', 'f1_score']].describe().round(3))

# Create violin plots for each score type - both inline and saved
for score_type in accuracy_df['score_type'].unique():
    score_df = accuracy_df[accuracy_df['score_type'] == score_type]
    
    # Create inline plot
    fig = px.violin(
        score_df, 
        x='model', 
        y='accuracy',
        title=f'Accuracy Distribution by Model - {score_type.title()}',
        points="all"
    )
    fig.update_layout(
        yaxis_range=[0, 1],
        xaxis_title="Model",
        yaxis_title="Accuracy",
        xaxis={'tickangle': 45},
        height=500
    )
    
    # Show inline
    fig.show()
    
    # Save to file (PDF only)
    output_file = visualizations_dir / f"accuracy_density_{score_type}.pdf"
    fig.write_image(str(output_file))
    print(f"Saved density plot: {output_file}")

print(f"\nDensity plots saved to {visualizations_dir}")

Generating accuracy distribution plots...

Accuracy data summary:
           accuracy                                                  f1_score  \
              count   mean    std    min    25%    50%    75%    max    count   
score_type                                                                      
detection       5.0  0.595  0.035  0.539  0.579  0.612  0.619  0.623      5.0   
fuzz            5.0  0.591  0.059  0.501  0.561  0.620  0.636  0.637      5.0   

                                                             
             mean    std    min    25%    50%    75%    max  
score_type                                                   
detection   0.629  0.036  0.581  0.601  0.641  0.651  0.669  
fuzz        0.606  0.059  0.511  0.598  0.624  0.629  0.668  




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




Saved density plot: /home/jeremias/projects/delphi-explanations/results/visualizations/accuracy_density_fuzz.pdf


Saved density plot: /home/jeremias/projects/delphi-explanations/results/visualizations/accuracy_density_detection.pdf

Density plots saved to /home/jeremias/projects/delphi-explanations/results/visualizations




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




## 4. Generate Mean Accuracy Bar Charts

Create bar charts displaying mean accuracy for each model and score type.

In [4]:
# Generate inline bar charts for mean accuracy and save to files
print("Generating mean accuracy bar charts...")

# Create inline bar charts for each score type
for score_type in accuracy_df['score_type'].unique():
    score_df = accuracy_df[accuracy_df['score_type'] == score_type].sort_values('accuracy', ascending=False)
    
    # Create inline bar chart
    fig = px.bar(
        score_df,
        x='model',
        y='accuracy',
        title=f'Mean Accuracy by Model - {score_type.title()}',
        text='accuracy'
    )
    fig.update_layout(
        yaxis_range=[0, 1],
        xaxis_title="Model",
        yaxis_title="Accuracy",
        xaxis={'tickangle': 45},
        height=500
    )
    fig.update_traces(texttemplate='%{text:.3f}', textposition='outside')
    
    # Show inline
    fig.show()
    
    # Save to file (PDF only)
    output_file = visualizations_dir / f"accuracy_bar_{score_type}.pdf"
    fig.write_image(str(output_file))
    print(f"Saved bar chart: {output_file}")

# Display accuracy rankings
print("\nModel accuracy rankings:")
for score_type in accuracy_df['score_type'].unique():
    score_df = accuracy_df[accuracy_df['score_type'] == score_type].sort_values('accuracy', ascending=False)
    print(f"\n{score_type.title()} Accuracy Rankings:")
    for i, (_, row) in enumerate(score_df.iterrows(), 1):
        print(f"  {i}. {row['model']}: {row['accuracy']:.3f}")

print(f"\nBar charts saved to {visualizations_dir}")

Generating mean accuracy bar charts...


Saved bar chart: /home/jeremias/projects/delphi-explanations/results/visualizations/accuracy_bar_fuzz.pdf




Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).






Support for Kaleido versions less than 1.0.0 is deprecated and will be removed after September 2025.
Please upgrade Kaleido to version 1.0.0 or greater (`pip install 'kaleido>=1.0.0'` or `pip install 'plotly[kaleido]'`).




Saved bar chart: /home/jeremias/projects/delphi-explanations/results/visualizations/accuracy_bar_detection.pdf

Model accuracy rankings:

Fuzz Accuracy Rankings:
  1. Qwen3-14B: 0.637
  2. Gemma-3-12B-IT: 0.636
  3. Gemma-3-27B-IT: 0.620
  4. Qwen3-4B: 0.561
  5. Gemma-3-4B-IT: 0.501

Detection Accuracy Rankings:
  1. Gemma-3-27B-IT: 0.623
  2. Qwen3-14B: 0.619
  3. Gemma-3-12B-IT: 0.612
  4. Qwen3-4B: 0.579
  5. Gemma-3-4B-IT: 0.539

Bar charts saved to /home/jeremias/projects/delphi-explanations/results/visualizations


## 5. Create Comprehensive Performance Summary

Generate summary tables and statistics comparing model performance across all metrics.

In [5]:
# Create comprehensive performance summary
print("Creating comprehensive performance summary...")

# Create accuracy summary
accuracy_summary = accuracy_df.groupby('model').agg({
    'accuracy': 'mean',
    'f1_score': 'mean',
    'precision': 'mean',
    'recall': 'mean'
}).round(3)

print("\nModel Performance Summary (Accuracy Metrics):")
print("=" * 60)
print(accuracy_summary)

# Save summary to CSV
summary_file = visualizations_dir / "model_accuracy_summary.csv"
accuracy_summary.to_csv(summary_file)
print(f"\nSummary saved to: {summary_file}")

# Best performing models by category
print("\nBest Performing Models by Category:")
print("=" * 50)
print(f"Highest Accuracy: {accuracy_summary['accuracy'].idxmax()} ({accuracy_summary['accuracy'].max():.3f})")
print(f"Highest F1 Score: {accuracy_summary['f1_score'].idxmax()} ({accuracy_summary['f1_score'].max():.3f})")
print(f"Highest Precision: {accuracy_summary['precision'].idxmax()} ({accuracy_summary['precision'].max():.3f})")
print(f"Highest Recall: {accuracy_summary['recall'].idxmax()} ({accuracy_summary['recall'].max():.3f})")

# Create detailed per-score-type analysis
print("\nDetailed Analysis by Score Type:")
print("=" * 40)
for score_type in accuracy_df['score_type'].unique():
    score_subset = accuracy_df[accuracy_df['score_type'] == score_type]
    print(f"\n{score_type.title()} Results:")
    print(score_subset[['model', 'accuracy', 'f1_score']].sort_values('accuracy', ascending=False))

print(f"\nAll visualizations and summaries saved to: {visualizations_dir}")
print("\nGenerated files:")
for file in sorted(visualizations_dir.glob("*")):
    print(f"  - {file.name}")

Creating comprehensive performance summary...

Model Performance Summary (Accuracy Metrics):
                accuracy  f1_score  precision  recall
model                                                
Gemma-3-12B-IT     0.624     0.625      0.630   0.633
Gemma-3-27B-IT     0.622     0.668      0.595   0.762
Gemma-3-4B-IT      0.520     0.602      0.516   0.734
Qwen3-14B          0.628     0.576      0.693   0.530
Qwen3-4B           0.570     0.615      0.558   0.689

Summary saved to: /home/jeremias/projects/delphi-explanations/results/visualizations/model_accuracy_summary.csv

Best Performing Models by Category:
Highest Accuracy: Qwen3-14B (0.628)
Highest F1 Score: Gemma-3-27B-IT (0.668)
Highest Precision: Qwen3-14B (0.693)
Highest Recall: Gemma-3-27B-IT (0.762)

Detailed Analysis by Score Type:

Fuzz Results:
            model  accuracy  f1_score
8       Qwen3-14B  0.636538  0.510532
4  Gemma-3-12B-IT  0.636154  0.598301
6  Gemma-3-27B-IT  0.619955  0.667639
2        Qwen3-4B  0.5607