diff --git a/docs/release-notes.md b/docs/release-notes.md index 3a81e29..18a1abb 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -9,6 +9,7 @@ * generalize the Classificator * add verbosity and callback handling in EvoPromptGA * add timestamp to the callback + * removed datasets from repo * changed task creation (now by default with a dataset) diff --git a/promptolution/llms/base_llm.py b/promptolution/llms/base_llm.py index 438ccf1..5cdeb53 100644 --- a/promptolution/llms/base_llm.py +++ b/promptolution/llms/base_llm.py @@ -75,6 +75,14 @@ def get_response(self, prompts: str) -> str: return responses + def set_generation_seed(self, seed: int): + """Set the random seed for reproducibility per request. + + Args: + seed (int): Random seed value. + """ + pass + @abstractmethod def _get_response(self, prompts: List[str]) -> List[str]: """Generate responses for the given prompts. diff --git a/promptolution/llms/vllm.py b/promptolution/llms/vllm.py index e3706c5..b8db466 100644 --- a/promptolution/llms/vllm.py +++ b/promptolution/llms/vllm.py @@ -161,6 +161,14 @@ def update_token_count(self, inputs: List[str], outputs: List[str]): for output in outputs: self.output_token_count += len(self.tokenizer.encode(output)) + def set_generation_seed(self, seed): + """Set the random seed for text generation. + + Args: + seed (int): Random seed for text generation. + """ + self.sampling_params.seed = seed + def __del__(self): """Cleanup method to delete the LLM instance and free up GPU memory.""" del self.llm diff --git a/promptolution/optimizers/opro.py b/promptolution/optimizers/opro.py index 7ef3616..aae22e8 100644 --- a/promptolution/optimizers/opro.py +++ b/promptolution/optimizers/opro.py @@ -1,6 +1,6 @@ -"""Module for OPRO.""" +"""Module implementing the OPRO (Optimization by PROmpting) algorithm.""" -from typing import List +from typing import Dict, List, Optional import numpy as np @@ -10,88 +10,140 @@ class Opro(BaseOptimizer): - """Opro: Optimization by PROmpting. + """OPRO: Optimization by PROmpting. - Proposed by the paper "Large Language Models as Optimizers" by Yang et. al: https://arxiv.org/abs/2309.03409. - This Optimizer works by providing the Meta-LLM with a task-description, as well as previous - prompts with their respective score. + Implementation of the technique proposed in "Large Language Models as Optimizers" + (Yang et al., 2023: https://arxiv.org/abs/2309.03409). - Attributes: - llm (BaseLLM): The Meta-LLM to optimize. - n_samples (int): The number of samples from the task dataset to show the Meta-LLM. - - Methods: - _sample_examples: Sample examples from the task dataset. - _format_old_instructions: Format the previous prompts and their scores. - optimize: Optimize the Meta-LLM by providing it with a new prompt. + OPRO works by providing a meta-LLM with task descriptions and previous + prompt-score pairs to generate improved prompts for a downstream LLM. """ - def __init__(self, meta_llm: BaseLLM, n_samples: int = 2, prompt_template: str = None, **args): - """Initialize the Opro optimizer.""" - self.meta_llm = meta_llm - - assert n_samples > 0, "n_samples must be greater than 0." - self.n_samples = n_samples - - self.meta_prompt = prompt_template if prompt_template else OPRO_TEMPLATE + def __init__( + self, + meta_llm: BaseLLM, + prompt_template: Optional[str] = None, + max_num_instructions: int = 20, + num_instructions_per_step: int = 8, + num_few_shots: int = 3, + **kwargs, + ) -> None: + """Initialize the OPRO optimizer. - super().__init__(**args) + Args: + df_few_shots: DataFrame with few-shot examples (must have 'input' and 'target' columns) + meta_llm: LLM that generates improved prompts + prompt_template: Custom meta prompt template (uses OPRO_TEMPLATE if None) + max_num_instructions: Maximum previous instructions to include in meta prompt + num_instructions_per_step: Number of prompts to generate in each step + num_few_shots: Number of few-shot examples to include (0 for none) + **kwargs: Additional arguments passed to the BaseOptimizer + """ + super().__init__(**kwargs) + self.meta_llm = meta_llm - self.scores = [ - self.task.evaluate(p, self.predictor, subsample=True, n_samples=self.n_eval_samples)[0] - for p in self.prompts - ] + self.meta_prompt_template = prompt_template if prompt_template else OPRO_TEMPLATE + self.max_num_instructions = max_num_instructions + self.num_instructions_per_step = num_instructions_per_step + self.num_few_shots = num_few_shots - def _sample_examples(self): - """Sample examples from the task dataset with their label. + def _sample_examples(self) -> str: + """Sample few-shot examples from the dataset. Returns: - str: The formatted string of sampled examples. + Formatted string of few-shot examples with inputs and expected outputs """ - idx = np.random.choice(len(self.task.xs), self.n_samples) + idx = np.random.choice(len(self.task.xs), self.num_few_shots) sample_x = self.task.xs[idx] sample_y = self.task.ys[idx] return "\n".join([f"Input: {x}\nOutput: {y}" for x, y in zip(sample_x, sample_y)]) - def _format_old_instructions(self): - """Format the previous prompts and their respective scores. + def _format_instructions(self) -> str: + """Format previous prompts and their scores for the meta prompt. Returns: - str: The formatted string of previous prompts and their scores. + Formatted string of previous prompts and their scores, + sorted by ascending score (worse to better) """ - return "".join( - [ - f"The old instruction was:\n{prompt}\nIt scored: {score}\n\n" - for prompt, score in zip(self.prompts, self.scores) - ] - ) + prompt_score_pairs = list(zip(self.prompts, self.scores)) + sorted_pairs = sorted(prompt_score_pairs, key=lambda x: x[1]) + + return "".join([f"text:\n{prompt}\nscore: {int(100 * round(score, 2))}\n\n" for prompt, score in sorted_pairs]) + + def _add_prompt_and_score(self, prompt: str, score: float) -> None: + """Add a prompt and its score to the lists, maintaining max length. + + Args: + prompt: The prompt to add + score: The corresponding score for the prompt + """ + if prompt in self.prompts: + return + + self.prompts.append(prompt) + self.scores.append(score) + + # Keep only the top-performing prompts if we exceed the maximum number of instructions + keep_indices = np.argsort(self.scores)[-self.max_num_instructions :] + self.prompts = [self.prompts[i] for i in keep_indices] + self.scores = [self.scores[i] for i in keep_indices] def optimize(self, n_steps: int) -> List[str]: - """Optimize the Meta-LLM by providing it with a new prompt. + """Run the OPRO optimization process. Args: - n_steps (int): The number of optimization steps to perform. + n_steps: Number of optimization steps to perform Returns: - str: The best prompt found by the optimizer. + List of all prompts generated during optimization """ + self.scores = list(self.task.evaluate(self.prompts, self.predictor)) + self.meta_prompt = self.meta_prompt_template.replace("", self._format_instructions()).replace( + "", self._sample_examples() + ) + for _ in range(n_steps): - meta_prompt = self.meta_prompt.replace("", self._format_old_instructions()).replace( + duplicate_prompts = 0 + for _ in range(self.num_instructions_per_step): + generation_seed = np.random.randint(0, int(1e9)) + self.meta_llm.set_generation_seed(generation_seed) + + if self.verbosity > 1: + print(f"Seed: {generation_seed}") + response = self.meta_llm.get_response([self.meta_prompt])[0] + + prompt = response.split("")[-1].split("")[0].strip() + + if prompt in self.prompts: + duplicate_prompts += 1 + continue + + score = self.task.evaluate(prompt, self.predictor)[0] + + self._add_prompt_and_score(prompt, score) + + if self.verbosity > 1: + print(f"New Instruction: {prompt}\nScore: {score}\n") + + # Update meta prompt + self.meta_prompt = self.meta_prompt_template.replace("", self._format_instructions()).replace( "", self._sample_examples() ) - prompt = self.meta_llm.get_response([meta_prompt])[0] - prompt = prompt.split("")[-1].split("")[0].strip() - score = self.task.evaluate(prompt, self.predictor, subsample=True, n_samples=self.n_eval_samples) - - self.prompts.append(prompt) - self.scores.append(score) + if self.verbosity > 1: + print(f"New meta prompt:\n{self.meta_prompt}\n") continue_optimization = self._on_step_end() + if not continue_optimization: break - self._on_epoch_end() + # stop optimization if all generated prompts are duplicates (converged) + if duplicate_prompts == self.num_instructions_per_step: + if self.verbosity > 0: + print("All generated prompts are duplicates. Stopping optimization.") + break + self._on_train_end() return self.prompts diff --git a/promptolution/tasks/classification_tasks.py b/promptolution/tasks/classification_tasks.py index 82823d3..03231e7 100644 --- a/promptolution/tasks/classification_tasks.py +++ b/promptolution/tasks/classification_tasks.py @@ -75,10 +75,7 @@ def evaluate( n_samples (int, optional): Number of samples to use if subsampling. Defaults to 20. subsample (bool, optional): Whether to use subsampling. If set to true, samples a different subset per call. Defaults to False. - return_seq (bool, optional): whether to return the generating sequence - subsample (bool, optional): Whether to use subsampling. - If set to true, samples a different subset per call. Defaults to False. - return_seq (bool, optional): whether to return the generating sequence + return_seq (bool, optional): whether to return the generating sequence. Returns: np.ndarray: Array of accuracy scores for each prompt. diff --git a/promptolution/templates.py b/promptolution/templates.py index 6cbc39e..514e48a 100644 --- a/promptolution/templates.py +++ b/promptolution/templates.py @@ -90,8 +90,7 @@ Below are some previous instructions with their scores. The score ranges from 0 to 100. - - + Here are some examples of the target dataset: @@ -104,8 +103,7 @@ Below are some previous instructions with their scores. The score ranges from 0 to 100. - - + Here are some examples of the target dataset: diff --git a/scripts/opro_test.py b/scripts/opro_test.py new file mode 100644 index 0000000..278a53f --- /dev/null +++ b/scripts/opro_test.py @@ -0,0 +1,102 @@ +"""Test run for the Opro optimizer.""" + +import argparse +import random +from logging import Logger + +from promptolution.callbacks import LoggerCallback, CSVCallback, TokenCountCallback +from promptolution.templates import OPRO_TEMPLATE_TD +from promptolution.helpers import get_llm +from promptolution.tasks import ClassificationTask +from promptolution.predictors import MarkerBasedClassificator +from promptolution.optimizers import Opro +from datasets import load_dataset + +logger = Logger(__name__) + +"""Run a test run for any of the implemented optimizers.""" +parser = argparse.ArgumentParser() +parser.add_argument("--model") +parser.add_argument("--model-storage-path", default="../models/") +parser.add_argument("--output-dir", default="results/opro_test/") +parser.add_argument("--max-model-len", type=int, default=2048) +parser.add_argument("--n-steps", type=int, default=999) +parser.add_argument("--token", default=None) +parser.add_argument("--seed", type=int, default=187) +args = parser.parse_args() + +callbacks = [ + LoggerCallback(logger), + CSVCallback(args.output_dir), + TokenCountCallback(5000000, "input_tokens"), +] + +df = load_dataset("SetFit/ag_news", split="train", revision="main").to_pandas().sample(300, random_state=args.seed) + +df["input"] = df["text"] +df["target"] = df["label_text"] + +task = ClassificationTask( + df, + description="The dataset contains news articles categorized into four classes: World, Sports, Business, and Tech. The task is to classify each news article into one of the four categories.", + x_column="input", + y_column="target", +) + +initial_prompts = [ + "Classify this news article as World, Sports, Business, or Tech. Provide your answer between and tags.", + "Read the following news article and determine which category it belongs to: World, Sports, Business, or Tech. Your classification must be placed between markers.", + "What is the primary category of this news piece? Choose from World, Sports, Business, or Tech. Place your selected category between .", + "Analyze this news article and categorize it as either World, Sports, Business, or Tech. Format your answer within tags.", + "Your task is to identify whether this news article belongs to World, Sports, Business, or Tech news. Provide your classification between the markers .", + "Please review the following news content and classify it into one of these categories: World, Sports, Business, or Tech. Your answer must be formatted with tags.", + "Based on the content, determine if this news article falls under World, Sports, Business, or Tech category. Return only your classification within .", + "Examine this news article and identify its primary category (World, Sports, Business, or Tech). Your final classification should be enclosed between markers.", + "In this task, you must categorize a news article into one of four classes: World, Sports, Business, or Tech. Remember to place your answer between tags for proper evaluation.", + "Read the provided news excerpt carefully and assign it to either World, Sports, Business, or Tech category. Ensure your answer appears between tags.", + "Considering the main subject matter, classify this news article as World, Sports, Business, or Tech. Format your response with .", + "Determine the appropriate category for this news article from the following options: World, Sports, Business, or Tech. Your selected category must be placed within markers.", + "After analyzing the given news article, assign it to the most suitable category: World, Sports, Business, or Tech. Your classification should be enclosed in tags.", + "Your objective is to classify the news article into one of the following categories: World, Sports, Business, or Tech based on its primary focus. Submit your answer between tags.", + "Which category best describes this news article: World, Sports, Business, or Tech? Provide your answer inside markers.", + "As a content classifier, determine if the following news article belongs to World, Sports, Business, or Tech news. Place your answer within tags.", + "Evaluate the following news article and indicate whether it primarily concerns World, Sports, Business, or Tech topics. Your classification must appear between .", + "Given a news article, your task is to determine its primary category from World, Sports, Business, or Tech. The final classification must be provided between tags.", + "Conduct a thorough analysis of the provided news article and classify it as belonging to one of these four categories: World, Sports, Business, or Tech. Your answer should be presented within markers.", + "Simply indicate whether this news article is about World, Sports, Business, or Tech. Include your answer between tags.", +] + +initial_prompts = random.sample(initial_prompts, 10) + +if "vllm" in args.model: + llm = get_llm( + args.model, + batch_size=None, + max_model_len=args.max_model_len, + model_storage_path=args.model_storage_path, + revision="main", + ) +else: + llm = get_llm(args.model, args.token) + +downstream_llm = llm +meta_llm = llm + +predictor = MarkerBasedClassificator(downstream_llm, classes=task.classes) + +optimizer = Opro( + task=task, + prompt_template=OPRO_TEMPLATE_TD.replace("