# Setup

In [None]:
%matplotlib
import os
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
import numpy as np

from dataclasses import dataclass
from typing import Dict, List, Tuple

# Themeing
sns.set_theme(style="whitegrid")
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Computer Modern', 'DejaVu Serif', 'serif'],
    'mathtext.fontset': 'cm',
    'axes.formatter.use_mathtext': True,
})

plt.rcParams.update({
    # Font Sizes (ICML template uses 10pt)
    "font.size": 8,
    "axes.titlesize": 8,
    "axes.labelsize": 7,
    "legend.fontsize": 7,
    "xtick.labelsize": 7,
    "ytick.labelsize": 7,

    "axes.linewidth": 0.5,   # Plot Border
    "patch.linewidth": 0.5,  # Bar Border
    "grid.linewidth": 0.5,
    "xtick.major.pad": 0,
    "ytick.major.pad": 0,
    "xtick.minor.pad": 0,
    "ytick.minor.pad": 0,
    "hatch.linewidth": 0.5,

    "lines.linewidth": 1,
    "lines.markersize": 5,
    "lines.markeredgewidth": 0.5,
    "lines.markeredgecolor": "white",

    "figure.dpi": 300 # 1500,
})

In [None]:
DOUBLE_COLUMN_WIDTH = 6.75133
SINGLE_COLUMN_WIDTH = 3.25063

GRID_ALPHA = 0.4
ACQUISITION_ORDER = ['Random', 'UltraFeedback', 'MaxMin', 'DeltaQwen', 'DeltaUCB', 'DRTS', 'InfoMax', 'DTS', 'MaxMinLCB']
DATASET_ORDER = ["UltraFeedback", "Skywork", "Combined", "Tulu 3"]

BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2', 'rewardbench_2']
DOWNSTREAM_BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']
RM_BENCHMARKS = ["rewardbench_2"]

@dataclass
class AcquisitionStyle:
    marker: str
    hatch: str
    color: str
    dashes: Tuple[int, ...] | None

HATCH_MULTIPLIER = 2
ACQUISITION_STYLES = {
    'Random': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#a63f3f', dashes=None),
    'UltraFeedback': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#cb4d4d', dashes=(5, 2)),
    'MaxMin': AcquisitionStyle(marker='^', hatch='\\' * HATCH_MULTIPLIER, color='#e06c6c', dashes=(2, 2)),
    'DeltaQwen': AcquisitionStyle(marker='D', hatch='x' * HATCH_MULTIPLIER, color='#ef8f8f', dashes=(5, 2, 2, 2)),
    'DeltaUCB': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3f3fa6', dashes=None),
    'DRTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4d4dcb', dashes=(5, 2)),
    'InfoMax': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3fa63f', dashes=None),
    'DTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4dcb4d', dashes=(5, 2)),
    'MaxMinLCB': AcquisitionStyle(marker='^', hatch='\\' * HATCH_MULTIPLIER, color='#6ce06c', dashes=(2, 2)),
    'Original': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#808080', dashes=(None))
}

# Load Data

In [None]:
if os.path.exists('full_results.csv'):
    print("Loaded full results")
    data = pd.read_csv('full_results.csv', sep=',')
