# Setup

In [None]:
%matplotlib
import os
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors
import pandas as pd
import numpy as np

from dataclasses import dataclass
from typing import Dict, List, Tuple

# Themeing
sns.set_theme(style="whitegrid")
plt.rcParams.update({
    'font.family': 'serif',
    'font.serif': ['Computer Modern', 'DejaVu Serif', 'serif'],
    'mathtext.fontset': 'cm',
    'axes.formatter.use_mathtext': True,
})

plt.rcParams.update({
    # Font Sizes (ICML template uses 10pt)
    "font.size": 8,
    "axes.titlesize": 8,
    "axes.labelsize": 8,
    "legend.fontsize": 8,
    "xtick.labelsize": 6,
    "ytick.labelsize": 6,

    "axes.linewidth": 0.5,   # Plot Border
    "patch.linewidth": 0.5,  # Bar Border
    "grid.linewidth": 0.5,
    "xtick.major.pad": 0,
    "ytick.major.pad": 0,
    "xtick.minor.pad": 0,
    "ytick.minor.pad": 0,
    "hatch.linewidth": 0.15,

    'legend.borderpad': 0.2,      # Reduce border inside the legend box
    'legend.labelspacing': 0.1,   # Reduce vertical spacing between legend entries

    "lines.linewidth": 1,
    "lines.markersize": 4,
    "lines.markeredgewidth": 0.25,
    "lines.markeredgecolor": "white",

    "figure.dpi": 1500 # 1500, for final plots
})

# Custom formatter for delta values as percentages with +/- signs
def delta_percent_formatter(x, pos):
    """Format delta values as +/- XX% instead of decimal."""
    percent = x * 100
    if percent == 0:
        return "0%"
    elif percent > 0:
        return f"+{percent:.0f}%"
    else:
        return f"{percent:.0f}%"

In [None]:
DOUBLE_COLUMN_WIDTH = 6.75133
SINGLE_COLUMN_WIDTH = 3.25063

GRID_ALPHA = 0.4
ACQUISITION_ORDER = ['Random', 'UltraFeedback', 'MaxMin', 'DeltaQwen', 'DeltaUCB', 'DRTS', 'InfoMax', 'DTS', 'MaxMinLCB']
DATASET_ORDER = ["UltraFeedback", "Skywork", "Combined", "Tulu 3"]

BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2', 'rewardbench_2']
DOWNSTREAM_BENCHMARKS = ['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']
RM_BENCHMARKS = ["rewardbench_2"]

@dataclass
class AcquisitionStyle:
    marker: str
    hatch: str
    color: str
    zorder: int
    dashes: Tuple[int, ...] | None

HATCH_MULTIPLIER = 12
ACQUISITION_STYLES = {
    'Random': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#a63f3f', dashes=None, zorder=2.1), 
    'UltraFeedback': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#cb4d4d', dashes=(4, 2), zorder=2.4),
    'MaxMin': AcquisitionStyle(marker='^', hatch='\\' * HATCH_MULTIPLIER, color='#e06c6c', dashes=(1, 1), zorder=2.7),
    'DeltaQwen': AcquisitionStyle(marker='D', hatch='x' * HATCH_MULTIPLIER, color='#ef8f8f', dashes=(4, 2, 1, 2), zorder=2.9),

    'DeltaUCB': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3f3fa6', dashes=None, zorder=2.3),
    'DRTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4d4dcb', dashes=(4, 2), zorder=2.6),
    
    'InfoMax': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#3fa63f', dashes=None, zorder=2.2),
    'DTS': AcquisitionStyle(marker='s', hatch='/' * HATCH_MULTIPLIER, color='#4dcb4d', dashes=(4, 2), zorder=2.5),
    'MaxMinLCB': AcquisitionStyle(marker='^', hatch='\\' * HATCH_MULTIPLIER, color='#6ce06c', dashes=(1, 1), zorder=2.8), 
    
    'Original': AcquisitionStyle(marker='o', hatch='' * HATCH_MULTIPLIER, color='#808080', dashes=(4, 2), zorder=0)
}

GREEN = "green"
RED = "red"

# Load Data

In [None]:
if os.path.exists('full_results.csv'):
    print("Loaded full results")
    data = pd.read_csv('full_results.csv', sep=',')
