In [22]:

def compute_score(solution_str, ground_truth) -> float:
    retval = 0.0
    try:
        string_in_last_boxed = last_boxed_only_string(solution_str)
        if string_in_last_boxed is not None:
            answer = remove_boxed(string_in_last_boxed)
            if is_equiv(answer, ground_truth):
                retval = 1.0
    except Exception as e:
        print(e)

    return retval


# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
def is_equiv(str1, str2, verbose=False):
    if str1 is None and str2 is None:
        print("WARNING: Both None")
        return True
    if str1 is None or str2 is None:
        return False

    try:
        ss1 = strip_string(str1)
        ss2 = strip_string(str2)
        if verbose:
            print(ss1, ss2)
        return ss1 == ss2
    except Exception:
        return str1 == str2


def remove_boxed(s):
    if "\\boxed " in s:
        left = "\\boxed "
        assert s[: len(left)] == left
        return s[len(left) :]

    left = "\\boxed{"

    assert s[: len(left)] == left
    assert s[-1] == "}"

    return s[len(left) : -1]


def last_boxed_only_string(string):
    idx = string.rfind("\\boxed")
    if "\\boxed " in string:
        return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
    if idx < 0:
        idx = string.rfind("\\fbox")
        if idx < 0:
            return None

    i = idx
    right_brace_idx = None
    num_left_braces_open = 0
    while i < len(string):
        if string[i] == "{":
            num_left_braces_open += 1
        if string[i] == "}":
            num_left_braces_open -= 1
            if num_left_braces_open == 0:
                right_brace_idx = i
                break
        i += 1

    retval = None if right_brace_idx is None else string[idx : right_brace_idx + 1]

    return retval


def fix_fracs(string):
    substrs = string.split("\\frac")
    new_str = substrs[0]
    if len(substrs) > 1:
        substrs = substrs[1:]
        for substr in substrs:
            new_str += "\\frac"
            if substr[0] == "{":
                new_str += substr
            else:
                try:
                    assert len(substr) >= 2
                except:  # noqa: E722
                    return string
                a = substr[0]
                b = substr[1]
                if b != "{":
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}{" + b + "}" + post_substr
                    else:
                        new_str += "{" + a + "}{" + b + "}"
                else:
                    if len(substr) > 2:
                        post_substr = substr[2:]
                        new_str += "{" + a + "}" + b + post_substr
                    else:
                        new_str += "{" + a + "}" + b
    string = new_str
    return string


def fix_a_slash_b(string):
    if len(string.split("/")) != 2:
        return string
    a = string.split("/")[0]
    b = string.split("/")[1]
    try:
        a = int(a)
        b = int(b)
        assert string == "{}/{}".format(a, b)
        new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
        return new_string
    except:  # noqa: E722
        return string


def remove_right_units(string):
    # "\\text{ " only ever occurs (at least in the val set) when describing units
    if "\\text{ " in string:
        splits = string.split("\\text{ ")
        assert len(splits) == 2
        return splits[0]
    else:
        return string


def remove_text_answer(string):
    # Handle \text{answer} patterns
    import re
    # Match \text{...} patterns and extract the content inside the braces
    pattern = r'\\text\{([^}]*)\}'
    match = re.search(pattern, string)
    if match:
        # Return just the content inside the \text{...}
        return match.group(1)
    return string


def fix_sqrt(string):
    if "\\sqrt" not in string:
        return string
    splits = string.split("\\sqrt")
    new_string = splits[0]
    for split in splits[1:]:
        if split[0] != "{":
            a = split[0]
            new_substr = "\\sqrt{" + a + "}" + split[1:]
        else:
            new_substr = "\\sqrt" + split
        new_string += new_substr
    return new_string


def strip_string(string):
    # linebreaks
    string = string.replace("\n", "")

    # remove inverse spaces
    string = string.replace("\\!", "")

    # replace \\ with \
    string = string.replace("\\\\", "\\")

    # replace tfrac and dfrac with frac
    string = string.replace("tfrac", "frac")
    string = string.replace("dfrac", "frac")

    # remove \left and \right
    string = string.replace("\\left", "")
    string = string.replace("\\right", "")

    # Remove circ (degrees)
    string = string.replace("^{\\circ}", "")
    string = string.replace("^\\circ", "")

    # remove dollar signs
    string = string.replace("\\$", "")

    # remove units (on the right)
    string = remove_right_units(string)
    
    # remove \text{answer} patterns and extract content
    string = remove_text_answer(string)

    # remove percentage
    string = string.replace("\\%", "")
    string = string.replace("\%", "")  # noqa: W605

    # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
    string = string.replace(" .", " 0.")
    string = string.replace("{.", "{0.")
    # if empty, return empty string
    if len(string) == 0:
        return string
    if string[0] == ".":
        string = "0" + string

    # to consider: get rid of e.g. "k = " or "q = " at beginning
    if len(string.split("=")) == 2 and len(string.split("=")[0]) <= 2:
        string = string.split("=")[1]

    # fix sqrt3 --> sqrt{3}
    string = fix_sqrt(string)

    # remove spaces
    string = string.replace(" ", "")

    # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1).
    # Also does a/b --> \\frac{a}{b}
    string = fix_fracs(string)

    # manually change 0.5 --> \frac{1}{2}
    if string == "0.5":
        string = "\\frac{1}{2}"

    # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
    string = fix_a_slash_b(string)

    return string

