In [None]:
import json
import random
from datasets import load_dataset
from tqdm import tqdm
import re


LETTER_5 = ["A", "B", "C", "D", "E"]
LETTER_4 = ["A", "B", "C", "D"]
STRIP_PREFIX = re.compile(r"^[A-E]\)\s*")   # “A) ” … “E) ” (with optional spaces)
NUMERIC_TO_LETTER = {
    "1": "A",
    "2": "B",
    "3": "C",
    "4": "D",
    "5": "E",
}

def reduce_to_four_choices(choices: list[str], correct_letter: str, seed: int | None = None):
    """
    Drop ONE incorrect option at random if we have 5 choices.
    Returns the new choices list and the *updated* answer letter (A-D).
    """
    if seed is not None:
        random.seed(seed)

    if len(choices) <= 4:
        return choices, correct_letter

    # locate the correct option
    correct_idx = LETTER_5.index(correct_letter)

    # pick a wrong option to remove
    incorrect_idxs = [i for i in range(len(choices)) if i != correct_idx]
    remove_idx = random.choice(incorrect_idxs)
    choices.pop(remove_idx)

    # if we removed something that was *before* the correct answer,
    # the correct index shifts left by 1
    if remove_idx < correct_idx:
        correct_idx -= 1

    # map new index (0-3) to A-D
    new_correct_letter = LETTER_4[correct_idx]
    return choices, new_correct_letter


def process_medmcqa(split="train", seed=42):
    """"
    Process the MedMCQA dataset.
    Args:
        split (str): Dataset split to process, can be "train", "validation", or "test".
        seed (int): Random seed for reproducibility.
    Returns:
        list: Processed dataset with each item containing 'dataset', 'id', 'question',
               'choices', 'rationale', and 'answer'.
    """
    if split == "train":
        medmcqa = load_dataset("openlifescienceai/medmcqa", split=split)
        medmcqa = medmcqa.train_test_split(train_size=0.3, seed=seed)["train"]

    if split == "validation":
        medmcqa = load_dataset("openlifescienceai/medmcqa", split=split)
        medmcqa = medmcqa.train_test_split(train_size=0.15, seed=seed)["train"]    

    if split == "test":
        medmcqa = load_dataset("openlifescienceai/medmcqa", split="train")
        medmcqa = medmcqa.train_test_split(train_size=0.3, seed=seed)["test"]
        medmcqa = medmcqa.select(range(600))

    processed = []

    for idx, item in tqdm(enumerate(medmcqa),
                          total=len(medmcqa),
                          desc=f"Processing MedMCQA ({split})"):

        # four answer options
        choices = [item["opa"], item["opb"], item["opc"], item["opd"]]

        answer_idx = int(item["cop"])
        answer_letter = LETTER_4[answer_idx]

        processed.append(
            {
                "dataset" : "medmcqa",
                "id"      : f"medmcqa_{idx}",
                "question": item["question"],
                "choices" : choices,
                "rationale": item.get("exp"),
                "answer"  : answer_letter
            }
        )

    return processed


def process_aqua_rat(split="train", seed=42):
    """"
    Process the AQUA-RAT dataset.
    Args:
        split (str): Dataset split to process, can be "train", "validation", or "test".
        seed (int): Random seed for reproducibility.
    Returns:
        list: Processed dataset with each item containing 'dataset', 'id', 'question',
               'choices', 'rationale', and 'answer'.
    """
    aqua = load_dataset("deepmind/aqua_rat", split=split)
    processed = []

    for idx, item in tqdm(enumerate(aqua), total=len(aqua), desc="Processing AQUA-RAT"):
        # clean every option
        raw_choices   = item["options"]
        clean_choices = [STRIP_PREFIX.sub("", opt).lstrip() for opt in raw_choices]

        # copy because we'll mutate it
        choices, answer_letter = reduce_to_four_choices(
            clean_choices,
            item["correct"],
            seed
        )

        processed.append(
            {
                "dataset": "aqua_rat",
                "id": f"aqua_rat_{idx}",
                "question": item["question"],
                "choices": choices,
                "rationale": None,#item.get("rationale"),
                "answer": answer_letter,
            }
        )
    return processed


def process_sciq(split="train", seed=42):
    """"
    Process the SciQ dataset.
    Args:
        split (str): Dataset split to process, can be "train", "validation", or "test".
    Returns:
        list: Processed dataset with each item containing 'dataset', 'id', 'question',
               'choices', 'rationale', and 'answer'.
    """
    random.seed(seed)
    sciq = load_dataset("allenai/sciq", split=split)
    processed = []

    for idx, item in tqdm(enumerate(sciq), total=len(sciq), desc="Processing SciQ"):
        choices = [
            item["correct_answer"],
            item["distractor1"],
            item["distractor2"],
            item["distractor3"],
        ]
        random.shuffle(choices)

        # find where the correct answer ended up after shuffling
        correct_index = choices.index(item["correct_answer"])
        answer_letter = LETTER_4[correct_index]

        processed.append(
            {
                "dataset": "sciq",
                "id": f"sciq_{idx}",
                "question": item["question"],
                "choices": choices,
                "rationale": item.get("support"),
                "answer": answer_letter
            }
        )
    return processed


def process_ai2_arc(split="train", seed=42):
    """"
    Process the AI2 ARC dataset.
    Args:
        split (str): Dataset split to process, can be "train", "validation", or "test".
        seed (int): Random seed for reproducibility.
    Returns:
        list: Processed dataset with each item containing 'dataset', 'id', 'question',
               'choices', 'rationale', and 'answer'.
    """
    arc = load_dataset("allenai/ai2_arc", "ARC-Challenge", split=split)
    processed = []

    for idx, item in tqdm(enumerate(arc), total=len(arc), desc="Processing AI2 ARC"):
        choices = item["choices"]["text"]

        # Normalize the answerKey: map numbers (if any) to A/B/C...
        raw_answer = item["answerKey"].strip()
        answer_letter = NUMERIC_TO_LETTER.get(raw_answer, raw_answer)

        choices, answer_letter = reduce_to_four_choices(
            choices,
            answer_letter,
            seed
        )

        processed.append(
            {
                "dataset": "ai2_arc",
                "id": f'ai2_arc_{idx}',
                "question": item["question"],
                "choices": choices,
                "rationale": None,
                "answer": answer_letter
            }
        )

    return processed


def process_mmlu(split="train"):
    """"
    Process the MMLU-STEM dataset.
    Args:
        split (str): Dataset split to process, can be "train", "validation", or "test".
    Returns:
        list: Processed dataset with each item containing 'dataset', 'id', 'question',
               'choices', 'rationale', and 'answer'.
    """
    mmlu = load_dataset("antoine-444/mmlu_stem_dataset", split=split)
        
    processed = []

    for idx, item in tqdm(enumerate(mmlu), total=len(mmlu), desc="Processing MMLU"):
        choices = item["choices"]
        answer_letter = item["answer"]

        processed.append(
            {
                "dataset": "mmlu_stem",
                "id": f'mmlu_stem_{idx}',
                "question": item["question"],
                "choices": choices,
                "rationale": None,
                "answer": answer_letter
            }
        )

    return processed

In [None]:
data = []

data.extend(process_medmcqa(split="train"))
data.extend(process_aqua_rat(split="train"))
data.extend(process_sciq(split="train"))
data.extend(process_ai2_arc(split="train"))
data.extend(process_mmlu(split="train"))

with open("data/mmlu/train.json", "w") as f:
    json.dump(data, f, indent=2)

print("✅ mcqa_train.json has been saved.")