else:
    acquisition_function_mapping = {
        "random": "Random",
        "ultrafeedback": "UltraFeedback",
        "maxmin": "MaxMin",
        "delta_qwen": "DeltaQwen",
        "DeltaUCB": "DeltaUCB",
        "DRTS": "DRTS",
        "InfoMax": "InfoMax",
        "DTS": "DTS",
        "MaxMinLCB": "MaxMinLCB",
    }

    base_model_scores = {
        "gsm8k": 0.758,
        "ifeval": 0.713,
        "truthfulqa": 0.468,
        "alpacaeval_2": 0.083,
        "rewardbench_2": 0.290
    }

    data = pd.read_csv('results.csv', sep=',')

    uf_dpo_sample_efficiency = pd.read_csv("ultrafeedback_dpo_sample_efficiency.csv")
    uf_rm_sample_efficiency = pd.read_csv("ultrafeedback_rm_sample_efficiency.csv")

    uf_sample_efficiency = pd.merge(
        uf_dpo_sample_efficiency,
        uf_rm_sample_efficiency,
        on='Method',
        suffixes=('_dpo', '_rm')
    )

    uf_sample_efficiency = uf_sample_efficiency[uf_sample_efficiency["Method"] != "SFT Base Model"].copy().reset_index(drop=True)

    uf_sample_efficiency = uf_sample_efficiency.rename(columns={
        'Mean_rm': 'rewardbench_2',
        'GSM8K': 'gsm8k',
        'IF Eval': 'ifeval',
        'Truthful QA': 'truthfulqa',
        'Alpaca Eval': 'alpacaeval_2',
    })

    uf_sample_efficiency['num_train_samples'] = uf_sample_efficiency['Method'].apply(lambda x: int(x.split('_')[-1]))
    uf_sample_efficiency['acquisition_function'] = uf_sample_efficiency['Method'].apply(lambda x: acquisition_function_mapping["_".join(x.split('_')[:-1]).split('-')[-1]])
    uf_sample_efficiency['po_algorithm'] = "DPO"
    uf_sample_efficiency['judge'] = "Qwen 3 235B"
    uf_sample_efficiency['dataset'] = "UltraFeedback"

    # Add base model scores at num_train_samples = 0 for sample efficiency plots
    for acq_name in acquisition_function_mapping.values():
        uf_sample_efficiency.loc[len(uf_sample_efficiency)] = {
            'dataset': 'UltraFeedback',
            'judge': 'Qwen 3 235B',
            'acquisition_function': acq_name,
            'po_algorithm': 'DPO',
            'num_train_samples': 0,
            'gsm8k': 0,
            'ifeval': 0,
            'truthfulqa': 0,
            'alpacaeval_2': 0,
            'rewardbench_2': 0
        }

    uf_sample_efficiency = uf_sample_efficiency.drop(columns=['Type_dpo', 'Mean_dpo', 'Type_rm', 'Factuality', 'Focus', 'Math', 'Precise IF', 'Safety', 'Ties', 'Method'])
    uf_sample_efficiency = uf_sample_efficiency[data.columns]
    acq_order = list(acquisition_function_mapping.values())
    uf_sample_efficiency['acq_func_order'] = uf_sample_efficiency['acquisition_function'].apply(lambda x: acq_order.index(x) if x in acq_order else -1)
    uf_sample_efficiency = uf_sample_efficiency.sort_values(by=['acq_func_order', 'num_train_samples']).drop(columns=['acq_func_order']).reset_index(drop=True)
    uf_sample_efficiency.to_csv("ultrafeedback_sample_efficiency.csv", index=False)


    ipo_simpo_sample_efficiency = pd.read_csv("final_results4.csv")
    ipo_simpo_sample_efficiency = ipo_simpo_sample_efficiency.rename(columns={
        'GSM8K': 'gsm8k',
        'IF Eval': 'ifeval',
        'Truthful QA': 'truthfulqa',
        'Alpaca Eval': 'alpacaeval_2',
    })
    ipo_simpo_sample_efficiency['rewardbench_2'] = ''
    ipo_simpo_sample_efficiency['dataset'] = 'UltraFeedback'
    ipo_simpo_sample_efficiency['judge'] = 'Qwen 3 235B'
    ipo_simpo_sample_efficiency['num_train_samples'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: int(x.split('_')[-1]))
    ipo_simpo_sample_efficiency['po_algorithm'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: x.split('_')[0])
    ipo_simpo_sample_efficiency['acquisition_function'] = ipo_simpo_sample_efficiency['Method'].apply(lambda x: acquisition_function_mapping['_'.join(x.split('_')[1:-1])])
    ipo_simpo_sample_efficiency = ipo_simpo_sample_efficiency.drop(columns=['Type', 'Mean', 'Method'])

    data = pd.concat([data, ipo_simpo_sample_efficiency], ignore_index=True)
    data = data.drop_duplicates().reset_index(drop=True)

    data = data.assign(
        num_train_samples_null=data['num_train_samples'].isna(),
        dataset_order_idx=data['dataset'].apply(lambda x: DATASET_ORDER.index(x) if x in DATASET_ORDER else len(DATASET_ORDER)),
        acquisition_order_idx=data['acquisition_function'].apply(
            lambda x: ACQUISITION_ORDER.index(x) if x in ACQUISITION_ORDER else len(ACQUISITION_ORDER))
    ).sort_values(
        by=['num_train_samples_null', 'dataset_order_idx', 'po_algorithm', 'acquisition_order_idx', 'num_train_samples'],
        ascending=[False, True, True, True, True]
    ).drop(columns=['num_train_samples_null', 'dataset_order_idx', 'acquisition_order_idx']).reset_index(drop=True)

    data.to_csv("full_results.csv", index=False)

if os.path.exists('dataset_statistics.csv'):
    print("Loaded full results")
    datasets_data = pd.read_csv('dataset_statistics.csv', sep=',')
else:
    datasets_data = pd.read_csv('my_analysis.csv', sep=',')

    datasets_data['training'] = datasets_data['Dataset_Name'].apply(lambda x: x.split('_')[0] if isinstance(x, str) and '_' in x else None)
    datasets_data['acquisition_function'] = datasets_data['Dataset_Name'].apply(lambda x: x.split('_')[1] if isinstance(x, str) and '_' in x else None)

    datasets_data.drop(columns=['Dataset_Name'], inplace=True)
    datasets_data.rename(columns={'Model': 'model'}, inplace=True)

    datasets_data = datasets_data[['training', 'acquisition_function', 'model', 'chosen_count', 'rejected_count']]

    datasets_data.to_csv("dataset_statistics.csv", index=False)

In [None]:
data["rm_mean_score"] = data[RM_BENCHMARKS].mean(axis=1)
data["downstream_mean_score"] = data[DOWNSTREAM_BENCHMARKS].mean(axis=1)

po_algo_ablation_raw_data = data[(data['dataset'] == 'UltraFeedback') & (data['num_train_samples'].isna())].copy()
po_algo_ablation_raw_data.drop(columns=['rewardbench_2'], inplace=True)

dataset_ablation_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]

teaser_raw_data = data[(data['po_algorithm'] == 'DPO') & (data['num_train_samples'].isna())]
sample_efficiency_raw_data = data[(data['dataset'] == 'UltraFeedback') & ((~data['num_train_samples'].isna()) | (data['acquisition_function'] == 'Original')) & (data['po_algorithm'] == 'DPO')].copy()
sample_efficiency_ipo_simpo_raw_data = data[
    (data['dataset'] == 'UltraFeedback') & 
    (((~data['num_train_samples'].isna())) | (data['acquisition_function'] == 'Original')) & 
    ((data['po_algorithm'] == 'IPO') | (data['po_algorithm'] == 'SimPO'))
].copy()