else:
    acquisition_function_mapping = {
        "random": "Random",
        "ultrafeedback": "UltraFeedback",
        "maxmin": "MaxMin",
        "delta_qwen": "DeltaQwen",
        "DeltaUCB": "DeltaUCB",
        "DRTS": "DRTS",
        "InfoMax": "InfoMax",
        "DTS": "DTS",
        "MaxMinLCB": "MaxMinLCB",
    }

    base_model_scores = {
        "gsm8k": 0.758,
        "ifeval": 0.713,
        "truthfulqa": 0.468,
        "alpacaeval_2": 0.083,
        "rewardbench_2": 0.290
    }

    data = pd.read_csv('results.csv', sep=',')

    uf_dpo_sample_efficiency = pd.read_csv("ultrafeedback_dpo_sample_efficiency.csv")
    uf_rm_sample_efficiency = pd.read_csv("ultrafeedback_rm_sample_efficiency.csv")

    uf_sample_efficiency = pd.merge(
        uf_dpo_sample_efficiency,
        uf_rm_sample_efficiency,
        on='Method',
        suffixes=('_dpo', '_rm')
    )

    uf_sample_efficiency = uf_sample_efficiency[uf_sample_efficiency["Method"] != "SFT Base Model"].copy().reset_index(drop=True)

    uf_sample_efficiency = uf_sample_efficiency.rename(columns={
        'Mean_rm': 'rewardbench_2',
        'GSM8K': 'gsm8k',
        'IF Eval': 'ifeval',
        'Truthful QA': 'truthfulqa',
        'Alpaca Eval': 'alpacaeval_2',
    })

    uf_sample_efficiency['num_train_samples'] = uf_sample_efficiency['Method'].apply(lambda x: int(x.split('_')[-1]))
    uf_sample_efficiency['acquisition_function'] = uf_sample_efficiency['Method'].apply(lambda x: acquisition_function_mapping["_".join(x.split('_')[:-1]).split('-')[-1]])
    uf_sample_efficiency['po_algorithm'] = "DPO"
    uf_sample_efficiency['judge'] = "Qwen 3 235B"
    uf_sample_efficiency['dataset'] = "UltraFeedback"

    # Add base model scores at num_train_samples = 0 for sample efficiency plots
    for acq_name in acquisition_function_mapping.values():
        uf_sample_efficiency.loc[len(uf_sample_efficiency)] = {
            'dataset': 'UltraFeedback',
            'judge': 'Qwen 3 235B',
            'acquisition_function': acq_name,
            'po_algorithm': 'DPO',
            'num_train_samples': 0,
            'gsm8k': 0,
            'ifeval': 0,
            'truthfulqa': 0,
            'alpacaeval_2': 0,
            'rewardbench_2': 0
        }

    uf_sample_efficiency = uf_sample_efficiency.drop(columns=['Type_dpo', 'Mean_dpo', 'Type_rm', 'Factuality', 'Focus', 'Math', 'Precise IF', 'Safety', 'Ties', 'Method'])
    uf_sample_efficiency = uf_sample_efficiency[data.columns]
    acq_order = list(acquisition_function_mapping.values())
    uf_sample_efficiency['acq_func_order'] = uf_sample_efficiency['acquisition_function'].apply(lambda x: acq_order.index(x) if x in acq_order else -1)
    uf_sample_efficiency = uf_sample_efficiency.sort_values(by=['acq_func_order', 'num_train_samples']).drop(columns=['acq_func_order']).reset_index(drop=True)
    uf_sample_efficiency.to_csv("ultrafeedback_sample_efficiency.csv", index=False)

    data = pd.concat([data, uf_sample_efficiency], ignore_index=True)
    data = data.drop_duplicates().reset_index(drop=True)

    data = data.assign(
        num_train_samples_null=data['num_train_samples'].isna(),
        dataset_order_idx=data['dataset'].apply(lambda x: DATASET_ORDER.index(x) if x in DATASET_ORDER else len(DATASET_ORDER)),
        acquisition_order_idx=data['acquisition_function'].apply(
            lambda x: ACQUISITION_ORDER.index(x) if x in ACQUISITION_ORDER else len(ACQUISITION_ORDER))
    ).sort_values(
        by=['num_train_samples_null', 'dataset_order_idx', 'po_algorithm', 'acquisition_order_idx', 'num_train_samples'],
        ascending=[False, True, True, True, True]
    ).drop(columns=['num_train_samples_null', 'dataset_order_idx', 'acquisition_order_idx']).reset_index(drop=True)

    data.to_csv("full_results.csv", index=False)

In [None]:
data["rm_mean_score"] = data[RM_BENCHMARKS].mean(axis=1)
data["downstream_mean_score"] = data[DOWNSTREAM_BENCHMARKS].mean(axis=1)

po_algo_ablation_raw_data = data[(data['dataset'] == 'UltraFeedback') & (data['num_train_samples'].isna())].copy()
po_algo_ablation_raw_data.drop(columns=['rewardbench_2'], inplace=True)

dataset_ablation_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]

teaser_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]

sample_efficiency_ultrafeedback_raw_data = data[(data['dataset'] == 'UltraFeedback') & (~data['num_train_samples'].isna())].copy()

# Plots

In [None]:
# ==============================================================================
# Dataset Ablation Plot
# ==============================================================================

dataset_ablation_data = dataset_ablation_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 1.8))

