## Dataset Preparation (Afrimed-QA v2)

This notebook loads **Afrimed-QA v2** from Hugging Face, inspects question types, filters to **MCQ** items,
removes questions with **multiple correct answers**, and saves the processed dataset to Google Drive.

> Tip: Run cells top-to-bottom. If you restart the runtime, remount Drive before saving.


In [None]:
"""Utilities and dependencies.

- `datasets.load_dataset`: download/load Afrimed-QA v2.
- `google.colab.userdata` / `huggingface_hub.login`: optional, for private datasets (not required for public datasets).
"""

from datasets import load_dataset
from huggingface_hub import login  # optional: for authenticated HF access
from google.colab import userdata  # optional: to read HF token stored in Colab
import math  # optional: kept for future use


### Load Afrimed-QA v2

We load the dataset using `datasets.load_dataset`. The returned object is a `DatasetDict` with splits
(i.e., `train`, `validation`, `test`).


In [None]:
# Load Afrimed-QA v2 from Hugging Face
ds = load_dataset("intronhealth/afrimedqa_v2")

# Quick sanity check: show available splits and features
ds

### Inspect question types

Afrimed-QA contains multiple question formats. We list the unique `question_type` values
so we can focus only on **MCQ** questions.


In [None]:
# List unique question types present in the training split
unique_question_types = ds["train"].unique("question_type")
print(unique_question_types)

### Show one example per question type (for verification)

This is a quick qualitative check to ensure the dataset fields look as expected for each type.


In [None]:
print("Examples for each question type:")

for q_type in unique_question_types:
    # Filter to the specific question type and select a single example
    example = ds["train"].filter(lambda x: x["question_type"] == q_type).select(range(1))

    print(f"\nQuestion Type: {q_type}")
    print(example[0])

### Stratification: keep only MCQ questions + essential columns

We:
1) Filter the training split to `question_type == "mcq"`.
2) Keep only the columns needed for fine-tuning / evaluation:
   - `question_clean`
   - `answer_options`
   - `correct_answer`
   - `answer_rationale`


In [None]:
# 1) Select only MCQ questions (plus their metadata)
mcq_questions = ds["train"].filter(lambda x: x["question_type"] == "mcq")

# 2) Keep only the columns we need
desired_columns = ["question_clean", "answer_options", "correct_answer", "answer_rationale"]
columns_to_remove = [col for col in mcq_questions.column_names if col not in desired_columns]

# Create the pared-down dataset
strat_mcq = mcq_questions.remove_columns(columns_to_remove)

# Display dataset summary
strat_mcq


### Mount Google Drive (for saving)


In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount("/content/drive")


### Helper: detect questions with multiple correct answers

Some items have answers like `"option1, option4"` or `"option1 and option3"`.
For MCQ fine-tuning where we only want *single* correct option, we take them out


In [None]:
def has_multiple_correct_answers(answer: str) -> bool:
    """Return True if `answer` appears to contain multiple correct options.

    Afrimed-QA's `correct_answer` field is typically a string like:
    - "option1"
    - "option2"
    - "option1, option4"
    - "option1 and option3"

    This helper uses simple heuristics (delimiters and repeated 'option' tokens) to detect
    multi-answer cases.

    Args:
        answer: The `correct_answer` field from the dataset.

    Returns:
        True if the string likely encodes multiple correct answers, otherwise False.
    """
    if isinstance(answer, str):
        answer_lower = answer.lower()
        return (
            ("," in answer_lower)
            or (" and " in answer_lower)
            or (answer_lower.count("option") > 1)
        )
    return False


In [None]:
# Filter out questions with multiple correct answers
strat_mcq_filtered = strat_mcq.filter(lambda x: not has_multiple_correct_answers(x["correct_answer"]))

print(f"Original dataset size: {len(strat_mcq)}")
print(f"Dataset size after removing multi-answer items: {len(strat_mcq_filtered)}")

# Define the save path on Google Drive
save_path = "/content/drive/MyDrive/AfricaLab/processed_mcq_dataset"

# Save the filtered dataset to Google Drive
strat_mcq_filtered.save_to_disk(save_path)

print(f"Dataset saved to: {save_path}")


### Quick preview of the processed dataset

In [None]:
# Inspect the first processed example
strat_mcq_filtered[0]