In [None]:
!wget https://raw.githubusercontent.com/hendrycks/math/refs/heads/main/modeling/math_equivalence.py

In [None]:
import math_equivalence
import re
import more_itertools as mit
import datasets
import transformers
import rich.table
from tqdm import tqdm

gemma_2_2b_it_pipeline = transformers.pipeline(task="text-generation", model="google/gemma-2-2b-it", device=0, torch_dtype="bfloat16")
math = datasets.load_dataset("hendrycks/competition_math")


In [53]:

N = None
WITH_SECOND = True
WITH_THIRD = True

def extract(answer, verbose=False):
    answer = re.sub(r"\$\$", "$", answer)
    answer = re.sub(re.escape(f"[asy]") + r".*?" + re.escape(f"[/asy]"), "", answer, flags=re.DOTALL)
    first_pat = re.compile(r"\$.+?\$", re.DOTALL)
    second_pat = re.compile(re.escape(r"\begin{align*}") + r".*?" + re.escape(r"\end{align*}"), re.DOTALL)
    third_path = re.compile(re.escape(r"\[") + r".*?" + re.escape(r"\]"), re.DOTALL)


    finds = {}
    for i, pat in enumerate([first_pat, second_pat, third_path]):
        extracted = pat.findall(answer)
        if extracted:
            finds[i] = dict(answer=extracted[-1], end_pos=answer.rfind(extracted[-1]) + len(extracted[-1]))
            
    if verbose:
        print(finds)

    max_find = max(
        finds.items(), 
        key=lambda x: x[1]["end_pos"],
    )
    if verbose:
        print(f"{max_find = }")

    return max_find[1]["answer"]

def formatted(extracted):
    start_str = "\\boxed{"
    start = extracted.find(start_str) + len(start_str)
    end = extracted.rfind("}")
    
    if start == -1 or end == -1:
        return extracted
    
    return extracted[start : end]

# count_bad = 0
# total = 0

# for sol in math["train"]["solution"]:
#     total += 1
#     answer = extract(sol)
#     if not answer:
#         count_bad += 1
    

# print(count_bad)
# print(f"{count_bad / total:0.2%}")

In [None]:
formatted(r"asdasdasd\\boxed{\\frac{1}{2}} asdas")

In [None]:
import rich
import rich.rule

count_weird = 0
for idx, sol in enumerate(math["train"]["solution"][:1000]):
    answer = extract(sol)
    pos = sol.rfind(answer)

    ends_where = len(sol) - pos - len(answer)
    if not ends_where < 70:
        count_weird += 1
        print(f"{idx = }")
        print("End of sol:", sol[-max(300, 30 + len(sol) - pos):])
        print("Answer:", answer)
        extract(sol, verbose=True)
        print()
        print(len(sol) - pos)
        rich.print(rich.rule.Rule())
print(count_weird)

In [None]:


N = 100
agreement_gemma_ours = 0
agreement_gemma_ours_formatted = 0
agreement_gemma_ours_ours_formatted = 0
agreement_gemma_ours_new_extractor = 0

def new_extractor(answer):
    start = answer.rfind("\\boxed{") + len("\\boxed{")
    if start == -1:
        return answer
    # Match brackets until we find the closing bracket
    count = 1
    for i, c in enumerate(answer[start:]):
        if c == "{":
            count += 1
        elif c == "}":
            count -= 1
            if count == 0:
                return answer[start:start + i]
    return answer[start:]

seen_so_far = 0
progress = tqdm(math["train"]["solution"])
for sample in progress:
    prompt = dict(role="user", content=r"Just give the answer. Extract the final, latex math answer in the big following solution. If it's just a number, remove the latex from it: " + sample)
    gemma = gemma_2_2b_it_pipeline([prompt], max_new_tokens=100, return_full_text=False, do_sample=False)[0]["generated_text"].strip()
    ours = extract(sample)
    formatted_ours = formatted(ours)
    new_extractor_ours = new_extractor(sample)
    is_equiv_gemma_ours = math_equivalence.is_equiv(ours, gemma)
    is_equiv_gemma_ours_formatted = math_equivalence.is_equiv(formatted_ours, gemma)
    is_equiv_ours_ours_formatted = math_equivalence.is_equiv(ours, formatted_ours)
    is_equiv_new_extractor_ours_gemma = math_equivalence.is_equiv(new_extractor_ours, gemma)
    agreement_gemma_ours += is_equiv_gemma_ours
    agreement_gemma_ours_formatted += is_equiv_gemma_ours_formatted
    agreement_gemma_ours_new_extractor += is_equiv_new_extractor_ours_gemma
    seen_so_far += 1

    # if not is_equiv_new_extractor_ours_gemma:
    #     table = rich.table.Table(show_header=False, show_lines=True, )
    #     table.add_row("Prompt:", prompt["content"])
    #     table.add_row("Ours:", ours)
    #     table.add_row("Ours, fomatted:", formatted_ours)
    #     table.add_row("Gemma:", gemma)
    #     table.add_row("new_extractor_ours:", new_extractor_ours)
    #     table.add_row("is_equiv:", str(is_equiv_gemma_ours))
    #     table.add_row("is_equiv w/ formatted:", str(is_equiv_gemma_ours_formatted))
    #     table.add_row("is_equiv new ours w/ gemma", str(is_equiv_new_extractor_ours_gemma))
    #     rich.print(table)
    #     rich.print(rich.rule.Rule())
    progress.set_description(f"{agreement_gemma_ours_new_extractor / seen_so_far:0.2%}")

print(f"Agreement gemma-ours: {agreement_gemma_ours / seen_so_far:0.2%}")
print(f"Agreement gemma-ours formatted: {agreement_gemma_ours_formatted / seen_so_far:0.2%}")
print(f"Agreement gemma new-ours: {agreement_gemma_ours_new_extractor / seen_so_far:0.2%}")


  0%|          | 0/7500 [00:00<?, ?it/s]

88.49%:  10%|▉         | 747/7500 [04:59<1:05:18,  1.72it/s]

In [64]:
agreement_gemma_ours_new_extractor

85