In [None]:
import os
import re
import json
import pickle
import numpy as np 


def process_episode(episode, results_folder):
    # each episode is a list of "prompts" (one or more "prompts" item)
    lengths_list = []  # list of lengths for all prompts in the episode
    rewards_list = []  # list of rewards for all prompts in the episode
    for episode_prompts_file in episode:
        file = os.path.join(results_folder, episode_prompts_file)
        x = np.load(file)
        lengths = x[0].tolist()
        rewards = x[1].tolist()  # --> -1 for no box; -0.5 for box but wrong; 1 for correct
        lengths_list.extend(lengths)
        rewards_list.extend(rewards)
    return lengths_list, rewards_list

In [None]:
### Figure 2 replication -- varying difficulty

import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import numpy as np
import pandas as pd


results_folders = {
    'p=0.0625': '../train/checkpoints/R1_on_aime_p_0.0625/experiments',
    'p=0.125': '../train/checkpoints/R1_on_aime_p_0.125/experiments',
    'p=0.25': '../train/checkpoints/R1_on_aime_p_0.25/experiments',
    'p=0.375': '../train/checkpoints/R1_on_aime_p_0.375/experiments'
}

results = {}
for name in results_folders:
    results_folder = results_folders[name]
    print(name)
    episodes = {}
    for file in os.listdir(results_folder):
        episode = int(file.split('_')[0].split('episode')[1])
        episodes[episode] = []

    for episode in episodes:
        for file in os.listdir(results_folder):
            if file.startswith(f'episode{episode}_'):
                episodes[episode].append(file)

    mean_length_list = []
    accuracy_list = []
    
    for episode in range(len(episodes)):
        ep_list = episodes[episode]  # list of prompts for this episode
        lengths, rewards = process_episode(ep_list, results_folder)
        avg_length = np.mean(lengths)
        rewards = [1 if item > 0 else 0 for item in rewards]
        acc = 100.0 * sum(rewards) / len(rewards)
        mean_length_list.append(avg_length)
        accuracy_list.append(acc)
    results[name] = [accuracy_list, mean_length_list]



np.random.seed(42)
plt.close('all')
plt.style.use("./publication.mplstyle")
smoothing_window = 50
steps = np.arange(1, 1800 + 1)

fig, axes = plt.subplots(2, 1, figsize=(3, 4), sharex=True, gridspec_kw={'hspace': 0.14})
ax1, ax2 = axes

for name in results:
    print(name)
    [accuracy_list, mean_length_list] = results[name]
    # smoothing the data
    acc = pd.Series(accuracy_list)
    acc = acc.rolling(window=smoothing_window, min_periods=1, center=True).mean().tolist()
    mlen = pd.Series(mean_length_list)
    mlen = mlen.rolling(window=smoothing_window, min_periods=1, center=True).mean().tolist()
    xsize = min(len(acc), len(steps))
    ax1.plot(steps[:xsize], acc[:xsize], ls='-', lw=1, label=name, alpha=0.8)
    ax2.plot(steps[:xsize], mlen[:xsize], ls='-', lw=1, alpha=0.8)

for ax in axes:
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    
ax1.set_xticks([1] + np.arange(300, len(steps)+1, 300).tolist())
ax1.set_xlim(-1, len(steps))
ax1.set_ylabel("Accuracy (%)")
ax2.set_ylabel("Response Length (tokens)")
ax2.set_xlabel("Steps")

ax1.set_yticks(np.arange(0, 101, 10))
ax1.set_ylim(0, 100)
ax2.set_yticks(np.arange(3000, 20000, 2000))
ax2.set_ylim(3000, 20000)

def format_ticks(value, _):
    return f'{int(value/1000)}K'
ax2.yaxis.set_major_formatter(FuncFormatter(format_ticks))

fig.legend(loc="upper center", bbox_to_anchor=(0.5, 0.94), ncol=3, frameon=False)
plt.show()

In [None]:
### Figure 3 replication for R1-1.5B
# Running this cell takes about 15 minutes

import pandas as pd
from transformers import AutoTokenizer
import matplotlib.lines as mlines


pretrain = 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'
tokenizer = AutoTokenizer.from_pretrained(pretrain)

def get_len(response_list):
    tokenized_outputs = tokenizer(response_list, padding=True, truncation=True, return_tensors="pt")
    lengths = (tokenized_outputs["input_ids"] != tokenizer.pad_token_id).sum(dim=1).tolist()
    return lengths