In [32]:
import json
t = 0
diffs = []
with open("data/improved_results_wrongs.jsonl") as f:
    for i,line in enumerate(f):
        a = json.loads(line)
        expected_answer = problem_to_answer[a["problem"]]
        score = compute_score(a["llm_output"], expected_answer)
        t += score
        if score > 0.5:
            diffs.append({"problem": a["problem"], "llm_output": a["llm_output"], "expected_answer": expected_answer})
            print(f"Problem {i}")
            print(a["problem"])
            print("LLM OUTPUT\n\n")
            print(a["llm_output"])
            print("EXPECTED ANSWER\n\n")
            print(expected_answer)
            print("--------------------------------")

print(t/500)


Problem 1
Let $p(x)$ be a polynomial of degree 5 such that
\[p(n) = \frac{n}{n^2 - 1}\]for $n = 2,$ 3, 4, $\dots,$ 7.  Find $p(8).$
LLM OUTPUT


<｜User｜>Let $p(x)$ be a polynomial of degree 5 such that
\[p(n) = \frac{n}{n^2 - 1}\]for $n = 2,$ 3, 4, $\dots,$ 7.  Find $p(8).$<｜Assistant｜><think>
Okay, so I have this problem where there's a polynomial \( p(x) \) of degree 5. It's given that for \( n = 2, 3, 4, \dots, 7 \), \( p(n) = \frac{n}{n^2 - 1} \). I need to find \( p(8) \). Hmm, okay.

First, let me understand what's going on here. We have a polynomial of degree 5, and we know its values at 6 points (from 2 to 7). Since a polynomial of degree \( k \) is uniquely determined by \( k+1 \) points, this should define the polynomial uniquely. So, there should be only one such polynomial \( p(x) \) that satisfies these conditions. Cool, so I don't have to worry about multiple solutions or anything like that.

Now, the challenge is to find \( p(8) \). How do I approach this? Well, maybe I 

In [33]:
with open("data/diffs.jsonl", "w") as f:
    for diff in diffs:
        f.write(json.dumps(diff) + "\n")

In [27]:
problem_to_answer

{'Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\\theta),$ where $r > 0$ and $0 \\le \\theta < 2 \\pi.$': '\\left( 3, \\frac{\\pi}{2} \\right)',
 'Define\n\\[p = \\sum_{k = 1}^\\infty \\frac{1}{k^2} \\quad \\text{and} \\quad q = \\sum_{k = 1}^\\infty \\frac{1}{k^3}.\\]Find a way to write\n\\[\\sum_{j = 1}^\\infty \\sum_{k = 1}^\\infty \\frac{1}{(j + k)^3}\\]in terms of $p$ and $q.$': 'p - q',
 'If $f(x) = \\frac{3x-2}{x-2}$, what is the value of $f(-2) +f(-1)+f(0)$? Express your answer as a common fraction.': '\\frac{14}{3}',
 'How many positive whole-number divisors does 196 have?': '9',
 'The results of a cross-country team\'s training run are graphed below. Which student has the greatest average speed? [asy]\nfor ( int i = 1; i <= 7; ++i )\n{\n\ndraw((i,0)--(i,6));\n}\n\nfor ( int i = 1; i <= 5; ++i )\n{\n\ndraw((0,i)--(8,i));\n}\ndraw((-0.5,0)--(8,0), linewidth(1));\ndraw((0,-0.5)--(0,6), linewidth(1));\nlabel("$O$", (

In [40]:
problem_to_llm_answer = {}
for diff in diffs:
    parts = diff["llm_output"].split("\\boxed", 1)
    diff["llm_output"] = parts[0] + "\\boxed" if len(parts) > 1 else diff["llm_output"]
    problem_to_llm_answer[diff["problem"]] = diff["llm_output"]


In [41]:
with open("data/improved_results_math_500_train.jsonl") as f:
    for line in f:
        a = json.loads(line)
        problem = a["problem"]
        if(problem in problem_to_llm_answer.keys()):
            print("PROBLEM")
            print(problem)
            print("BETTER ANSWER")
            print(problem_to_llm_answer[problem])
            print("WORSE ANSWER")
            print(a["llm_output"])
            print("--------------------------------")

PROBLEM
Let $p(x)$ be a polynomial of degree 5 such that
\[p(n) = \frac{n}{n^2 - 1}\]for $n = 2,$ 3, 4, $\dots,$ 7.  Find $p(8).$
BETTER ANSWER
<｜User｜>Let $p(x)$ be a polynomial of degree 5 such that
\[p(n) = \frac{n}{n^2 - 1}\]for $n = 2,$ 3, 4, $\dots,$ 7.  Find $p(8).$<｜Assistant｜><think>
Okay, so I have this problem where there's a polynomial \( p(x) \) of degree 5. It's given that for \( n = 2, 3, 4, \dots, 7 \), \( p(n) = \frac{n}{n^2 - 1} \). I need to find \( p(8) \). Hmm, okay.

First, let me understand what's going on here. We have a polynomial of degree 5, and we know its values at 6 points (from 2 to 7). Since a polynomial of degree \( k \) is uniquely determined by \( k+1 \) points, this should define the polynomial uniquely. So, there should be only one such polynomial \( p(x) \) that satisfies these conditions. Cool, so I don't have to worry about multiple solutions or anything like that.

Now, the challenge is to find \( p(8) \). How do I approach this? Well, maybe I c