# Plots

In [None]:
# ==============================================================================
# Dataset Ablation Plot
# ==============================================================================

dataset_ablation_data = dataset_ablation_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2))

# --- Plot Data ---
# Left: downstream scores
sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='downstream_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    edgecolor="white",
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_left
)
# Right: reward model scores
sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='rm_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_right,
)

# --- Apply Hatches ---
n_hues = len(ACQUISITION_ORDER)
n_groups = len(dataset_ablation_data['dataset'].unique())
for ax in [ax_left, ax_right]:
    for i, bar in enumerate(ax.patches):
        hue_idx = i // n_groups
        if hue_idx < n_hues:
            acq_func = ACQUISITION_ORDER[hue_idx]
            bar.set_hatch(acquisition_hatches[acq_func])

# --- Legend ---
ax_left.get_legend().remove()
ax_right.get_legend().remove()
handles, labels = ax_left.get_legend_handles_labels()
for i, handle in enumerate(handles):
    if i < len(ACQUISITION_ORDER):
        acq_func = ACQUISITION_ORDER[i]
        handle.set_hatch(acquisition_hatches[acq_func])
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.1),
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_left.set_xlabel('(a) Fine-tuned Models')
ax_right.set_xlabel('(b) Reward Models')
ax_left.set_ylabel('$\\Delta$Score', fontweight="bold")
ax_right.set_ylabel('$\\Delta$Score', fontweight="bold")

# --- Axis Ticks ---
for ax in [ax_left, ax_right]:
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
        label.set_fontsize(plt.rcParams["axes.labelsize"])
        label.set_rotation(20)
        label.set_ha('right')
        

ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])
ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])

# Format y-ticks as +/- XX%
for ax in [ax_left, ax_right]:
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))

# --- Grid ---
for ax in [ax_left, ax_right]:
    ax.grid(alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

# --- Axis Limits ---
ax_left.set_ylim(0, 0.16)
ax_right.set_ylim(0, 0.4)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("dataset_ablation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Dataset Ablation Plot (Split Export)
# ==============================================================================

dataset_ablation_data = dataset_ablation_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}

# Define single plot width (approx half of double column)
SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2
HEIGHT = 1.5

# ==========================================
# 1. Left Plot (Fine-tuned Models)
# ==========================================
fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='downstream_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    edgecolor="white",
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_left
)
ax_left.get_legend().remove() # We export legend separately
ax_left.set_xlabel('')
ax_left.set_ylabel('$\\Delta$Score', fontweight="bold")
ax_left.set_ylim(0, 0.16)
ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])

# ==========================================
# 2. Right Plot (Reward Models)
# ==========================================
fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

sns.barplot(
    data=dataset_ablation_data,
    x='dataset',
    y='rm_mean_score',
    hue='acquisition_function',
    palette=acquisition_colors,
    order=DATASET_ORDER,
    hue_order=ACQUISITION_ORDER + ["Original"],
    ax=ax_right,
)
ax_right.get_legend().remove()
ax_right.set_xlabel('')
ax_right.set_ylabel('$\\Delta$Score', fontweight="bold")
ax_right.set_ylim(0, 0.4)
ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])


# ==========================================
# 3. Shared Formatting (Hatches & Ticks)
# ==========================================
n_hues = len(ACQUISITION_ORDER)
n_groups = len(dataset_ablation_data['dataset'].unique())

for ax in [ax_left, ax_right]:
    # Apply Hatches
    for i, bar in enumerate(ax.patches):
        hue_idx = i // n_groups
        if hue_idx < n_hues:
            acq_func = ACQUISITION_ORDER[hue_idx]
            bar.set_hatch(acquisition_hatches[acq_func])

    # Format Ticks
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
        label.set_fontsize(plt.rcParams["axes.labelsize"] - 2)
        # label.set_rotation(20)
        # label.set_ha('right')
    
    # Format Y-Axis
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))
    
    # Grid
    ax.grid(alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

# ==========================================
# 4. Legend Export
# ==========================================
# Extract handles from the left plot
handles, labels = ax_left.get_legend_handles_labels()

# Apply hatches to the legend handles to match the plot
for i, handle in enumerate(handles):
    if i < len(ACQUISITION_ORDER):
        acq_func = ACQUISITION_ORDER[i]
        handle.set_hatch(acquisition_hatches[acq_func])

# Create a dedicated figure just for the legend
# Width matches the full double column, height is small
fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))