def prepare_results(checkpoint_folder, data_name):
    tags = [item for item in os.listdir(checkpoint_folder) if 'global_step' in item]
    tags = sorted(tags, key=lambda x: int(re.search(r'\d+$', x).group()))  # making sort incremental
    tags = [tag for tag in tags if int(re.search(r'\d+$', tag).group()) <= 148]
    eval_folders = [os.path.join(checkpoint_folder, tag, 'eval', data_name) for tag in tags]
    accuracy_list = []
    mean_length_list = []
    mean_true_len_list = []
    mean_false_len_list = []
    for folder in eval_folders:
        print(folder)
        try:
            files = os.listdir(folder)
            files = [file for file in files if 'metrics' not in file]
            assert len(files) == 1
            file = os.path.join(folder, files[0])
            data = []
            with open(file, 'r') as f:
                for line in f:
                    data.append(json.loads(line))
            df = pd.DataFrame(data)
            df['qwen_avg_score'] = df['score'].apply(np.mean)
            df['response_lengths'] = df['code'].apply(get_len)
            df['mean_lengths'] = df['response_lengths'].apply(lambda x: int(np.mean(x)))
            df['true_len'] = df.apply(lambda row: [l for l, c in zip(row['response_lengths'], row['score']) if c], axis=1)
            df['false_len'] = df.apply(lambda row: [l for l, c in zip(row['response_lengths'], row['score']) if not c], axis=1)
            df['mean_true_len'] = df['true_len'].apply(lambda x: np.mean(x) if len(x) > 0 else 0)
            df['mean_false_len'] = df['false_len'].apply(lambda x: np.mean(x) if len(x) > 0 else 0)  # zero if empty
            accuracy_list.append(float(np.mean(df['qwen_avg_score'])))
            mean_length_list.append(int(np.mean(df['mean_lengths'])))
            mean_true_len_list.append(int(np.mean([x for x in df['mean_true_len'] if x != 0])))
            mean_false_len_list.append(int(np.mean([x for x in df['mean_false_len'] if x != 0])))
        except:
            print('>>> missing data')
            accuracy_list.append(accuracy_list[-1])
            mean_length_list.append(mean_length_list[-1])
            mean_true_len_list.append(mean_true_len_list[-1])
            mean_false_len_list.append(mean_false_len_list[-1])
    return accuracy_list, mean_length_list, mean_true_len_list, mean_false_len_list


benchmarks = ['math500', 'aime24', 'amc23', 'mmlu_stem']
checkpoint_folder = '../train/checkpoints/R1_on_math/_actor/'

results = {}
for benchmark in benchmarks:
    results[benchmark] = prepare_results(checkpoint_folder, benchmark)


lw_list = [1, 1, 1, 1]
ls_list = ['-', ':', '--', '-.']
palette = ['#0C5DA5', '#FF9500', '#00B945', '#FF2C00', '#845B97', '#474747', '#9e9e9e']

fig, axes = plt.subplots(2, 1, figsize=(3, 3), sharex=True, gridspec_kw={'hspace': 0.1})  # Small gap added
ax1, ax2 = axes

benchmarks = ['math500', 'aime24', 'amc23', 'mmlu_stem']

for benchmark_idx, benchmark in enumerate(benchmarks):
    if benchmark == 'mmlu_stem':
        continue
    accuracy_list = np.array(results[benchmark][0]) * 100
    mean_length_list = results[benchmark][1]
    steps = np.arange(0, len(accuracy_list)) * 4
    color = palette[0]
    ls = ls_list[benchmark_idx]
    ax1.plot(steps, accuracy_list, ls=ls, color=color, alpha=0.8)
    ax2.plot(steps, mean_length_list, ls=ls, color=color, alpha=0.8)

ax1.set_xlim(steps[0], steps[-1])
ax1.set_ylabel("Accuracy (%)")
ax2.set_ylabel("Response Length (tokens)")
ax2.set_xlabel("Steps")
ax1.set_xticks(np.arange(0, steps[-1], 16).tolist())
ax1.set_xlim(-1, steps[-1]+1)
ax1.set_yticks(np.arange(10, 101, 10))
ax1.set_ylim(10, 100)

for ax in [ax1, ax2]:
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)

def format_ticks(value, _):
    return f'{int(value/1000)}K'
ax2.yaxis.set_major_formatter(FuncFormatter(format_ticks))

model_legend = [mlines.Line2D([], [], color=palette[i], lw=1.5, label=exp.split('_8')[0].replace('_', ' ')+'B') for i, exp in enumerate(results.keys())]
benchmark_legend = [mlines.Line2D([], [], color='black', lw=1.5, ls=ls_list[i], label=benchmarks[i]) for i in range(len(benchmarks))]
# fig.legend(handles=model_legend, loc="upper left", bbox_to_anchor=(0.2, 1.07), frameon=False, title="Models")
fig.legend(handles=benchmark_legend, loc="upper right", bbox_to_anchor=(0.8, 1.07), frameon=False, title="Benchmarks")

plt.show()