Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 71 additions & 5 deletions src/unitxt/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,31 @@ def preprocess_input_and_reference_fields(


class MultipleChoiceTemplate(InputFormatTemplate):
"""Formats the input (that specifies the question), the multiple choices to select the answer from, and specifies the field with the correct answer."""
"""Formats the input that specifies a multiple-choice question, with a list of possible answers to choose from, and identifies the correct answer.

Args:
target_prefix (str): Optional prefix that can be added before the target label in
generated prompts or outputs.
choices_field (str): The key under which the multiple choices are stored in the
input and reference dictionaries.
target_field (str): The key under which the correct choice is stored in the
reference dictionary (can be integer index or textual label).
choices_separator (str): A string used to join formatted choices (e.g. ", ").
source_choice_format (str): A Python format string used for displaying each choice
in the input fields (e.g. "{choice_numeral}. {choice_text}").
target_choice_format (str): A Python format string used for displaying each choice
in the target or final output (e.g. "{choice_numeral}").
enumerator (str): Determines how choice numerals are enumerated. Possible values
include "capitals", "lowercase", "numbers", or "roman".
shuffle_choices (bool): If True, shuffle the choices. The shuffling seed can be
set with `shuffle_choices_seed`.
shuffle_choices_seed (int, optional): If provided, the choices are shuffled with
this fixed integer seed for reproducibility.
sort_choices_by_length (bool): If True, sorts choices by their length (ascending).
sort_choices_alphabetically (bool): If True, sorts choices in alphabetical order.
reverse_choices (bool): If True, reverses the order of the choices after any
sorting has been applied. Defaults to False to preserve backward compatibility.
"""

target_prefix: str = ""
choices_field: str = "choices"
Expand All @@ -504,7 +528,12 @@ class MultipleChoiceTemplate(InputFormatTemplate):
source_choice_format: str = "{choice_numeral}. {choice_text}"
target_choice_format: str = "{choice_numeral}"
enumerator: str = "capitals"

shuffle_choices: bool = False
shuffle_choices_seed: int = None
sort_choices_by_length: bool = False
sort_choices_alphabetically: bool = False
reverse_choices: bool = False # False by default for backward-compat

def prepare(self):
super().prepare()
Expand Down Expand Up @@ -538,6 +567,22 @@ def prepare(self):
"XX",
]

def verify(self):
super().verify()
if self.shuffle_choices and (
self.sort_choices_by_length
or self.sort_choices_alphabetically
or self.reverse_choices
):
raise UnitxtError(
"You cannot combine shuffle_choices with sorting or reversing flags."
)

if self.sort_choices_by_length and self.sort_choices_alphabetically:
raise UnitxtError(
"You cannot combine both sort_choices_by_length and sort_choices_alphabetically simultaneously."
)

def inputs_to_choices(self, data: Dict[str, Any], choice_format: str) -> str:
choices = data[self.choices_field]
enumrated_choices = []
Expand Down Expand Up @@ -612,16 +657,37 @@ def reference_fields_to_target_and_references(
def preprocess_input_and_reference_fields(
self, input_fields: Dict[str, Any], reference_fields: Dict[str, Any]
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
choices = input_fields[self.choices_field]

# --- handle sorting first ---
if self.sort_choices_by_length:
choices.sort(key=len)
if self.sort_choices_alphabetically:
choices.sort()
if self.reverse_choices:
choices.reverse()

# Update both input_fields and reference_fields for sorting
input_fields[self.choices_field] = choices
reference_fields[self.choices_field] = choices

# --- Then handle shuffling, if enabled ---
if self.shuffle_choices:
# Find the original target index before shuffling
target_index = self.outputs_to_target_index(reference_fields)
original_label_choice = reference_fields[self.choices_field][target_index]
choices = input_fields[self.choices_field]
random_seed = {**input_fields}

random_generator = new_random_generator(random_seed)
# Use a fixed seed if provided, otherwise generate from input fields
if self.shuffle_choices_seed is not None:
random_generator = new_random_generator(self.shuffle_choices_seed)
else:
random_seed = {**input_fields}
random_generator = new_random_generator(random_seed)

random_generator.shuffle(choices)
input_fields[self.choices_field] = choices

# Update the fields after shuffling
input_fields[self.choices_field] = choices
reference_fields[self.choices_field] = choices
reference_fields[self.target_field] = choices.index(original_label_choice)

Expand Down
231 changes: 231 additions & 0 deletions tests/library/test_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,237 @@ def test_multiple_choice_template_with_shuffle(self):
str(ve.exception.__cause__),
)

def test_multiple_choice_template_with_shuffle_choices_seed(self):
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
shuffle_choices=True,
shuffle_choices_seed=12345, # fixed seed
)

inputs = [
{
"input_fields": {
"choices": ["Alpha", "Beta", "Gamma"],
"text": "example seed",
},
"reference_fields": {
"choices": ["Alpha", "Beta", "Gamma"],
"label": 2, # "Gamma" is index 2
},
}
]

