Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions promptolution/utils/prompt_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Utility functions for prompt creation."""

from typing import List, Union

import numpy as np

from promptolution.llms.base_llm import BaseLLM
from promptolution.tasks.base_task import BaseTask
from promptolution.tasks.classification_tasks import ClassificationTask


def create_prompt_variation(prompt: Union[List[str], str], llm: BaseLLM, meta_prompt: str = None) -> List[str]:
"""Generate a variation of the given prompt(s) while keeping the semantic meaning.

Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910

Args:
prompt (Union[List[str], str]): The prompt(s) to generate variations of.
llm (BaseLLM): The language model to use for generating the variations.
meta_prompt (str): The meta prompt to use for generating the variations.
If None, a default meta prompt is used. Should contain <prev_prompt> tag.

Returns:
List[str]: A list of generated variations of the input prompt(s).
"""
if meta_prompt is None:
meta_prompt = """Generate a single variation of the following instruction while keeping the semantic meaning.
Generate the variation starting with <prompt> and ending with </prompt> tags.

Input: <prev_prompt>

Output:"""

if isinstance(prompt, str):
prompt = [prompt]
varied_prompts = llm.get_response([meta_prompt.replace("<prev_prompt>", p) for p in prompt])

varied_prompts = [p.split("</prompt>")[0].split("<prompt>")[-1] for p in varied_prompts]

return varied_prompts


def create_prompts_from_samples(task: BaseTask, llm: BaseLLM, meta_prompt: str = None, n_samples: int = 3) -> List[str]:
"""Generate a set of prompts from dataset examples sampled from a given task.

Idea taken from the paper Zhou et al. (2021) https://arxiv.org/pdf/2211.01910
Samples are selected, such that
(1) all possible classes are represented
(2) the samples are as representative as possible

Args:
task (BaseTask): The task to generate prompts for.
Xs and Ys from this object are used to generate the prompts.
llm (BaseLLM): The language model to use for generating the prompts.
meta_prompt (str): The meta prompt to use for generating the prompts.
If None, a default meta prompt is used.
n_samples (int): The number of samples to use for generating prompts.

Returns:
List[str]: A list of generated prompts.
"""
if isinstance(task, ClassificationTask):
# if classification task sample such that all classes are represented
unique_classes, counts = np.unique(task.ys, return_counts=True)
proportions = counts / len(task.ys)
samples_per_class = np.round(proportions * n_samples).astype(int)
samples_per_class = np.maximum(samples_per_class, 1)

# sample
xs = []
ys = []
for cls, n_samples in zip(unique_classes, samples_per_class):
indices = np.where(task.ys == cls)[0]
indices = np.random.choice(indices, n_samples, replace=False)
xs.extend(task.xs[indices])
ys.extend(task.ys[indices])

else:
# if not classification task, sample randomly
indices = np.random.choice(len(task.xs), n_samples, replace=False)
xs = task.xs[indices].tolist()
ys = task.ys[indices].tolist()

if meta_prompt is None:
meta_prompt = (
"You are asked to give the corresponding prompt that gives the following outputs given these inputs."
+ "Return it starting with <prompt> and ending with </prompt> tags."
+ "Include the name of the output classes in the prompt."
)

for x, y in zip(xs, ys):
meta_prompt += f"\n\nInput: {x}\nOutput: {y}"

meta_prompt += "\nThe instruction was"

prompt = llm.get_response([meta_prompt])[0]
prompt = prompt.split("</prompt>")[0].split("<prompt>")[-1]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "promptolution"
version = "0.1.1"
version = "0.2.0"
description = ""
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
readme = "README.md"
Expand Down
42 changes: 42 additions & 0 deletions scripts/prompt_creation_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Script to run prompt creation and evaluation."""

from configparser import ConfigParser
from logging import Logger

from promptolution.llms import get_llm
from promptolution.predictors import get_predictor
from promptolution.tasks import get_tasks
from promptolution.utils.prompt_creation import create_prompt_variation, create_prompts_from_samples

logger = Logger(__name__)


def main():
"""Main function to run the experiment."""
config = ConfigParser()
config.task_name = "subj"
config.ds_path = "data_sets/cls/subj"
config.random_seed = 42

llm = get_llm("meta-llama/Meta-Llama-3-8B-Instruct")
task = get_tasks(config)[0]
predictor = get_predictor("meta-llama/Meta-Llama-3-8B-Instruct", classes=task.classes)

init_prompts = create_prompts_from_samples(task, llm)
logger.critical(f"Initial prompts: {init_prompts}")

# evaluate on task
scores = task.evaluate(init_prompts, predictor)
logger.critical(f"Initial scores {scores.mean()} +/- {scores.std()}")

varied_prompts = create_prompt_variation(init_prompts, llm)[0]

logger.critical(f"Varied prompts: {varied_prompts}")

# evaluate on task
scores = task.evaluate(varied_prompts, predictor)
logger.critical(f"Varied scores {scores.mean()} +/- {scores.std()}")


if __name__ == "__main__":
main()