In [1]:
import matplotlib.pyplot as plt
import numpy as np

In [18]:
k_vals = [1, 3, 5]
contextcite_calls = {
    "ContextCite (256 calls)": [0.732, 0.951, 0.986],
    "ContextCite (128 calls)": [0.728, 0.942, 0.978],
    "ContextCite (64 calls)":  [0.724, 0.928, 0.965],
    "ContextCite (32 calls)":  [0.720, 0.901, 0.945]
}
other_methods = {
    "Average attention": [0.705, 0.853, 0.867],
    "Similarity": [0.572, 0.735, 0.787],
    "Gradient ℓ₁-norm": [0.174, 0.407, 0.466],
    "Leave-one-out": [0.749, 0.845, 0.862]
}

In [19]:
bar_width = 0.09
x = np.arange(len(k_vals))

# Plotting
plt.figure(figsize=(10, 5))

colors = {
    "ContextCite (256 calls)": "#1f77b4", "ContextCite (128 calls)": "#1f77b4cc", "ContextCite (64 calls)": "#1f77b499", "ContextCite (32 calls)": "#1f77b466",
    "Average attention": "#ff7f0e", "Similarity": "#2ca02c", "Gradient ℓ₁-norm": "#d62728", "Leave-one-out": "#9467bd"
}

for idx, (label, scores) in enumerate({**contextcite_calls, **other_methods}.items()):
    offset = (idx - 3.5) * bar_width
    plt.bar(x + offset, scores, width=bar_width, label=label, color=colors[label])

# Axes and labels
plt.xticks(x, [r'$k=1$', r'$k=3$', r'$k=5$'])
plt.ylabel('Log-prob drop\n(more accurate →)', fontsize=12)
plt.title('LLaMA-3 8B on TyDiQA (Long-context QA)', fontsize=14)
plt.ylim(0, 1)
plt.grid(axis='y', linestyle='--', alpha=0.6)

# Move legend to bottom
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=4)

# Save the figure
output_path = "results/llama3_tydiqa.png"
plt.tight_layout()
plt.savefig(output_path, dpi=300, bbox_inches='tight')
plt.close()