fig_leg.legend(
    handles,
    labels,
    loc='center',
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# ==========================================
# 5. Save All Files
# ==========================================
# Save Left Plot
fig_left.savefig(
    "dataset_ablation/left.pdf", 
    format="pdf", 
    bbox_inches="tight", 
    pad_inches=0.02
)

# Save Right Plot
fig_right.savefig(
    "dataset_ablation/right.pdf", 
    format="pdf", 
    bbox_inches="tight", 
    pad_inches=0.02
)

# Save Legend
fig_leg.savefig(
    "dataset_ablation/legend.pdf", 
    format="pdf", 
    bbox_inches="tight", 
    pad_inches=0.02
)

plt.show()

In [None]:
# ==============================================================================
# PO Algorithm Ablation Plot (with broken y-axis)
# ==============================================================================

po_algo_ablation_data = po_algo_ablation_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_hatches = {k: v.hatch for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_top, ax_bottom) = plt.subplots(2, 1, sharex=True, figsize=(SINGLE_COLUMN_WIDTH, 1.5), gridspec_kw={
    'height_ratios': [8, 1],
    'hspace': 0.1
})

# --- Plot Data ---
for ax in [ax_top, ax_bottom]:
    sns.barplot(
        data=po_algo_ablation_data,
        x='po_algorithm',
        y='downstream_mean_score',
        hue='acquisition_function',
        palette=acquisition_colors,
        order=['DPO', 'IPO', 'SimPO'],
        hue_order=ACQUISITION_ORDER,
        ax=ax
    )
    ax.get_legend().remove()

# --- Apply Hatches ---
n_hues = len(ACQUISITION_ORDER)
n_groups = len(po_algo_ablation_data['po_algorithm'].unique())
for ax in [ax_top, ax_bottom]:
    for i, bar in enumerate(ax.patches):
        hue_idx = i // n_groups
        if hue_idx < n_hues:
            acq_func = ACQUISITION_ORDER[hue_idx]
            bar.set_hatch(acquisition_hatches[acq_func])

# --- Legend ---
handles, labels = ax_top.get_legend_handles_labels()
for i, handle in enumerate(handles):
    if i < len(ACQUISITION_ORDER):
        acq_func = ACQUISITION_ORDER[i]
        handle.set_hatch(acquisition_hatches[acq_func])
fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.45, 1.2),
    ncol=3,
    frameon=False,
    borderpad=0.2,
    columnspacing=0.7,
    handletextpad=0.3,
    handlelength=1.2,
    handleheight=0.7,
)

# --- Axis Labels & Titles ---
ax_bottom.set_xlabel('')
ax_top.set_ylabel('$\\Delta$Score', y=0.4, fontweight="bold")
ax_bottom.set_ylabel('')

# --- Axis Ticks ---
ax_top.set_yticks(np.arange(0, 0.25, 0.05))
ax_bottom.set_yticks([-0.25])
for label in ax_bottom.get_xticklabels():
    label.set_fontweight('bold')
    label.set_fontsize(plt.rcParams["axes.labelsize"])

# Format y-ticks as +/- XX%
for ax in [ax_top, ax_bottom]:
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))

# --- Grid ---
for ax in [ax_top, ax_bottom]:
    ax.grid(axis='y', alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

# --- Axis Limits ---
ax_top.set_ylim(-0.01, 0.22)
ax_bottom.set_ylim(-0.30, -0.20)

# --- Broken Axis Styling ---
ax_top.spines['bottom'].set_visible(False)
ax_bottom.spines['top'].set_visible(False)
ax_top.tick_params(bottom=False)
break_kwargs = {
    'marker': [(-1, -0.5), (1, 0.5)],
    'markersize': 6,
    'linestyle': 'none',
    'color': plt.rcParams['axes.edgecolor'],
    'markeredgecolor': plt.rcParams['axes.edgecolor'],
    'mew': plt.rcParams['axes.linewidth'],
    'clip_on': False,
}
ax_top.plot([0, 1], [0, 0], transform=ax_top.transAxes, **break_kwargs)
ax_bottom.plot([0, 1], [1, 1], transform=ax_bottom.transAxes, **break_kwargs)

# --- Save & Show ---
# plt.tight_layout()
fig.savefig("po_algo_ablation.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Sample Efficiency Plot (UltraFeedback)
# ==============================================================================

sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}
acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}

# --- Figure Setup ---
fig, (ax_left, ax_right) = plt.subplots(1, 2, figsize=(DOUBLE_COLUMN_WIDTH, 2.25), sharey=False)

# --- Plot Data ---
# Left: downstream scores
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='downstream_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)
# Right: reward model scores
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='rm_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# --- Add reference lines for Original UltraFeedback scores ---
original_row = sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] == "Original"]

for ax, y in [
    (ax_left, original_row["downstream_mean_score"].values[0]),
    (ax_right, original_row["rm_mean_score"].values[0]),
]:
    ax.axhline(
        y=y,
        color=ACQUISITION_STYLES["Original"].color,
        dashes=ACQUISITION_STYLES["Original"].dashes,
        label="Original",
        zorder=0,
    )

# --- Legend ---
ax_left.get_legend().remove()
ax_right.get_legend().remove()
handles, labels = ax_left.get_legend_handles_labels()

fig.legend(
    handles,
    labels,
    loc='upper center',
    bbox_to_anchor=(0.5, 1.1),
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.text(0.5, -0.275, '(a) Fine-tuned Models', transform=ax_left.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_right.text(0.5, -0.275, '(b) Reward Models', transform=ax_right.transAxes, ha='center', fontsize=plt.rcParams['axes.labelsize'])
ax_left.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_right.set_ylabel('Score $\\Delta$', fontweight="bold")

ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])
ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4])

# Format y-ticks as +/- XX%
for ax in [ax_left, ax_right]:
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))

# Format x-Ticks as '10k', '20k', ... instead of '10000', '20000', ...
def thousands_formatter(x, pos):
    if x >= 1000:
        return f"{int(x/1000):d}k"
    else:
        return f"{int(x):d}"

for ax in [ax_left, ax_right]:
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))

# --- Grid ---
ax_left.grid(axis='y', alpha=GRID_ALPHA)
ax_right.grid(axis='y', alpha=GRID_ALPHA)

# --- Axis Limits ---
ax_left.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].max() * 1.1
)
ax_right.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    sample_efficiency_ultrafeedback_data['rm_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['rm_mean_score'].max() * 1.1
)

# --- Grid ---
for ax in [ax_left, ax_right]:
    ax.grid(alpha=GRID_ALPHA)

# --- Apply Marker Edge Width ---
for ax in [ax_left, ax_right]:
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])
        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])

# --- Robust Z-Order Fix ---
color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}
for ax in [ax_left, ax_right]:
    for line in ax.get_lines():
        line_color = mcolors.to_hex(line.get_color()).lower()[:7]
        
        if line_color in color_map:
            line.set_zorder(color_map[line_color])
            

# --- Save & Show ---
plt.tight_layout()
fig.savefig("sample_efficiency.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Sample Efficiency Plot (Split Export)
# ==============================================================================
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors

sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}
acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}

