In [195]:
%load_ext autoreload
%autoreload 2

## Load Three Selection Task Datasets

In [198]:
import os
import json

os.chdir("../")

with open("data_save/task_data/landmarks.json", "r") as f:
    LANDMARKS = json.load(f)

with open("data_save/task_data/rhymes.json", "r") as f:
    RHYMES_WITH = json.load(f)

with open("data_save/task_data/profession.json", "r") as f:
    PROFESSIONS = json.load(f)

In [202]:
import random
from typing import Literal

def compose_selection_prompt(
    data: dict,
    entity_type: Literal["words", "people", "landmarks"],
    n_distractors: int = 3,
) -> str:
    # Map the entity type to a predicate
    predicate_map = {
        "words": "rhymes with",
        "people": "shares a profession with",
        "landmarks": "is located in",
    }
    predicate = predicate_map[entity_type]

    # Select a category for the subject
    all_categories = list(data.keys())
    category = random.choice(all_categories)

    # Select a subject
    if entity_type == "landmarks":
        subject = category
    else:
        subject = random.choice(data[category])

    # Select a target that is not the subject
    possible_targets = [item for item in data[category] if item != subject]
    target = random.choice(possible_targets)

    # Select distractors from other categories
    other_categories = [key for key in all_categories if key != category]
    try:
        distractors_categories = random.sample(other_categories, n_distractors)
    except ValueError:
        raise ValueError(f"Cannot find {n_distractors} distractors for {category}.")
    distractors = [random.choice(data[category]) for category in distractors_categories]

    # Create and shuffle the final list of options
    options = [target] + distractors
    random.shuffle(options)

    # Format the options list into a string
    joiner = random.choice([", ", " or ", "ORDERED_LIST"])
    if joiner == "ORDERED_LIST":
        options_str = "\n".join(f"{i+1}. {option}" for i, option in enumerate(options))
    else:
        options_str = joiner.join(options)

    # Choose a template and format it
    prompt_templates = [
        "Which of these {} {} {}?\n{}.\nAnswer:",
        "{}.\nWhich of these {} {} {}?.\nAnswer:",
    ]
    template = random.choice(prompt_templates)
    # Check if placeholders are before or after the question
    if template == prompt_templates[0]:
        return template.format(entity_type, predicate, subject, options_str), target
    else:
        return template.format(options_str, entity_type, predicate, subject), target


In [241]:
prompt, target = compose_selection_prompt(
    data=PROFESSIONS,
    entity_type="people",
    n_distractors=5,
)

print(f"{prompt=}")
print(f"{target=}")

prompt='Which of these people shares a profession with Rihanna?\nBob Woodward, Pedri, Dua Lipa, Ridley Scott, Daniil Medvedev, Aziz Ansari.\nAnswer:'
target='Dua Lipa'