# --- Plot Data ---
# Left: downstream scores
sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='downstream_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    edgecolor="white",
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_left
)
# Right: reward model scores
sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='rm_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_right,
)

# --- Apply Hatches ---
n_hues = len(ACQUISITION_ORDER)
n_groups = len(dataset_ablation_data['dataset'].unique())
for ax in [ax_left, ax_right]:
    for i, bar in enumerate(ax.patches):
        hue_idx = i // n_groups
        if hue_idx < n_hues:
            acq_func = ACQUISITION_ORDER[hue_idx]
            bar.set_hatch(acquisition_hatches[acq_func])

# --- Legend ---
ax_left.get_legend().remove()
ax_right.get_legend().remove()
handles, labels = ax_left.get_legend_handles_labels()
for i, handle in enumerate(handles):
    if i < len(ACQUISITION_ORDER):
        acq_func = ACQUISITION_ORDER[i]
        handle.set_hatch(acquisition_hatches[acq_func] * 2)
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.53, 1.15),
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_left.set_xlabel('(a) Fine-tuned Models')
ax_right.set_xlabel('(b) Reward Models')
ax_left.set_ylabel('Downstream Score $\\Delta$', fontweight="bold")
ax_right.set_ylabel('Reward Model Score $\\Delta$', fontweight="bold")

# --- Axis Ticks ---
for ax in [ax_left, ax_right]:
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')

ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])



# --- Grid ---
for ax in [ax_left, ax_right]:
    ax.grid(alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

# --- Axis Limits ---
ax_left.set_ylim(0, 0.16)
ax_right.set_ylim(0, 0.4)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("dataset_ablation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# PO Algorithm Ablation Plot (with broken y-axis)
# ==============================================================================

po_algo_ablation_data = po_algo_ablation_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_top, ax_bottom) = plt.subplots(2, 1, sharex=True, figsize=(SINGLE_COLUMN_WIDTH, 1.5), gridspec_kw={
    'height_ratios': [8, 1],
    'hspace': 0.1
})

# --- Plot Data ---
for ax in [ax_top, ax_bottom]:
    sns.barplot(
        data=po_algo_ablation_data,
        x='po_algorithm',
        y='downstream_mean_score',
        hue='acquisition_function',
        palette=acquisition_colors,
        order=['DPO', 'IPO', 'SimPO'],
        hue_order=ACQUISITION_ORDER,
        ax=ax
    )
    ax.get_legend().remove()

# --- Apply Hatches ---
n_hues = len(ACQUISITION_ORDER)
n_groups = len(po_algo_ablation_data['po_algorithm'].unique())
for ax in [ax_top, ax_bottom]:
    for i, bar in enumerate(ax.patches):
        hue_idx = i // n_groups
        if hue_idx < n_hues:
            acq_func = ACQUISITION_ORDER[hue_idx]
            bar.set_hatch(acquisition_hatches[acq_func])

# --- Legend ---
handles, labels = ax_top.get_legend_handles_labels()
for i, handle in enumerate(handles):
    if i < len(ACQUISITION_ORDER):
        acq_func = ACQUISITION_ORDER[i]
        handle.set_hatch(acquisition_hatches[acq_func] * 2)
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.51, 1.25),
    ncol=(len(acquisition_colors) - 1) // 3 + 1,
    fontsize=plt.rcParams["ytick.labelsize"],
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_bottom.set_xlabel('')
ax_top.set_ylabel('Downstream Score $\\Delta$', y=0.4, fontweight="bold")
ax_bottom.set_ylabel('')

# --- Axis Ticks ---
ax_top.set_yticks(np.arange(0, 0.25, 0.05))
ax_bottom.set_yticks([-0.25])
for label in ax_bottom.get_xticklabels():
    label.set_fontweight('bold')

# --- Grid ---
for ax in [ax_top, ax_bottom]:
    ax.grid(axis='y', alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

# --- Axis Limits ---
ax_top.set_ylim(-0.01, 0.22)
ax_bottom.set_ylim(-0.30, -0.20)

# --- Broken Axis Styling ---
ax_top.spines['bottom'].set_visible(False)
ax_bottom.spines['top'].set_visible(False)
ax_top.tick_params(bottom=False)
break_kwargs = {
    'marker': [(-1, -0.5), (1, 0.5)],
    'markersize': 6,
    'linestyle': 'none',
    'color': '0.8',
    'mew': plt.rcParams['axes.linewidth'],
    'clip_on': False,
}
ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **break_kwargs)
ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **break_kwargs)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("po_algo_ablation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Sample Efficiency Plot (UltraFeedback)
# ==============================================================================

sample_efficiency_ultrafeedback_data = sample_efficiency_ultrafeedback_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2.5), sharey=False)

