In [1]:
import matplotlib.pyplot as plt
import numpy as np

In [2]:
k_vals = [1, 3, 5]
contextcite_l_calls = {
    "ContextCite-L (256 calls)": [1.760979, 3.263331, 4.562216],
    "ContextCite-L (128 calls)": [1.484970, 3.161550, 3.791863],
    "ContextCite-L (64 calls)": [1.285550, 2.569343, 3.228789],
    "ContextCite-L (32 calls)": [1.262293, 2.312171, 3.016126]
}
contextcite_calls = {
    "ContextCite (256 calls)": [0.510021, 0.754932, 0.831023],
    "ContextCite (128 calls)": [0.500903, 0.741556, 0.819800],
    "ContextCite (64 calls)":  [0.496182, 0.725651, 0.788392],
    "ContextCite (32 calls)":  [0.490307, 0.676900, 0.719050]
}
# other_methods = {
#     "Average attention": [0.705, 0.853, 0.867],
#     "Similarity": [0.572, 0.735, 0.787],
#     "Gradient ℓ₁-norm": [0.174, 0.407, 0.466],
#     "Leave-one-out": [0.749, 0.845, 0.862]
# }

In [3]:
# all_methods = {**contextcite_l_calls, **contextcite_calls, **other_methods}
all_methods = {**contextcite_l_calls, **contextcite_calls}
colors = {
    "ContextCite-L (256 calls)": "#1f77b4",
    "ContextCite-L (128 calls)": "#1f77b4cc",
    "ContextCite-L (64 calls)": "#1f77b499",
    "ContextCite-L (32 calls)": "#1f77b466",
    "ContextCite (256 calls)": "#2ca02c",
    "ContextCite (128 calls)": "#2ca02ccc",
    "ContextCite (64 calls)": "#2ca02c99",
    "ContextCite (32 calls)": "#2ca02c66",
    "Average attention": "#ff7f0e",
    "Similarity": "#8c564b",
    "Gradient ℓ₁-norm": "#d62728",
    "Leave-one-out": "#9467bd"
}

In [4]:
x = np.arange(len(k_vals))
bar_width = 0.09
fig, axes = plt.subplots(nrows=3, figsize=(12, 12), sharex=True)

handles = []
labels = list(all_methods.keys())

for i, k in enumerate(k_vals):
    ax = axes[i]
    values_at_k = [all_methods[label][i] for label in labels]
    x = np.arange(len(labels))

    bars = ax.bar(x, values_at_k, color=[colors[l] for l in labels])
    
    # ✅ Subplot titles added back
    ax.set_title(f"Top-$k$ = {k}", fontsize=12)
    
    ax.set_ylabel("Log-prob drop", fontsize=10)
    ax.set_ylim(0, 5.0)
    ax.grid(axis='y', linestyle='--', alpha=0.6)
    ax.set_xticks(x)
    ax.set_xticklabels([])  # Still hiding x-axis tick labels

    # Collect legend handles once
    if i == 0:
        handles = bars

axes[-1].set_xlabel("Attribution Method", fontsize=11)

# ❌ No main suptitle
# fig.suptitle("LLaMA-3 8B on TyDiQA: Attribution Comparison by Top-$k$", fontsize=14)

# ✅ Add legend at bottom
fig.legend(handles, labels, loc='lower center', ncol=2, fontsize=9, bbox_to_anchor=(0.5, -0.05))

plt.tight_layout(rect=[0, 0.05, 1, 1])

# Save
output_path = "results/plots/llama3_tydiqa_cc.png"
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()