In [1]:
import itertools
import re
import string

import datasets
import torch
import numpy as np
import more_itertools as mi
import rich
import rich.table
import rich.markup
import transformers

In [59]:
LETTER_MAP = {v: i for i, v in enumerate(string.ascii_uppercase)}

PATS_STR = [
    r"^\w\.?$",
    r"^answer is \w\.?$",
    r"^option[=\:\s\-]+\w+\.?$",
    r"^answer[=\:\s\-]+\w+\.?$",
    r"^answer: option \w+\.?$",
    r"^choice[=\:\s\-]+\w+\.?$",
]
PATS = [re.compile(pat) for pat in PATS_STR]


def one_of(str_):
    str_ = str_.strip().lower()
    matches = [pat.match(str_) for pat in PATS]
    return any(matches)


def extract_answer(example):
    letter = example["correct"].strip()
    good_answer_idx = LETTER_MAP[letter]
    answer_text = example["options"][good_answer_idx].replace(",", "")
    init_answer_text = answer_text
    assert answer_text.startswith(letter + ")")
    
    while answer_text.upper().startswith(letter + ")"):
        answer_text = answer_text[len(letter) + 1:].strip()

    while answer_text.upper().startswith(letter + "."):
        answer_text = answer_text[len(letter) + 1:].strip()

    while answer_text.upper().startswith("[" + letter + "]"):
        answer_text = answer_text[1 + len(letter) + 1:].strip()

    while answer_text.upper().startswith(letter + " "):
        answer_text = answer_text[len(letter) + 1:].strip()

    rationale = " ".join(example["rationale"].split("\n")[:-1])
    rationale = rationale.replace("\n", " ")

    return {
        "answer": answer_text, 
        "rationale": rationale.strip() + ". The answer is " + answer_text
    }

def only_one_int(sample):
    matches = re.findall(r"\d+", sample["answer"])
    return matches and len(matches) == 1

rat = datasets.load_dataset("aqua_rat")

output = (rat["train"]
).filter(lambda x: one_of(x["rationale"].split("\n")[-1])
).map(extract_answer, batched=False
).filter(only_one_int, batched=False
).remove_columns(["options", "correct"])

print(output)


No config specified, defaulting to: aqua_rat/raw
Found cached dataset aqua_rat (/home/mila/g/gagnonju/.cache/huggingface/datasets/aqua_rat/raw/0.0.0/fc47b9f437236ab96fc1fcb61096aa193819aedd76437893e2390ab0740a3381)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/mila/g/gagnonju/.cache/huggingface/datasets/aqua_rat/raw/0.0.0/fc47b9f437236ab96fc1fcb61096aa193819aedd76437893e2390ab0740a3381/cache-d02f51627fc709f8.arrow


Map:   0%|          | 0/66390 [00:00<?, ? examples/s]

Filter:   0%|          | 0/66390 [00:00<?, ? examples/s]

Dataset({
    features: ['question', 'rationale', 'answer'],
    num_rows: 46243
})


In [60]:
output["rationale"]

['Speed of the boat downstream = 25 +11 = 36 kmph = 36 * 5/18 = 10 m/s Hence time taken to cover 80 m = 80/10 = 8 seconds. The answer is 8 seconds',
 'Smallest number of five digits is 10000. Required number must be divisible by L.C.M. of 22,33,66,44 i.e 132, On dividing 10000 by 132,we get 32 as remainder. Therefore, Required number = 10000 +( 132 â€“ 32 ) = 10100. The answer is 10100',
 'Solution S.I. = Rs.(956-825 )=Rs.131 Rate = (100x131/825x3) = 524/99% New rate = (524/99 +4)% = 920/99% New S.I. = Rs.(825 x 920/99 x 3/100) Rs. 230. ∴ New amount = Rs.(825+230)= Rs. 1055. The answer is Rs. 1055',
 'It is essential to recognize that the remainder when an integer is divided by 10 is simply the units digit of that integer. To help see this, consider the following examples: 4/10 is 0 with a remainder of 4 14/10 is 1 with a remainder of 4 5/10 is 0 with a remainder of 5 105/10 is 10 with a remainder of 5 It is also essential to remember that the q is a positive integer and multiple of 2.