# Because we have a fixed seed, we expect a certain deterministic shuffle.
# Let's assume that the random shuffle with seed=12345 will order the choices
# as ["Beta", "Gamma", "Alpha"].
# The enumerator is "capitals" => ["A", "B", "C"]
targets = [
{
"input_fields": {
"choices": ["Gamma", "Beta", "Alpha"], # after shuffling
"text": "example seed",
"options": ["A", "B", "C"], # enumerator placeholders
},
"reference_fields": {
"choices": ["Gamma", "Beta", "Alpha"],
"label": 0, # "Gamma" is now index 1
},
# The source with source_choice_format = "{choice_numeral}. {choice_text}"
# => "A. Beta, B. Gamma, C. Alpha"
"source": "Text: example seed, Choices: A. Gamma, B. Beta, C. Alpha.",
"target": "A", # Because original label=2 => "Gamma" => new index=1 => enumerator "B"
"references": ["A"], # For final references
"instruction": "",
"target_prefix": "",
"postprocessors": ["processors.to_string_stripped"],
}
]

check_operator(template, inputs, targets, tester=self)

def test_multiple_choice_template_with_sort_choices_by_length(self):
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
sort_choices_by_length=True,
)

inputs = [
{
"input_fields": {
"choices": ["Large", "No", "Yes", "Tiny"],
"text": "example length sort",
},
"reference_fields": {
"choices": ["Large", "No", "Yes", "Tiny"],
"label": "Large",
},
}
]

# After sorting by length: ["No", "Yes", "Tiny", "Large"]
# "Large" was originally label => now index=3 => enumerator "D"
targets = [
{
"input_fields": {
"choices": ["No", "Yes", "Tiny", "Large"],
"text": "example length sort",
"options": ["A", "B", "C", "D"], # enumerator placeholders
},
"reference_fields": {
"choices": ["No", "Yes", "Tiny", "Large"],
"label": "Large", # "Large" => new index=3
},
"source": "Text: example length sort, Choices: A. No, B. Yes, C. Tiny, D. Large.",
"target": "D", # enumerator => label=3 => "D"
"references": ["D"],
"instruction": "",
"target_prefix": "",
"postprocessors": ["processors.to_string_stripped"],
}
]

check_operator(template, inputs, targets, tester=self)

def test_multiple_choice_template_with_sort_choices_alphabetically(self):
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
sort_choices_alphabetically=True,
)

inputs = [
{
"input_fields": {
"choices": ["Delta", "Alpha", "Charlie", "Bravo"],
"text": "example alphabetical sort",
},
"reference_fields": {
"choices": ["Delta", "Alpha", "Charlie", "Bravo"],
"label": "Charlie",
},
}
]

# Alphabetically sorted: ["Alpha", "Bravo", "Charlie", "Delta"]
# "Charlie" => new index=2 => enumerator "C"
targets = [
{
"input_fields": {
"choices": ["Alpha", "Bravo", "Charlie", "Delta"],
"text": "example alphabetical sort",
"options": ["A", "B", "C", "D"],
},
"reference_fields": {
"choices": ["Alpha", "Bravo", "Charlie", "Delta"],
"label": "Charlie",
},
"source": "Text: example alphabetical sort, Choices: A. Alpha, B. Bravo, C. Charlie, D. Delta.",
"target": "C",
"references": ["C"],
"instruction": "",
"target_prefix": "",
"postprocessors": ["processors.to_string_stripped"],
}
]

check_operator(template, inputs, targets, tester=self)

def test_multiple_choice_template_with_reverse_choices(self):
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
reverse_choices=True,
# No sorting flags => just reversing the original list
)

inputs = [
{
"input_fields": {
"choices": ["One", "Two", "Three"],
"text": "example reverse",
},
"reference_fields": {
"choices": ["One", "Two", "Three"],
"label": "Two",
},
}
]

# Original: ["One", "Two", "Three"]
# Reversed: ["Three", "Two", "One"]
# "Two" => new index=1 => enumerator "B"
targets = [
{
"input_fields": {
"choices": ["Three", "Two", "One"],
"text": "example reverse",
"options": ["A", "B", "C"], # enumerator
},
"reference_fields": {
"choices": ["Three", "Two", "One"],
"label": "Two",
},
"source": "Text: example reverse, Choices: A. Three, B. Two, C. One.",
"target": "B",
"references": ["B"],
"instruction": "",
"target_prefix": "",
"postprocessors": ["processors.to_string_stripped"],
}
]

check_operator(template, inputs, targets, tester=self)

def test_multiple_choice_template_verify_flags(self):
with self.subTest("shuffle_choices + sort_choices_by_length"):
with self.assertRaises(UnitxtError) as ve:
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
shuffle_choices=True,
sort_choices_by_length=True,
)
template.verify()
self.assertIn("You cannot combine shuffle_choices", str(ve.exception))

with self.subTest("shuffle_choices + sort_choices_alphabetically"):
with self.assertRaises(UnitxtError) as ve:
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
shuffle_choices=True,
sort_choices_alphabetically=True,
)
template.verify()
self.assertIn("You cannot combine shuffle_choices", str(ve.exception))

with self.subTest("shuffle_choices + reverse_choices"):
with self.assertRaises(UnitxtError) as ve:
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
shuffle_choices=True,
reverse_choices=True,
)
template.verify()
self.assertIn("You cannot combine shuffle_choices", str(ve.exception))

with self.subTest("sort_choices_by_length + sort_choices_alphabetically"):
with self.assertRaises(UnitxtError) as ve:
template = MultipleChoiceTemplate(
input_format="Text: {text}, Choices: {choices}.",
enumerator="capitals",
sort_choices_by_length=True,
sort_choices_alphabetically=True,
)
template.verify()
self.assertIn(
"You cannot combine both sort_choices_by_length and sort_choices_alphabetically",
str(ve.exception),
)

def test_key_val_template_simple(self):
template = KeyValTemplate()
instance = {
Expand Down
Loading