# Accuracy vs Steerability Analysis (Figure 9)

In [None]:
import os
import json
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

In [None]:
def get_acc_from_layer(data, layername):
    """Calculate accuracy for a specific intervention layer.
    
    Args:
        data: Dictionary containing intervention results
        layername: Name of the layer (e.g., 'none', 'intervention_10')
    
    Returns:
        Accuracy ratio (correct / total)
    """
    sel_data = data.get(layername, {})
    verdict_dict = {"correct": 0, "wrong": 0, "nonsense": 0}
    
    for elt in sel_data.values():
        v = elt['verdict']
        verdict_dict[v] += 1
    
    total = verdict_dict['correct'] + verdict_dict['wrong'] + verdict_dict['nonsense']
    ratio = verdict_dict['correct'] / total if total > 0 else 0
    return ratio

In [None]:
# Directory containing intervention metadata
metadata_dir = "metadata"

# Dictionary to store baseline accuracy for each model
acc_oneword = {}

# Load baseline accuracy (no intervention) from *_5.json files
for filename in os.listdir(metadata_dir):
    if "_5" in filename and "reasoning" not in filename and "cot" not in filename:
        model_name = filename.replace(".json", "")
        json_path = os.path.join(metadata_dir, filename)
        
        with open(json_path, 'r') as f:
            data = json.load(f)
            acc_oneword[model_name] = get_acc_from_layer(data, "none")

print(f"Loaded accuracy data for {len(acc_oneword)} models")
print(f"Models: {sorted(acc_oneword.keys())}")

In [None]:
# Dictionary to store steerability values for each model
steer_oneword = {model_name: [] for model_name in acc_oneword}

# Calculate steerability (accuracy drop) for each intervention
for filename in os.listdir(metadata_dir):
    if "_5" in filename and "cot" not in filename:
        model_name = filename.replace(".json", "")
        json_path = os.path.join(metadata_dir, filename)
        
        with open(json_path, 'r') as f:
            data = json.load(f)
            
            # Process each intervention layer
            for key_ in list(data.keys()):
                # Skip baseline and noise interventions
                if key_ != "none" and 'noise' not in key_:
                    acc_drop = acc_oneword[model_name] - get_acc_from_layer(data, key_)
                    steer_oneword[model_name].append(acc_drop)

print(f"Calculated steerability for {len(steer_oneword)} models")

In [None]:
# Clean up model names: remove "_5" suffix
acc_oneword_clean = {}
steer_oneword_clean = {}

for key_ in list(acc_oneword.keys()):
    clean_name = key_.replace("_5", "")
    acc_oneword_clean[clean_name] = acc_oneword[key_]
    steer_oneword_clean[clean_name] = steer_oneword[key_]

acc_oneword = acc_oneword_clean
steer_oneword = steer_oneword_clean

print(f"\nFinal models: {sorted(acc_oneword.keys())}")

In [None]:
# Prepare data for plotting
model_names = sorted(acc_oneword.keys())
x_values = []  # Mean steerability
y_values = []  # Baseline accuracy
x_errors = []  # Steerability std dev

colors = plt.cm.get_cmap('rainbow', len(model_names))

for i, model_name in enumerate(model_names):
    y_values.append(acc_oneword[model_name])
    
    steer_vals = steer_oneword.get(model_name, [])
    if steer_vals:
        x_values.append(np.mean(steer_vals))
        x_errors.append(np.std(steer_vals))
    else:
        x_values.append(np.nan)
        x_errors.append(np.nan)

In [None]:
plt.figure(figsize=(5, 8))

# Group models by their prefix (first 4 characters)
model_groups = defaultdict(list)
for i, model_name in enumerate(model_names):
    model_groups[model_name[:4]].append((x_values[i], y_values[i], x_errors[i], model_name, colors(i)))

# Plot points with error bars
for i, model_name in enumerate(model_names):
    plt.errorbar(
        x_values[i],
        y_values[i],
        xerr=x_errors[i],
        fmt='o',
        markersize=8,
        capsize=3,
        label=model_name,
        color=colors(i)
    )

# Connect models within each family with dashed lines
for starting_letter, models_in_group in model_groups.items():
    models_in_group.sort(key=lambda x: x[0])  # Sort by x-value
    x_group = [model[0] for model in models_in_group]
    y_group = [model[1] for model in models_in_group]
    color = models_in_group[0][4]
    plt.plot(x_group, y_group, linestyle='--', color=color, alpha=0.6)

plt.xlabel("Average Steerability (ID belief Δ - noise belief Δ)")
plt.ylabel("Accuracy on COCO-spatial")
plt.title("Accuracy vs. Average Steerability")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

## Summary Statistics

In [None]:
# Print summary table
print(f"{'Model':30s} {'Steerability (mean ± std)':>30s} {'Accuracy':>10s}")
print("-" * 75)

for model_name, x, xerr, y in zip(model_names, x_values, x_errors, y_values):
    steer_str = f"{x:.4f} ± {xerr:.4f}"
    print(f"{model_name:30s} {steer_str:>30s} {y:10.4f}")