In [6]:
import sys
import pathlib
import os

import datasets
import rich
import transformers

WORK_DIR = pathlib.Path().cwd().absolute()
PARENT = WORK_DIR.parent
print(PARENT)
sys.path.append(str(PARENT))

import lib_metric
import lib_data


/home/mila/g/gagnonju/Marg-Li-CoT/with_trl


In [2]:
t = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [3]:
gsm8k = {}

for split in ["train", "test"]:
    print(f"split: {split}")
    gsm8k[split] = lib_data.GSM8K(
    tok_max_query_length = None,
    tok_max_answer_length = None,
    tok_max_total_length = None,
    any_tokenizer=t,
    device="cpu",
    ds=datasets.load_dataset(  # type: ignore
        split=split,
        path="gsm8k",
        name="main",
    ),
    question_prefix = "",
    question_suffix = "",
)

split: train


Found cached dataset gsm8k (/home/mila/g/gagnonju/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)


split: test


Found cached dataset gsm8k (/home/mila/g/gagnonju/.cache/huggingface/datasets/gsm8k/main/1.1.0/37bfb08b1d4fcbb01f06b03d9e1ef5f1fcbd4d3af3d08842c50d7305091285ba)


In [47]:
import tqdm
import rich.table
import rich.markup
import math

import re
import more_itertools

def to_num(x):
    return float(x)

def pick_one_or_third(x):
    assert len(x) == 3, len(x)
    assert (
        (x[0] != "" and x[1] == "" and x[2] == "") or
        (x[0] != "" and x[1] != "" and x[2] == "") or
        (x[0] == "" and x[1] == "" and x[2] != "")
    ), x
    if x[0] != "":
        return x[0]
    else:
        assert x[2] != "", x
        return x[2]

def eval_eqn(eqn):
    print(eqn)
    tokenized = re.findall(r"(\d+(\.\d+)?)|(\d*(\.\d+))|([*\-+/])", eqn)
    print(tokenized)
    tokenized = [pick_one_or_third(t) for t in tokenized if t]
    print(tokenized)

    if tokenized[0] in ["*", "/", "+", "-"]:
        tokenized = tokenized[1:]

    num_so_far = to_num(tokenized[0])
    
    for idx in range(2, len(tokenized), 2):
        operator = tokenized[idx - 1]
        if operator == "*":
            num_so_far *= to_num(tokenized[idx])
        elif operator == "/":
            num_so_far /= to_num(tokenized[idx])
        elif operator == "+":
            num_so_far += to_num(tokenized[idx])
        elif operator == "-":
            num_so_far -= to_num(tokenized[idx])
        else:
            raise ValueError(f"{operator}, {tokenized}, {eqn}")
    return num_so_far


table = rich.table.Table(title="GSM8K")
table.add_column("Split")
table.add_column("left")
table.add_column("answer")
table.add_column("computed_answer")
table.add_column("computed_answer_pre_cast")
table.add_column("casted_ref_answer")


for split in ["train", "test"]:
    total = 0
    failed = 0
    print(split)
    for entry in tqdm.tqdm(gsm8k[split], desc=split):
        for obj in entry.obj_ref_equations:
            computed_answer_pre_cast = eval(obj["left"])
            # assert computed_answer_pre_cast == eval_eqn(obj["left"]), (computed_answer_pre_cast, eval_eqn(obj["left"]))
            computed_answer = round(computed_answer_pre_cast)

            try:
                casted_ref_answer = round(float(obj["answer"]))
            except ValueError:
                rich.print(f"[red]{rich.markup.escape(entry.detok_ref_scratchpad)}")
                casted_ref_answer = None

            if not computed_answer == casted_ref_answer:
                table.add_row(
                    split, 
                    rich.markup.escape(obj["left"]), 
                    rich.markup.escape(obj["answer"]), 
                    str(computed_answer), 
                    str(computed_answer_pre_cast), 
                    str(casted_ref_answer),
                )
                failed += 1
            total += 1

    rich.print(f"Split: {split} - failed: {failed}")
rich.print(table)

train


train:   0%|          | 0/7473 [00:00<?, ?it/s]

train: 100%|██████████| 7473/7473 [00:00<00:00, 34263.23it/s]


test


test:   0%|          | 0/1319 [00:00<?, ?it/s]

test: 100%|██████████| 1319/1319 [00:00<00:00, 28886.96it/s]