# Define single plot width
SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2
HEIGHT = 1.75

# Helper function for common formatting
def format_efficiency_plot(ax):
    # Ticks formatting
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f"{int(x/1000):d}k" if x >= 1000 else f"{int(x):d}"))
    
    # Grid
    ax.grid(alpha=GRID_ALPHA)
    
    # Marker Styles
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])
        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])
        
    # Z-Order Fix
    color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}
    for line in ax.get_lines():
        try:
            c = mcolors.to_hex(line.get_color()).lower()[:7]
            if c in color_map:
                line.set_zorder(color_map[c])
        except:
            pass

# ==========================================
# 1. Left Plot (Fine-tuned Models)
# ==========================================
fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

# Plot Lines
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='downstream_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)

# Reference Line
original_val_left = sample_efficiency_ultrafeedback_data[
    sample_efficiency_ultrafeedback_data["acquisition_function"] == "Original"
]["downstream_mean_score"].values[0]

ax_left.axhline(
    y=original_val_left,
    color=ACQUISITION_STYLES["Original"].color,
    dashes=ACQUISITION_STYLES["Original"].dashes,
    label="Original",
    zorder=0,
)

# --- Formatting ---
ax_left.get_legend().remove()
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_left.set_yticks([0.00, 0.05, 0.10, 0.15]) 

# --- Inserted: Left Axis Limits ---
ax_left.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_score'].max() * 1.1
)

format_efficiency_plot(ax_left)


# ==========================================
# 2. Right Plot (Reward Models)
# ==========================================
fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

# Plot Lines
sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='rm_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# Reference Line
original_val_right = sample_efficiency_ultrafeedback_data[
    sample_efficiency_ultrafeedback_data["acquisition_function"] == "Original"
]["rm_mean_score"].values[0]

ax_right.axhline(
    y=original_val_right,
    color=ACQUISITION_STYLES["Original"].color,
    dashes=ACQUISITION_STYLES["Original"].dashes,
    label="Original",
    zorder=0,
)

# --- Formatting ---
ax_right.get_legend().remove()
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_right.set_yticks([0.0, 0.1, 0.2, 0.3, 0.4]) 

# --- Inserted: Right Axis Limits ---
ax_right.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    sample_efficiency_ultrafeedback_data['rm_mean_score'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['rm_mean_score'].max() * 1.1
)

format_efficiency_plot(ax_right)


# ==========================================
# 3. Legend Export
# ==========================================
handles, labels = ax_left.get_legend_handles_labels()

fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))
fig_leg.legend(
    handles,
    labels,
    loc='center',
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# ==========================================
# 4. Save
# ==========================================
fig_left.savefig("sample_efficiency/left.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)
fig_right.savefig("sample_efficiency/right.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)
fig_leg.savefig("sample_efficiency/legend.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)

plt.show()

In [None]:
# ==============================================================================
# Sample Efficiency (AlpacaEval Ablation) - Split Export
# ==============================================================================
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors

sample_efficiency_ultrafeedback_data = sample_efficiency_raw_data.copy()

# --- Data Prep ---
# Calculate downstream mean WITHOUT AlpacaEval
sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'] = \
    sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa']].mean(axis=1)

# Calculate downstream mean WITH AlpacaEval
sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'] = \
    sample_efficiency_ultrafeedback_data[['gsm8k', 'ifeval', 'truthfulqa', 'alpacaeval_2']].mean(axis=1)

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}

SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2
HEIGHT = 1.75

# Helper function for common formatting
def format_alpaca_plot(ax):
    # Format y-ticks as +/- XX%
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))
    
    # Format x-Ticks as '10k'
    def thousands_formatter(x, pos):
        return f"{int(x/1000):d}k" if x >= 1000 else f"{int(x):d}"
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(thousands_formatter))

    # Grid & Markers
    ax.grid(alpha=GRID_ALPHA)
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])

# ==========================================
# 1. Left Plot: WITH AlpacaEval
# ==========================================
fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='downstream_mean_with_alpaca',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)

# Reference Line
y_left = sample_efficiency_ultrafeedback_data.loc[
    sample_efficiency_ultrafeedback_data["acquisition_function"] == "Original", 
    "downstream_mean_with_alpaca"
].values[0]

ax_left.axhline(
    y=y_left,
    color=ACQUISITION_STYLES["Original"].color,
    dashes=ACQUISITION_STYLES["Original"].dashes,
    label="Original",
    zorder=0,
)

# Formatting
ax_left.get_legend().remove()
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_left.set_yticks([0.00, 0.05, 0.10, 0.15])

# Limits
ax_left.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_with_alpaca'].max() * 1.1
)

format_alpaca_plot(ax_left)


# ==========================================
# 2. Right Plot: WITHOUT AlpacaEval
# ==========================================
fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

sns.lineplot(
    data=sample_efficiency_ultrafeedback_data[sample_efficiency_ultrafeedback_data["acquisition_function"] != "Original"],
    x='num_train_samples',
    y='downstream_mean_no_alpaca',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# Reference Line
y_right = sample_efficiency_ultrafeedback_data.loc[
    sample_efficiency_ultrafeedback_data["acquisition_function"] == "Original", 
    "downstream_mean_no_alpaca"
].values[0]

ax_right.axhline(
    y=y_right,
    color=ACQUISITION_STYLES["Original"].color,
    dashes=ACQUISITION_STYLES["Original"].dashes,
    label="Original",
    zorder=0,
)

# Formatting
ax_right.get_legend().remove()
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_right.set_yticks([0.00, 0.05, 0.10, 0.15])

# Limits
ax_right.set_xlim(
    sample_efficiency_ultrafeedback_data['num_train_samples'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].min() * 1.1,
    sample_efficiency_ultrafeedback_data['downstream_mean_no_alpaca'].max() * 1.1
)

format_alpaca_plot(ax_right)


# ==========================================
# 3. Legend Export
# ==========================================
handles, labels = ax_left.get_legend_handles_labels()

fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))
fig_leg.legend(
    handles,
    labels,
    loc='center',
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)


