In [2]:
#Code origin
#Author: Alexander Valentini
#Here we don't apply the chat template, but just do preliminary data processing

import datasets
import re
import jsonlines
from typing import Literal, TypedDict

In [3]:
CHOICES_PATTERN = re.compile(r"a \) (.+) , b \) (.+) , c \) (.+) , d \) (.+) , e \) (.+)")
CHOICES_PATTERN_2 = re.compile(r"\['a \) ([^']+)', 'b \) ([^']+)', 'c \) ([^']+)', 'd \) ([^']+)', 'e \) ([^']+)'\]")

In [4]:
#How the dataset looks when pulled from huggingface:
dataset = datasets.load_dataset("allenai/math_qa")
dataset

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/3.25k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.44k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/7.30M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/29837 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2985 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4475 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 29837
    })
    test: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 2985
    })
    validation: Dataset({
        features: ['Problem', 'Rationale', 'options', 'correct', 'annotated_formula', 'linear_formula', 'category'],
        num_rows: 4475
    })
})

In [5]:
class ResultDict(TypedDict):
    subject: str
    question: str
    answer: Literal["A", "B", "C", "D", "E"]

In [6]:
#Its because we need to write it in the format below, instead of the one from huggingface. 
#Which uses "a ) option" style
def craft_mcqa_question(question: str, options: tuple[str, str, str, str, str], /) -> str:
    assert all(option.strip() == option for option in options)
    joined_options = f"A. {options[0]}\nB. {options[1]}\nC. {options[2]}\nD. {options[3]}\nE. {options[4]}"
    return f"Question: {question}\n\nOptions:\n{joined_options}\n\nAnswer:"

In [7]:
#We convert all lower case options in the dataset to upper case:
def upper_case(letter: Literal["a", "b", "c", "d", "e"], /) -> Literal["A", "B", "C", "D", "E"]:
    if letter == "a":
        return "A"
    if letter == "b":
        return "B"
    if letter == "c":
        return "C"
    if letter == "d":
        return "D"
    if letter == "e":
        return "E"
    raise ValueError("Incorrect!")

In [8]:
def parse_dataset(data: datasets.arrow_dataset.Dataset, /) -> list[ResultDict]:
    results: list[ResultDict] = []
    for line in data:
        assert isinstance(line, dict)

        # It should have keys in the dictionary with these names
        problem: str = line["Problem"]
        options: str = line["options"]
        correct: str = line["correct"]
        category: str = line["category"]

        # Verify types (the dataset should contain strings)
        assert isinstance(problem, str)
        assert isinstance(options, str)
        assert isinstance(correct, str)
        assert isinstance(category, str)
        assert correct in ("a", "b", "c", "d", "e")

        # Parse options
        #We want to assert it has the expected format from huggingface before we change it. To make sure 
        #it is at least downloaded and loaded correctly from huggingface before further preprocessing:
        options_match = CHOICES_PATTERN.fullmatch(options)
        if options_match is None:
            options_match = CHOICES_PATTERN_2.fullmatch(options)
            assert options_match is not None

        #There needs to be 5 keys we are extracting:
        # Extracting groups
        groups: tuple[str, ...] = options_match.groups()
        assert len(groups) == 5
        assert all(isinstance(x, str) for x in groups)

        #We end up with 3 keys in the dictionary for each datapoint (Some info is removed). We removed the rationale
        #to make the training process simpler:
        results.append({
            "question": craft_mcqa_question(problem, groups),
            "answer": upper_case(correct),
            "subject": category
        })

    # Return results
    return results

In [9]:
train_data = parse_dataset(dataset["train"])
test_data = parse_dataset(dataset["test"])
validation_data = parse_dataset(dataset["validation"])

In [10]:
#New Format:
train_data[0]

{'question': "Question: the banker ' s gain of a certain sum due 3 years hence at 10 % per annum is rs . 36 . what is the present worth ?\n\nOptions:\nA. rs . 400\nB. rs . 300\nC. rs . 500\nD. rs . 350\nE. none of these\n\nAnswer:",
 'answer': 'A',
 'subject': 'gain'}

In [10]:
with jsonlines.open("datasets/mcqa/mcqa_math_train_dataset.jsonl", "w") as f:
    f.write_all(train_data)

In [11]:
with jsonlines.open("datasets/mcqa/mcqa_math_test_dataset.jsonl", "w") as f:
    f.write_all(test_data)
with jsonlines.open("datasets/mcqa/mcqa_math_validation_dataset.jsonl", "w") as f:
    f.write_all(validation_data)