In [1]:
import sys
import importlib
import numpy as np

sys.path.append("..")

from scripts.load_results import load_results

from strategic_ttc.verifiers.boxed_number import parse_pred_from_explanation
from strategic_ttc.verifiers.gpqa import parse_pred_from_explanation_gpqa


import strategic_ttc.core.accuracy_analysis_f as aa
import strategic_ttc.core.game_dynamics_f as gm
import scripts.utils as ut

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
importlib.reload(aa)
importlib.reload(gm)
importlib.reload(ut)

<module 'scripts.utils' from '/NL/strategic-compute/work/strategic-ttc/notebooks/../scripts/utils.py'>

# Load results

In [3]:
BENCHMARK = "GSM8K"

In [4]:
results = load_results(f"../final_runs/{BENCHMARK}")

if BENCHMARK in {"GSM8K", "AIME"}:
    pred_fn = parse_pred_from_explanation
else:
    pred_fn = parse_pred_from_explanation_gpqa

model_color = ut.assign_colors(results.keys())
reasoning_models, unreasoning_models = ut.categorize_models(results.keys())

In [5]:
model_color

{'Llama-3-8B': '#004183',
 'Llama-3.1-8B': '#4192E3',
 'Llama-3.2-1B': '#84CDE6',
 'Llama-3.2-3B': '#16A990',
 'Qwen2-0.5B': '#786B06',
 'Qwen2-1.5B': '#E1C84C',
 'Qwen2-7B': '#F7949F',
 'Qwen2.5-3B': '#AA3377',
 'Qwen2.5-7B': '#56002B',
 'reason-R1-D-Llama-8B': '#C5C3C3',
 'reason-R1-D-Qwen-1.5B': '#6E6E6E',
 'reason-R1-D-Qwen-7B': '#000000'}

In [8]:
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import numpy as np

# --- Constants & Setup ---
ICML_WIDTH_TEXT_PT = 486

def get_fig_dim(width, fraction=1, aspect_ratio=None):
    """Set figure dimensions to avoid scaling in LaTeX."""
    fig_width_pt = width * fraction
    inches_per_pt = 1 / 72.27
    if aspect_ratio is None:
        aspect_ratio = (1 + 5**.5) / 2
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_width_in / aspect_ratio
    return (fig_width_in, fig_height_in)

def latexify(font_size=10, small_font_size=8):
    params = {
        'backend': 'ps',
        'text.latex.preamble': r'\usepackage{gensymb} \usepackage{bm}',
        'font.size': font_size,
        'legend.fontsize': small_font_size,
        'text.usetex': True,    
        'font.family': 'serif',
        'font.serif': 'Computer Modern',
        'mathtext.fontset': 'cm'
    }
    plt.rcParams.update(params)

# --- Data ---
all_colors = {
    'Llama-3-8B': '#004183', 'Llama-3.1-8B': '#4192E3', 
    'Llama-3.2-1B': '#84CDE6', 'Llama-3.2-3B': '#16A990', 
    'Qwen2-0.5B': '#786B06', 'Qwen2-1.5B': '#E1C84C', 
    'Qwen2-7B': '#F7949F', 'Qwen2.5-3B': '#AA3377', 
    'Qwen2.5-7B': '#56002B', 
    'reason-R1-D-Llama-8B': '#C5C3C3', 
    'reason-R1-D-Qwen-1.5B': '#6E6E6E', 
    'reason-R1-D-Qwen-7B': '#000000'
}

# --- Separation Logic ---
# 1. Standard Models (Llama/Qwen) - 9 Items
standard_models = {k: v for k, v in all_colors.items() if "reason" not in k}

# 2. Reasoning Models - 3 Items
reasoning_models = {k: v for k, v in all_colors.items() if "reason" in k}

def create_standalone_legend(data_dict, filename, ncol, width_pt=ICML_WIDTH_TEXT_PT, clean_reasoning_names=False):
    """
    Generates a generic standalone legend PDF.
    """
    latexify(font_size=10, small_font_size=8)
    
    handles = []
    for name, hex_color in data_dict.items():
        # Formatting Logic
        if clean_reasoning_names:
            # Clean up the long reasoning names for the plot label
            clean_name = name.replace("reason-", "").replace("Distill-", "D-")
        else:
            clean_name = name
            
        # Escape for LaTeX
        clean_name = clean_name.replace("_", r"\_")
        label_str = f"\\texttt{{{clean_name}}}"
        
        handles.append(
            Line2D([0], [0], color=hex_color, lw=2, label=label_str)
        )

    # Dimensions
    num_items = len(handles)
    rows = (num_items + ncol - 1) // ncol
    
    # Calculate height based on rows
    leg_width, _ = get_fig_dim(width_pt, fraction=1.0)
    leg_height = rows * 0.25 
    
    fig = plt.figure(figsize=(leg_width, leg_height))
    
    fig.legend(
        handles=handles,
        loc='center',
        ncol=ncol,
        frameon=False,
        fontsize=8,
        columnspacing=1.0,
        handlelength=1.5
    )
    
    fig.savefig(filename, bbox_inches='tight', pad_inches=0.02)
    print(f"Saved: {filename}")
    plt.close(fig)

# --- Execute ---

# 1. Create Standard Legend (9 items -> 5 cols gives 2 rows)
create_standalone_legend(
    standard_models, 
    "legend_standard.pdf", 
    ncol=3,
    clean_reasoning_names=False
)

# 2. Create Reasoning Legend (3 items -> 3 cols gives 1 row)
create_standalone_legend(
    reasoning_models, 
    "legend_reasoning.pdf", 
    ncol=3,
    clean_reasoning_names=True # Will strip "reason-" prefix for cleaner label
)

Saved: legend_standard.pdf
Saved: legend_reasoning.pdf