# ==========================================
# 4. Save
# ==========================================
fig_left.savefig("sample_efficiency_no_alpaca_eval/left.pdf", format="pdf", bbox_inches="tight", pad_inches=0)
fig_right.savefig("sample_efficiency_no_alpaca_eval/right.pdf", format="pdf", bbox_inches="tight", pad_inches=0)
fig_leg.savefig("sample_efficiency_no_alpaca_eval/legend.pdf", format="pdf", bbox_inches="tight", pad_inches=0)

plt.show()

In [None]:
# ==============================================================================
# Teaser Radar Plot (Normalized)
# ==============================================================================

teaser_data = teaser_raw_data.copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}

# Legend display names with symbols
legend_names = {
    'DeltaUCB': r'DeltaUCB$^\dagger$',
    'DRTS': r'DRTS$^\dagger$',
    'DTS': r'DTS$^*$',
}

# --- Data Preparation ---
benchmark_to_label = {
    "truthfulqa": "TruthfulQA",
    "gsm8k": "GSM8K",
    "rewardbench_2": "RewardBench 2",
    "ifeval": "IFEval",
    "alpacaeval_2": "AlpacaEval 2"
}
num_labels = len(benchmark_to_label)

# Extract data for each acquisition function and calculate mean over datasets
teaser_data = teaser_data[teaser_data["acquisition_function"].isin([
    'DeltaUCB',
    'UltraFeedback',
    'DTS',
    'DRTS',
    'DeltaQwen',
])]
teaser_data = (
    teaser_data.groupby(['acquisition_function'])[list(benchmark_to_label.keys())]
    .mean()
)



benchmark_to_limits = {}
for benchmark in benchmark_to_label.keys():
    benchmark_max = teaser_data[benchmark].max()
    benchmark_to_limits[benchmark] = (0.0, float(benchmark_max))
benchmark_to_ticks = {}
for benchmark in benchmark_to_label.keys():
    benchmark_to_ticks[benchmark] = np.linspace(benchmark_to_limits[benchmark][0], benchmark_to_limits[benchmark][1], 5)[1:]

# Normalize now that we have saved the original limits and ticks
teaser_data = teaser_data / teaser_data.max() * 0.9

y_max = 1.05
angle_offset = np.pi / 2 - (2 * np.pi / num_labels) * 1
angles = (np.linspace(0, 2 * np.pi, num_labels, endpoint=False) + angle_offset).tolist()
angles += angles[:1]

# --- Figure Setup ---
fig, ax = plt.subplots(figsize=(SINGLE_COLUMN_WIDTH, SINGLE_COLUMN_WIDTH * 1.1), subplot_kw=dict(polar=True))

# --- Plot Data ---
for acq_func in ACQUISITION_ORDER:
    if acq_func not in teaser_data.index:
        continue

    values_normalized = teaser_data.loc[acq_func, list(benchmark_to_label.keys())].tolist()
    values_closed = values_normalized + [values_normalized[0]]
    color = acquisition_colors[acq_func]
    marker = acquisition_markers[acq_func]
    label = legend_names.get(acq_func, acq_func)  # Use custom name if available

    ax.plot(angles, values_closed, color=color, marker=marker, label=label, zorder=10, mew=0)

# --- Legend ---
ax.legend(
    loc='upper center',
    bbox_to_anchor=(0.5, -0.1),
    ncol=3,
    frameon=False,
)

# --- Axis Labels & Titles ---
ax.set_xticks(angles[:-1])
ax.set_xticklabels([])
label_rotations = {
    'GSM8K': 0,
    'TruthfulQA': -72,
    'RewardBench 2': 72,
    'IFEval': -36,
    'AlpacaEval 2': 36
}
for angle, label in zip(angles[:-1], benchmark_to_label.values()):
    rotation = label_rotations.get(label, 0)
    ax.text(angle, 1.25, label, fontweight='bold',
            ha='center', va='center', rotation=rotation)

# --- Axis Ticks ---
# Only show the outermost tick (at 1.0 of the max value) for readability
tick_positions_normalized = [1.0]
ax.set_yticks(tick_positions_normalized)
ax.set_yticklabels([])
for angle, benchmark in zip(angles[:-1], benchmark_to_label.keys()):
    tick_val = benchmark_to_ticks[benchmark][-1]  # Get the outermost tick value
    tick_pos = tick_positions_normalized[0]
    # Format as +XX% instead of decimal
    tick_percent = tick_val * 100
    tick_label = f"+{tick_percent:.0f}%" if tick_percent > 0 else f"{tick_percent:.0f}%"
    # Use radial offset (not vertical) so labels are pushed outward consistently
    ax.text(angle, tick_pos + 0.04, tick_label,
            color='black', fontsize=plt.rcParams['xtick.labelsize'],
            ha='center', va='center', zorder=1)

# --- Grid ---
ax.yaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)
ax.xaxis.grid(True, linestyle='--', color='gray', alpha=GRID_ALPHA)

# --- Axis Limits ---
ax.set_ylim(0, y_max)
ax.set_rlim(0, y_max)
ax.spines['polar'].set_visible(False)

