In [1]:
!pip install --upgrade datasets jsonlines



In [2]:
from datasets import Dataset, DatasetDict, load_dataset
from typing import Literal, TypedDict
import jsonlines
import os

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MMLU_DATASET_TASK = Literal["abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_medicine", "college_physics", "computer_security", "conceptual_physics", "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_european_history", "high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", "high_school_physics", "high_school_psychology", "high_school_statistics", "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", "international_law", "jurisprudence", "logical_fallacies", "machine_learning", "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", "professional_law", "professional_medicine", "professional_psychology", "public_relations", "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions"]
ALL_MMLU_DATASET_TASKS = ("abstract_algebra", "anatomy", "astronomy", "business_ethics", "clinical_knowledge", "college_biology", "college_chemistry", "college_computer_science", "college_mathematics", "college_medicine", "college_physics", "computer_security", "conceptual_physics", "econometrics", "electrical_engineering", "elementary_mathematics", "formal_logic", "global_facts", "high_school_biology", "high_school_chemistry", "high_school_computer_science", "high_school_european_history", "high_school_geography", "high_school_government_and_politics", "high_school_macroeconomics", "high_school_mathematics", "high_school_microeconomics", "high_school_physics", "high_school_psychology", "high_school_statistics", "high_school_us_history", "high_school_world_history", "human_aging", "human_sexuality", "international_law", "jurisprudence", "logical_fallacies", "machine_learning", "management", "marketing", "medical_genetics", "miscellaneous", "moral_disputes", "moral_scenarios", "nutrition", "philosophy", "prehistory", "professional_accounting", "professional_law", "professional_medicine", "professional_psychology", "public_relations", "security_studies", "sociology", "us_foreign_policy", "virology", "world_religions")

In [1]:
ALL_CS_MATH_TASKS = (
    "abstract_algebra",
    "college_computer_science",
    "college_mathematics",
    "computer_security",
    "elementary_mathematics",
    "formal_logic",
    "high_school_computer_science",
    "high_school_mathematics",
    "high_school_statistics",
    "machine_learning",
)

In [2]:
len(ALL_CS_MATH_TASKS)

10

In [5]:
def load_mmlu_dataset(tasks: list[MMLU_DATASET_TASK] | tuple[MMLU_DATASET_TASK, ...], /) -> dict[MMLU_DATASET_TASK, DatasetDict]:
    full_dataset = {}
    for task in tasks:
        full_dataset[task] = load_dataset("cais/mmlu", task)
    return full_dataset

In [6]:
class ProcessDatasetRow(TypedDict):
    subject: MMLU_DATASET_TASK
    question: str
    answer: Literal["A", "B", "C", "D"]

In [7]:
def process_mmlu_dataset(full_dataset: dict[MMLU_DATASET_TASK, DatasetDict], /) -> list[ProcessDatasetRow]:
    processed_dataset: list[ProcessDatasetRow] = []
    for key, value in full_dataset.items():
        assert key in ALL_MMLU_DATASET_TASKS
        assert isinstance(value, DatasetDict)
        assert frozenset(value.keys()) == {"dev", "test", "validation"}
        assert isinstance(value["dev"], Dataset) and isinstance(value["test"], Dataset) and isinstance(value["validation"], Dataset)

        def check_and_add(subject: MMLU_DATASET_TASK, dataset: Dataset, /):
            for element in dataset:
                assert isinstance(element, dict)
                assert frozenset(element.keys()) == {"answer", "choices", "question", "subject"}
                assert isinstance(element["subject"], str) and element["subject"] == subject

                question: str = element["question"]
                assert isinstance(question, str)

                choices: list[str] = element["choices"]
                assert isinstance(choices, list) and all(isinstance(x, str) for x in choices) and len(choices) >= 2 and len(choices) <= 4

                answer: Literal[0, 1, 2, 3] = element["answer"]
                assert isinstance(answer, int) and answer >= 0 and answer <= 3 and answer < len(choices)

                options_str = ""
                for i, choice in enumerate(choices):
                    options_str += ("\n" + ("A", "B", "C", "D")[i] + ". " + choice)

                processed_dataset.append(
                    {
                        "subject": subject,
                        "question": f"""Question: {element["question"]}\n\nOptions:{options_str}\n\nAnswer:""",
                        "answer": ("A", "B", "C", "D")[answer]
                    }
                )

        check_and_add(key, value["dev"])
        check_and_add(key, value["test"])
        check_and_add(key, value["validation"])
    return processed_dataset

In [8]:
full_dataset = load_mmlu_dataset(ALL_CS_MATH_TASKS)
full_dataset

{'abstract_algebra': DatasetDict({
     test: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 100
     })
     validation: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 11
     })
     dev: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 5
     })
 }),
 'college_computer_science': DatasetDict({
     test: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 100
     })
     validation: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 11
     })
     dev: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 5
     })
 }),
 'college_mathematics': DatasetDict({
     test: Dataset({
         features: ['question', 'subject', 'choices', 'answer'],
         num_rows: 100
     })
     validation: Dataset({
         features: ['que

In [9]:
processed_dataset = process_mmlu_dataset(full_dataset)
processed_dataset

[{'subject': 'abstract_algebra',
  'question': 'Question: Find all c in Z_3 such that Z_3[x]/(x^2 + c) is a field.\n\nOptions:\nA. 0\nB. 1\nC. 2\nD. 3\n\nAnswer:',
  'answer': 'B'},
 {'subject': 'abstract_algebra',
  'question': 'Question: Statement 1 | If aH is an element of a factor group, then |aH| divides |a|. Statement 2 | If H and K are subgroups of G then HK is a subgroup of G.\n\nOptions:\nA. True, True\nB. False, False\nC. True, False\nD. False, True\n\nAnswer:',
  'answer': 'B'},
 {'subject': 'abstract_algebra',
  'question': 'Question: Statement 1 | Every element of a group generates a cyclic subgroup of the group. Statement 2 | The symmetric group S_10 has 10 elements.\n\nOptions:\nA. True, True\nB. False, False\nC. True, False\nD. False, True\n\nAnswer:',
  'answer': 'C'},
 {'subject': 'abstract_algebra',
  'question': 'Question: Statement 1| Every function from a finite set onto itself must be one to one. Statement 2 | Every subgroup of an abelian group is abelian.\n\nOpt

In [10]:
len(processed_dataset)

1823

In [12]:
with jsonlines.open(os.path.join("datasets", "mcqa_dataset.jsonl"), mode="w") as f:
    f.write_all(processed_dataset)