# --- Plot Data ---
# Left: downstream scores
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data,
    x='num_train_samples',
    y='downstream_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)
# Right: reward model scores
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data,
    x='num_train_samples',
    y='rm_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# --- Apply Marker Edge Width ---
for ax in [ax_left, ax_right]:
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])

# --- Legend ---
ax_left.get_legend().remove()
ax_right.get_legend().remove()
handles, labels = ax_left.get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.1),
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.text(0.5, -0.35, '(a) Fine-tuned Models', transform=ax_left.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_right.text(0.5, -0.35, '(b) Reward Models', transform=ax_right.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_left.set_ylabel('Downstream Score $\\Delta$', fontweight="bold")
ax_right.set_ylabel('Reward Model Score $\\Delta$', fontweight="bold")

ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])

# Format x-Ticks as '10k', '20k', ... instead of '10000', '20000', ...
def thousands_formatter(x, pos):
    if x >= 1000:
        return f"{int(x/1000):d}k"
    else:
        return f"{int(x):d}"

for ax in [ax_left, ax_right]:
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))

# --- Grid ---
ax_left.grid(axis='y', alpha=GRID_ALPHA)
ax_right.grid(axis='y', alpha=GRID_ALPHA)

# --- Axis Limits ---
ax_left.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].max() * 1.1
)
ax_right.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    sample_efficiency_ultrafeedback_data['rm_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['rm_mean_score'].max() * 1.1
)

# --- Grid ---
for ax in [ax_left, ax_right]:
    ax.grid(alpha=GRID_ALPHA)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("sample_efficiency.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Sample Efficiency Plot (UltraFeedback) - With vs Without AlpacaEval
# ==============================================================================

sample_efficiency_ultrafeedback_data = sample_efficiency_ultrafeedback_raw_data.copy()

# Calculate downstream mean WITHOUT AlpacaEval
sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'] = sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa']].mean(axis=1)

# Calculate downstream mean WITH AlpacaEval (all benchmarks)
sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'] = sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']].mean(axis=1)

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2.5), sharey=False)

# --- Plot Data ---
# Left: downstream scores WITH AlpacaEval
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data,
    x='num_train_samples',
    y='downstream_mean_with_alpaca',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)
# Right: downstream scores WITHOUT AlpacaEval
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data,
    x='num_train_samples',
    y='downstream_mean_no_alpaca',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# --- Apply Marker Edge Width ---
for ax in [ax_left, ax_right]:
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])

# --- Legend ---
ax_left.get_legend().remove()
ax_right.get_legend().remove()
handles, labels = ax_left.get_legend_handles_labels()
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.1),
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.text(0.5, -0.35, '(a) With AlpacaEval', transform=ax_left.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_right.text(0.5, -0.35, '(b) Without AlpacaEval', transform=ax_right.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_left.set_ylabel('Downstream Score $\\Delta$', fontweight="bold")
ax_right.set_ylabel('Downstream Score $\\Delta$', fontweight="bold")

ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])
ax_right.set_yticks([0.00, 0.05, 0.10, 0.15])

# Format x-Ticks as '10k', '20k', ... instead of '10000', '20000', ...
def thousands_formatter(x, pos):
    if x >= 1000:
        return f"{int(x/1000):d}k"
    else:
        return f"{int(x):d}"

for ax in [ax_left, ax_right]:
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))

# --- Grid ---
ax_left.grid(axis='y', alpha=GRID_ALPHA)
ax_right.grid(axis='y', alpha=GRID_ALPHA)

# --- Axis Limits ---
ax_left.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].max() * 1.1
)
ax_right.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].max() * 1.1
)

# --- Grid ---
for ax in [ax_left, ax_right]:
    ax.grid(alpha=GRID_ALPHA)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("sample_efficiency_alpacaeval_comparison.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Teaser Radar Plot (Normalized)
# ==============================================================================

teaser_data = teaser_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}