# --- Save & Show ---
plt.tight_layout()
fig.savefig("teaser.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
# ==============================================================================
# Dataset Statistics Plot (stacked bars)
# ==============================================================================

plot_data = datasets_data.copy()

dataset_to_data = {
    "baseline_random": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'random')].copy(),
    "baseline_ultrafeedback": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'ultrafeedback')].copy(),
    "baseline_maxmin": plot_data[(plot_data['training'] == 'baselines') & (plot_data['acquisition_function'] == 'maxmin')].copy(),
    "dpo_infomax": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'InfoMax')].copy(),
    "rm_infomax": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'InfoMax')].copy(),
    "dpo_dts": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DTS')].copy(),
    "rm_dts": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DTS')].copy(),
    "dpo_maxminlcb": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'MaxMinLCB')].copy(),
    "rm_maxminlcb": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'MaxMinLCB')].copy(),
    "dpo_drts": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DRTS')].copy(),
    "rm_drts": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DRTS')].copy(),
    "dpo_deltaucb": plot_data[(plot_data['training'] == 'dpo') & (plot_data['acquisition_function'] == 'DeltaUCB')].copy(),
    "rm_deltaucb": plot_data[(plot_data['training'] == 'rm') & (plot_data['acquisition_function'] == 'DeltaUCB')].copy(),
}

model_name_map = {
    "Qwen/Qwen2.5-0.5B-Instruct": "Qwen 2.5 0.5B",
    "Qwen/Qwen2.5-72B-Instruct": "Qwen 2.5 72B",
    "Qwen/Qwen3-0.6B": "Qwen 3 0.6B",
    "Qwen/Qwen3-1.7B": "Qwen 3 1.7B",
    "Qwen/Qwen3-14B": "Qwen 3 14B",
    "Qwen/Qwen3-30B-A3B": "Qwen 3 30B A3B",
    "Qwen/Qwen3-32B": "Qwen 3 32B",
    "Qwen/Qwen3-235B-A22B": "Qwen 3 235B A22B",
    "meta-llama/Llama-3.1-8B-Instruct": "Llama 3.1 8B",
    "meta-llama/Llama-3.2-1B-Instruct": "Llama 3.2 1B",
    "meta-llama/Llama-3.2-3B-Instruct": "Llama 3.2 3B",
    "meta-llama/Llama-3.3-70B-Instruct": "Llama 3.3 70B",
    "microsoft/Phi-4-mini-instruct": "Phi 4 Mini",
    "microsoft/phi-4": "Phi 4",
    "mistralai/Mistral-Small-24B-Instruct-2501": "Mistral Small",
    "mistralai/Mistral-Large-Instruct-2411": "Mistral Large",
    "nvidia/Llama-3_3-Nemotron-Super-49B-v1": "Nemotron Super 49B",
    "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF": "Nemotron 70B",
    "nvidia/Llama-3_1-Nemotron-Ultra-253B-v1": "Nemotron Ultra 253B",
    "google/gemma-3-1b-it": "Gemma 3 1B",
    "google/gemma-3-4b-it": "Gemma 3 4B",
    "google/gemma-3-12b-it": "Gemma 3 12B",
    "google/gemma-3-27b-it": "Gemma 3 27B",
    "HuggingFaceTB/SmolLM2-1.7B-Instruct": "SmolLM2 1.7B",
    "CohereLabs/c4ai-command-a-03-2025": "Command A",
    "deepseek-ai/DeepSeek-V3": "DeepSeek V3",
    "allenai/OLMo-2-0325-32B-Instruct": "OLMo 2 32B",
    "allenai/Llama-3.1-Tulu-3-70B": "Tulu 70B",
    "allenai/Llama-3.1-Tulu-3-405B": "Tulu 405B",
    "moonshotai/Moonlight-16B-A3B-Instruct": "Moonlight 16B A3B",
}

# Define which datasets/variants to show and their pretty subplot titles
plot_datasets = [
    ("baseline_random", "Random"),
    ("baseline_ultrafeedback", "UltraFeedback"),
    ("baseline_maxmin", "MaxMin"),
    ("dpo_infomax", "DPO: InfoMax"),
    ("rm_infomax", "RM: InfoMax"),
    ("dpo_dts", "DPO: DTS"),
    ("rm_dts", "RM: DTS"),
    ("dpo_maxminlcb", "DPO: MaxMinLCB"),
    ("rm_maxminlcb", "RM: MaxMinLCB"),
    ("dpo_drts", "DPO: DRTS"),
    ("rm_drts", "RM: DRTS"),
    ("dpo_deltaucb", "DPO: DeltaUCB"),
    ("rm_deltaucb", "RM: DeltaUCB"),
]

# Output each plot individually, filename format: dataset_statistics_{name}.pdf
for dataset_key, file_name in plot_datasets:
    fig, ax = plt.subplots(figsize=(DOUBLE_COLUMN_WIDTH, 2.4))

    # Prepare data as wide-form for stacking
    df = dataset_to_data[dataset_key].copy()

    df = df.set_index('model')
    df.index = [model_name_map.get(m, m) for m in df.index]

    models = list(df.index)
    chosen_counts = df['chosen_count']
    rejected_counts = df['rejected_count']

    # Stacked barplot using native matplotlib
    bar1 = ax.bar(
        models,
        rejected_counts,
        color=RED,
        label="Rejected",
    )
    bar2 = ax.bar(
        models, 
        chosen_counts, 
        color=GREEN, 
        label="Chosen", 
        bottom=rejected_counts
    )

    ax.set_xlabel('')
    ax.set_ylabel('Counts')

    ax.set_xlim(-0.75, len(models) - 0.25)
    ax.set_xticks(range(len(models)))
    ax.set_xticklabels(models, rotation=45, ha='right')

    # Remove title per instructions
    ax.set_title("")

    ax.grid(axis='y', alpha=GRID_ALPHA)
    ax.grid(axis='x', alpha=0.0)

    # Remove legend if it exists, then replace with only correct labels
    legend = ax.get_legend()
    if legend is not None:
        legend.remove()

    # Tight layout and save
    plt.tight_layout()
    fig.savefig(f"dataset_statistics/{dataset_key}.pdf", format="pdf", bbox_inches="tight")
    plt.close(fig)

In [None]:
# ==============================================================================
# Sample Efficiency Plot - IPO, SimPO (Split Export)
# ==============================================================================
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker
import matplotlib.colors as mcolors

