In [None]:
import pickle
import re
import signal
from typing import Optional

import datasets
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

plt.style.use("default")
rc = {
    "figure.figsize": (3.2, 2.4),
    "figure.dpi": 200,
    "figure.constrained_layout.use": True,
    "axes.grid": True,
    "axes.spines.right": False,
    "axes.spines.top": False,
    "axes.linewidth": 0.5,
    "grid.linewidth": 0.5,
    "xtick.major.width": 0.5,
    "ytick.major.width": 0.5,
    "xtick.major.size": 2.5,
    "ytick.major.size": 2.5,
    "axes.labelsize": "small",
    "axes.titlesize": "small",
    "xtick.labelsize": "small",
    "ytick.labelsize": "small",
    "axes.titlepad": 2,
    "axes.labelpad": 2,
    "xtick.major.pad": 2,
    "ytick.major.pad": 2,
    "lines.linewidth": 1,
    "patch.linewidth": 0,
}

test_data = datasets.load_from_disk(f"../ReProver/data/math")["test"]

In [2]:
def process_results(results: str):
    candidates = results.replace("</s>", "")
    answer = normalize_final_answer(candidates)
    return answer


def remove_boxed(s: str) -> str:
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"

    assert s[: len(left)] == left
    assert s[-1] == "}"

    return s[len(left) : -1]


class timeout:
    def __init__(self, seconds=1, error_message="Timeout"):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)

    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)


SUBSTITUTIONS = [
    ("an ", ""),
    ("a ", ""),
    (".$", "$"),
    ("\\$", ""),
    (r"\ ", ""),
    (" ", ""),
    ("mbox", "text"),
    (",\\text{and}", ","),
    ("\\text{and}", ","),
    ("\\text{m}", "\\text{}"),
]
REMOVED_EXPRESSIONS = [
    r"\left",
    r"\right",
    "square",
    "ways",
    "integers",
    "dollars",
    "mph",
    "inches",
    "ft",
    "hours",
    "km",
    "units",
    "\\ldots",
    "sue",
    "points",
    "feet",
    "minutes",
    "digits",
    "cents",
    "degrees",
    "cm",
    "gm",
    "pounds",
    "meters",
    "meals",
    "edges",
    "students",
    "childrentickets",
    "multiples",
    "\\text{s}",
    "\\text{.}",
    "\\text{\ns}",
    "\\text{}^2",
    "\\text{}^3",
    "\\text{\n}",
    "\\text{}",
    r"\mathrm{th}",
    r"^\circ",
    r"^{\circ}",
    r"\;",
    r",\!",
    "{,}",
    '"',
    "\\dots",
]


def normalize_final_answer(final_answer: str) -> str:
    """
    Normalize a final answer to a quantitative reasoning question.

    Copied character for character from appendix D of Lewkowycz et al. (2022)
    """
    final_answer = re.sub(r"\\text{(.*?)}", r"\1", final_answer)
    for before, after in SUBSTITUTIONS:
        final_answer = final_answer.replace(before, after)
    for expr in REMOVED_EXPRESSIONS:
        final_answer = final_answer.replace(expr, "")

    # Normalize shorthand TeX:
    #  \fracab -> \frac{a}{b}
    #  \frac{abc}{bef} -> \frac{abc}{bef}
    #  \fracabc -> \frac{a}{b}c
    #  \sqrta -> \sqrt{a}
    #  \sqrtab -> sqrt{a}b
    final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer)
    final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer)
    final_answer = final_answer.replace("$", "")

    # Normalize 100,000 -> 100000
    if final_answer.replace(",", "").isdigit():
        final_answer = final_answer.replace(",", "")

    return final_answer


def last_boxed_only_string(string: str) -> Optional[str]:
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    if right_brace_idx is None:
        retval = None
    else:
        retval = string[idx : right_brace_idx + 1]

    return retval


def normalized_answer(model_answer: str) -> str:
    last_boxed_string = last_boxed_only_string(model_answer)
    if last_boxed_string is None:
        return None
    try:
        filtered_result = remove_boxed(last_boxed_string)
    except:
        # print("-" * 80)
        # print(model_answer)
        # print(f"`remove_boxed` failed on {last_boxed_string}")
        return None
    try:
        answer = normalize_final_answer(filtered_result)
    except:
        # print(f"`normalize_final_answer` failed on {filtered_result}")
        return None
    return answer

In [3]:

exp_path = f"~/refine/ReProver/exps"

# TODO: set wandb run ids
# exp_ids = (<exp_id_1>,)

In [5]:
answers = [data["answer"] for data in test_data]

In [None]:
epochs = (0, 2, 4)

N_list = [1, 16, 64]

# NOTE: set to 256 or 1_024 depending on which N value test.py was called with
N_test = 1_024

n_files = N_test // 256

for exp in exp_ids:

    print(f"Experiment {exp}")

    meta_meta_greedy = []
    for N in tqdm(N_list):
        meta_greedy = []
        for ii in epochs:
            results = [[] for _ in range(500)]
            for j in range(n_files):
                with open(
                    f"{exp_path}/{exp}/epoch_{ii}/results_{N_test}_1_{j}_.pkl", "rb",
                ) as f:
                    results_new = pickle.load(f)
                for i in range(500):
                    results[i].extend(results_new[i])
            greedy = []
            for _ in range(N_test // N):
                filtered_results = []
                for result in results:
                    ttt = []
                    for res in result[_ * N : N + _ * N]:
                        try:
                            ttt.append(remove_boxed(last_boxed_only_string(res[0])))
                        except:
                            ttt.append("")
                    filtered_results.append(ttt)
                c = 0
                for i in range(500):
                    for res in filtered_results[i]:
                        # flag = True
                        if normalize_final_answer(answers[i]) == normalize_final_answer(
                            res
                        ):
                            c += 1
                            break
                greedy.append(c)
            meta_greedy.append(greedy)
        meta_meta_greedy.append(meta_greedy)

        # Number of times N inferences were collected (e.g. 16 for N_test=1_024 and N=64).
        # We're estimating the mean and standard error of the mean under this number of draws.
        n_samples = N_test // N

        counts = np.array(meta_greedy)
        assert counts.shape == (len(epochs), n_samples)
        sample_means = counts.mean(axis=-1)
        assert sample_means.shape == (len(epochs),)
        std_hat = 1 / np.sqrt(n_samples - 1) * np.sqrt(((counts - sample_means[:, None]) ** 2).sum(axis=-1))  # sum in sample dimension
        std_error = std_hat / np.sqrt(n_samples)

        print(f"N = {N}")
        for _i, epoch in enumerate(epochs):
            print(f"Epoch {epoch}: {sample_means[_i] / 5:.1f}% +- {std_error[_i] / 5:.1f}%")  # / 500 * 100