# Legend display names with symbols
legend_names = {
    'DeltaUCB': r'DeltaUCB$^\dagger$',
    'DRTS': r'DRTS$^\dagger$',
    'DTS': r'DTS$^*$',
}

# --- Data Preparation ---
benchmark_to_label = {
    "truthfulqa": "TruthfulQA",
    "gsm8k": "GSM8K",
    "rewardbench_2": "RewardBench 2",
    "ifeval": "IFEval",
    "alpacaeval_2": "AlpacaEval 2"
}
num_labels = len(benchmark_to_label)

# Extract data for each acquisition function and calculate mean over datasets
teaser_data = teaser_data[teaser_data["acquisition_function"].isin([
    'DeltaUCB',
    'UltraFeedback',
    'DTS',
    'DRTS',
    'DeltaQwen',
])]
teaser_data = (
    teaser_data.groupby(['acquisition_function'])[list(benchmark_to_label.keys())]
    .mean()
)



benchmark_to_limits = {}
for benchmark in benchmark_to_label.keys():
    benchmark_max = teaser_data[benchmark].max()
    benchmark_to_limits[benchmark] = (0.0, float(benchmark_max))
benchmark_to_ticks = {}
for benchmark in benchmark_to_label.keys():
    benchmark_to_ticks[benchmark] = np.linspace(benchmark_to_limits[benchmark][0], benchmark_to_limits[benchmark][1], 5)[1:]

# Normalize now that we have saved the original limits and ticks
teaser_data = teaser_data / teaser_data.max() * 0.9

y_max = 1.05
angle_offset = np.pi / 2 - (2 * np.pi / num_labels) * 1
angles = (np.linspace(0, 2 * np.pi, num_labels, endpoint=False) + angle_offset).tolist()
angles += angles[:1]

# --- Figure Setup ---
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH * 1.1), subplot_kw=dict(polar=True))

# --- Plot Data ---
for acq_func in ACQUISITION_ORDER:
    if acq_func not in teaser_data.index:
        continue

    values_normalized = teaser_data.loc[acq_func, list(benchmark_to_label.keys())].tolist()
    values_closed = values_normalized + [values_normalized[0]]
    color = acquisition_colors[acq_func]
    marker = acquisition_markers[acq_func]
    label = legend_names.get(acq_func, acq_func)  # Use custom name if available

    ax.plot(angles, values_closed, color=color, marker=marker, label=label, zorder=10, mew=0)

# --- Legend ---
ax.legend(
    loc='upper center',
    bbox_to_anchor=(0.5, -0.1),
    ncol=3,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax.set_xticks(angles[:-1])
ax.set_xticklabels([])
label_rotations = {
    'GSM8K': 0,
    'TruthfulQA': -72,
    'RewardBench 2': 72,
    'IFEval': -36,
    'AlpacaEval 2': 36
}
for angle, label in zip(angles[:-1], benchmark_to_label.values()):
    rotation = label_rotations.get(label, 0)
    ax.text(angle, 1.18, label, fontweight='bold',
            ha='center', va='center', rotation=rotation)

# --- Axis Ticks ---
# We have 3 ticks per benchmark (at 1/3, 2/3, and 1.0 of the max value)
num_ticks = 3
tick_positions_normalized = np.linspace(1/3, 1, num_ticks)  # Positions: 0.33, 0.67, 1.0
ax.set_yticks(tick_positions_normalized)
ax.set_yticklabels([])
for angle, benchmark in zip(angles[:-1], benchmark_to_label.keys()):
    tick_values = benchmark_to_ticks[benchmark]
    for tick_pos, tick_val in zip(tick_positions_normalized, tick_values):
        # Place labels along the spoke, slightly offset
        ax.text(angle, tick_pos + 0.02, f'{tick_val:.2f}',
                color='gray', fontsize=plt.rcParams['font.size'] * 0.65,
                ha='center', va='bottom', zorder=1)

# --- Grid ---
ax.yaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)
ax.xaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)

# --- Axis Limits ---
ax.set_ylim(0, y_max)
ax.set_rlim(0, y_max)
ax.spines['polar'].set_visible(False)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("teaser.pdf", format="pdf", bbox_inches="tight")
plt.show()