# Model Accuracy Analysis

This notebook analyzes model accuracy statistics across experiments by difficulty and question type (yes/no/unanswerable).


## 1. Setup and Data Loading

First, let's import necessary libraries and load the experiment data.


In [3]:
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path

# Import our refactored functions
from abstainer.src.analysis.experiment_accuracy import (
    ExperimentAnalyzer,
    find_experiment_dirs,
    process_experiment_results
)

# Set plot style
plt.style.use('ggplot')

# Define paths
RESULTS_DIR = "../../results"


In [None]:
# Find all experiment directories
experiment_dirs = find_experiment_dirs(RESULTS_DIR)
print(f"Found {len(experiment_dirs)} experiment directories")
for i, exp_dir in enumerate(experiment_dirs[:5]):
    print(f"  {i+1}. {exp_dir}")
if len(experiment_dirs) > 5:
    print(f"  ... and {len(experiment_dirs) - 5} more")


## 2. Load and Process Experiment Results

Now let's load the results from all experiments and organize them into a DataFrame.


In [None]:
# Create an ExperimentAnalyzer instance and load all experiments
analyzer = ExperimentAnalyzer(RESULTS_DIR)
results_df = analyzer.load_experiments()

# Display basic information
print(f"Loaded {len(results_df)} question results from {len(experiment_dirs)} experiments")
print("\nDataFrame columns:")
print(results_df.columns.tolist())
print("\nSample data:")
results_df.head()


## 3. Analyze Accuracy by Difficulty and Question Type

Now let's analyze the model's accuracy by difficulty level and question type.


In [None]:
# Calculate accuracy statistics
stats = analyzer.calculate_statistics()

print("Overall accuracy:", stats['overall_accuracy'])
print("\nAccuracy by difficulty:")
print(stats['accuracy_by_difficulty'])
print("\nAccuracy by difficulty and question type:")
print(stats['accuracy_by_difficulty_type'])


## 4. Visualization

Let's create visualizations to better understand the accuracy patterns.


In [None]:
# Plot accuracy by difficulty
from abstainer.src.analysis.experiment_accuracy import plot_accuracy_by_difficulty

fig, ax = plot_accuracy_by_difficulty(results_df)
plt.show()

# Plot accuracy by question type and difficulty
from abstainer.src.analysis.experiment_accuracy import plot_accuracy_by_question_type_and_difficulty

fig, ax = plot_accuracy_by_question_type_and_difficulty(results_df)
plt.show()


## 5. Text Summary

Let's generate a concise text summary of the model's performance.


In [None]:
# Generate performance summary
summary = analyzer.generate_summary()
print(summary)


## 6. Additional Analysis: Performance Across Different Forms

Let's also analyze how performance varies across different prompt forms.


In [None]:
# Plot accuracy by form
from abstainer.src.analysis.experiment_accuracy import plot_accuracy_by_form

fig, ax = plot_accuracy_by_form(results_df)
plt.show()

# Find best and worst performing forms
form_accuracy = results_df.groupby(['experiment_name', 'form', 'permutation'])['is_correct'].mean().reset_index()
form_accuracy = form_accuracy.groupby('form')['is_correct'].agg(['mean', 'std', 'count']).reset_index()

best_form = form_accuracy.loc[form_accuracy['mean'].idxmax()]
worst_form = form_accuracy.loc[form_accuracy['mean'].idxmin()]

print(f"Best performing form: {best_form['form']} with {best_form['mean']:.2%} accuracy")
print(f"Worst performing form: {worst_form['form']} with {worst_form['mean']:.2%} accuracy")
print("\nAccuracy by form:")
print(form_accuracy)


## 6.2 Accuracy Variation Across Label Permutations

Let's analyze how the accuracy varies across different label permutations for each form and question type.


In [None]:
# Plot accuracy by form and question type with permutation variation
from abstainer.src.analysis.experiment_accuracy import plot_accuracy_by_form_and_question_type

fig, ax = plot_accuracy_by_form_and_question_type(results_df)
plt.show()

# Prepare data for summary table
if 'base_form' not in results_df.columns:
    results_df['base_form'] = results_df['form'].astype(str).str.extract(r'(V\d+)', expand=False)

permutation_stats = (
    results_df
    .groupby(['base_form', 'permutation', 'question_type'], as_index=False)['is_correct']
    .agg(mean='mean', count='count')
)

form_type_stats = (
    permutation_stats
    .groupby(['base_form', 'question_type'], as_index=False)['mean']
    .agg(avg_accuracy='mean', accuracy_std='std', num_permutations='count')
)

# Display summary table
summary_cols = ['base_form', 'question_type', 'avg_accuracy', 'accuracy_std', 'num_permutations']
print("Accuracy variation across label permutations:")
print(form_type_stats[summary_cols].to_string(index=False))