ipo_data = sample_efficiency_ipo_simpo_raw_data[(sample_efficiency_ipo_simpo_raw_data["po_algorithm"] == "IPO") & (sample_efficiency_ipo_simpo_raw_data["acquisition_function"] != "Original")].copy()
simpo_data = sample_efficiency_ipo_simpo_raw_data[(sample_efficiency_ipo_simpo_raw_data["po_algorithm"] == "SimPO") & (sample_efficiency_ipo_simpo_raw_data["acquisition_function"] != "Original")].copy()

# --- Style Setup ---
acquisition_colors = {k: v.color for k, v in ACQUISITION_STYLES.items()}
acquisition_markers = {k: v.marker for k, v in ACQUISITION_STYLES.items()}
acquisition_dashes = {k: v.dashes if v.dashes is not None else "" for k, v in ACQUISITION_STYLES.items()}
acquisition_zorder = {k: v.zorder for k, v in ACQUISITION_STYLES.items()}

# Define single plot width
SINGLE_PLOT_WIDTH = DOUBLE_COLUMN_WIDTH / 2
HEIGHT = 1.75

# Helper function for common formatting
def format_efficiency_plot(ax):
    # Ticks formatting
    ax.yaxis.set_major_formatter(mticker.FuncFormatter(delta_percent_formatter))
    ax.xaxis.set_major_formatter(mticker.FuncFormatter(lambda x, pos: f"{int(x/1000):d}k" if x >= 1000 else f"{int(x):d}"))
    
    # Grid
    ax.grid(alpha=GRID_ALPHA)
    
    # Marker Styles
    for line in ax.get_lines():
        line.set_markeredgewidth(plt.rcParams['lines.markeredgewidth'])
        line.set_markeredgecolor(plt.rcParams['lines.markeredgecolor'])
        
    # Z-Order Fix
    color_map = {v.color.lower(): v.zorder for k, v in ACQUISITION_STYLES.items()}
    for line in ax.get_lines():
        try:
            c = mcolors.to_hex(line.get_color()).lower()[:7]
            if c in color_map:
                line.set_zorder(color_map[c])
        except:
            pass

# ==========================================
# 1. Left Plot (IPO)
# ==========================================
fig_left, ax_left = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

# Plot Lines
sns.lineplot(
    data=ipo_data,
    x='num_train_samples',
    y='downstream_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_left
)

# Reference Line
# original_val_left = ipo_data[
#     ipo_data["acquisition_function"] == "Original"
# ]["downstream_mean_score"].values[0]

# ax_left.axhline(
#     y=original_val_left,
#     color=ACQUISITION_STYLES["Original"].color,
#     dashes=ACQUISITION_STYLES["Original"].dashes,
#     label="Original",
#     zorder=0,
# )

# --- Formatting ---
ax_left.get_legend().remove()
ax_left.set_xlabel('Consumed Samples', fontweight="bold")
ax_left.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_left.set_yticks([0.00, 0.05, 0.10, 0.15]) 

# --- Inserted: Left Axis Limits ---
ax_left.set_xlim(
    ipo_data['num_train_samples'].min() * 1.1,
    ipo_data['num_train_samples'].max() * 1.1
)
ax_left.set_ylim(
    # ipo_data['downstream_mean_score'].min() * 1.1,
    -0.051,
    ipo_data['downstream_mean_score'].max() * 1.1
)

format_efficiency_plot(ax_left)


# ==========================================
# 2. Right Plot (SimPO)
# ==========================================
fig_right, ax_right = plt.subplots(figsize=(SINGLE_PLOT_WIDTH, HEIGHT))

# Plot Lines
sns.lineplot(
    data=simpo_data,
    x='num_train_samples',
    y='downstream_mean_score',
    hue='acquisition_function',
    style='acquisition_function',
    hue_order=ACQUISITION_ORDER,
    style_order=ACQUISITION_ORDER,
    palette=acquisition_colors,
    markers=acquisition_markers,
    dashes=acquisition_dashes,
    ax=ax_right
)

# Reference Line
# original_val_right = simpo_data[
#     simpo_data["acquisition_function"] == "Original"
# ]["rm_mean_score"].values[0]

# ax_right.axhline(
#     y=original_val_right,
#     color=ACQUISITION_STYLES["Original"].color,
#     dashes=ACQUISITION_STYLES["Original"].dashes,
#     label="Original",
#     zorder=0,
# )

# --- Formatting ---
ax_right.get_legend().remove()
ax_right.set_xlabel('Consumed Samples', fontweight="bold")
ax_right.set_ylabel('Score $\\Delta$', fontweight="bold")
ax_right.set_yticks([-0.05, 0.0, 0.1, 0.2, 0.3, 0.4]) 

# --- Inserted: Right Axis Limits ---
ax_right.set_xlim(
    simpo_data['num_train_samples'].min() * 1.1,
    simpo_data['num_train_samples'].max() * 1.1
)
ax_right.set_ylim(
    simpo_data['downstream_mean_score'].min() * 1.5,
    simpo_data['downstream_mean_score'].max() * 1.1
)

format_efficiency_plot(ax_right)


# ==========================================
# 3. Legend Export
# ==========================================
handles, labels = ax_left.get_legend_handles_labels()

fig_leg = plt.figure(figsize=(DOUBLE_COLUMN_WIDTH, 0.5))
fig_leg.legend(
    handles,
    labels,
    loc='center',
    ncol=(len(acquisition_colors) - 1) // 2 + 1,
    frameon=False,
)

# ==========================================
# 4. Save
# ==========================================
fig_left.savefig("sample_efficiency_ipo_simpo/left.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)
fig_right.savefig("sample_efficiency_ipo_simpo/right.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)
fig_leg.savefig("sample_efficiency_ipo_simpo/legend.pdf", format="pdf", bbox_inches="tight", pad_inches=0.02)

plt.show()