diff --git a/promptolution/llms/api_llm.py b/promptolution/llms/api_llm.py index d00bc91..e1286f6 100644 --- a/promptolution/llms/api_llm.py +++ b/promptolution/llms/api_llm.py @@ -1,144 +1,86 @@ """Module to interface with various language models through their respective APIs.""" -import asyncio -import time -from logging import Logger -from typing import Any, List -import nest_asyncio -import openai -import requests -from langchain_anthropic import ChatAnthropic -from langchain_community.chat_models.deepinfra import ChatDeepInfra, ChatDeepInfraException -from langchain_core.messages import HumanMessage, SystemMessage -from langchain_openai import ChatOpenAI +try: + import asyncio -from promptolution.llms.base_llm import BaseLLM + from openai import AsyncOpenAI -logger = Logger(__name__) + import_successful = True +except ImportError: + import_successful = False +from logging import Logger +from typing import Any, List -async def invoke_model(prompt, system_prompt, model, semaphore): - """Asynchronously invoke a language model with retry logic. +from promptolution.llms.base_llm import BaseLLM - Args: - prompt (str): The input prompt for the model. - system_prompt (str): The system prompt for the model. - model: The language model to invoke. - semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls. +logger = Logger(__name__) - Returns: - str: The model's response content. - Raises: - ChatDeepInfraException: If all retry attempts fail. - """ +async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, semaphore): async with semaphore: - max_retries = 100 - delay = 3 - attempts = 0 - - while attempts < max_retries: - try: - response = await model.ainvoke([SystemMessage(content=system_prompt), HumanMessage(content=prompt)]) - return response.content - except ChatDeepInfraException as e: - print(f"DeepInfra error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds...") - attempts += 1 - await asyncio.sleep(delay) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}] + response = await client.chat.completions.create( + model=model_id, + messages=messages, + max_tokens=max_tokens, + ) + return response class APILLM(BaseLLM): - """A class to interface with various language models through their respective APIs. + """A class to interface with language models through their respective APIs. - This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models. - It handles API key management, model initialization, and provides methods for - both synchronous and asynchronous inference. + This class provides a unified interface for making API calls to language models + using the OpenAI client library. It handles rate limiting through semaphores + and supports both synchronous and asynchronous operations. Attributes: - model: The initialized language model instance. - - Methods: - get_response: Synchronously get responses for a list of prompts. - get_response_async: Asynchronously get responses for a list of prompts. + model_id (str): Identifier for the model to use. + client (AsyncOpenAI): The initialized API client. + max_tokens (int): Maximum number of tokens in model responses. + semaphore (asyncio.Semaphore): Semaphore to limit concurrent API calls. """ - def __init__(self, model_id: str, token: str = None, **kwargs: Any): - """Initialize the APILLM with a specific model. + def __init__( + self, api_url: str, model_id: str, token: str = None, max_concurrent_calls=50, max_tokens=512, **kwargs: Any + ): + """Initialize the APILLM with a specific model and API configuration. Args: + api_url (str): The base URL for the API endpoint. model_id (str): Identifier for the model to use. - token (str): API key for the model. + token (str, optional): API key for authentication. Defaults to None. + max_concurrent_calls (int, optional): Maximum number of concurrent API calls. Defaults to 50. + max_tokens (int, optional): Maximum number of tokens in model responses. Defaults to 512. + **kwargs (Any): Additional parameters to pass to the API client. Raises: - ValueError: If an unknown model identifier is provided. + ImportError: If required libraries are not installed. """ + if not import_successful: + raise ImportError( + "Could not import at least one of the required libraries: openai, asyncio. " + "Please ensure they are installed in your environment." + ) super().__init__() - if "claude" in model_id: - self.model = ChatAnthropic(model=model_id, api_key=token) - elif "gpt" in model_id: - self.model = ChatOpenAI(model=model_id, api_key=token) - else: - self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token) - - def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]: - """Get responses for a list of prompts in a synchronous manner. + self.model_id = model_id + self.client = AsyncOpenAI(base_url=api_url, api_key=token, **kwargs) + self.max_tokens = max_tokens - This method includes retry logic for handling connection errors and rate limits. + self.semaphore = asyncio.Semaphore(max_concurrent_calls) - Args: - prompts (list[str]): List of input prompts. - system_prompts (list[str]): List of system prompts. If not provided, uses default system_prompts - - Returns: - list[str]: List of model responses. - - Raises: - requests.exceptions.ConnectionError: If max retries are exceeded. - """ - max_retries = 100 - delay = 3 - attempts = 0 - - nest_asyncio.apply() - - while attempts < max_retries: - try: - responses = asyncio.run(self.get_response_async(prompts)) - return responses - except requests.exceptions.ConnectionError as e: - attempts += 1 - logger.critical( - f"Connection error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..." - ) - time.sleep(delay) - except openai.RateLimitError as e: - attempts += 1 - logger.critical( - f"Rate limit error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..." - ) - time.sleep(delay) - - # If the loop exits, it means max retries were reached - raise requests.exceptions.ConnectionError("Max retries exceeded. Connection could not be established.") - - async def get_response_async(self, prompts: list[str], max_concurrent_calls=200) -> list[str]: - """Asynchronously get responses for a list of prompts. - - This method uses a semaphore to limit the number of concurrent API calls. - - Args: - prompts (list[str]): List of input prompts. - max_concurrent_calls (int): Maximum number of concurrent API calls allowed. - - Returns: - list[str]: List of model responses. - """ - semaphore = asyncio.Semaphore(max_concurrent_calls) - tasks = [] - - for prompt in prompts: - tasks.append(invoke_model(prompt, self.model, semaphore)) + def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]: + # Setup for async execution in sync context + loop = asyncio.get_event_loop() + responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts)) + return responses + async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]: + tasks = [ + _invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore) + for prompt, system_prompt in zip(prompts, system_prompts) + ] responses = await asyncio.gather(*tasks) - return responses + return [response.choices[0].message.content for response in responses] diff --git a/promptolution/llms/base_llm.py b/promptolution/llms/base_llm.py index 1a79d29..b9be7fb 100644 --- a/promptolution/llms/base_llm.py +++ b/promptolution/llms/base_llm.py @@ -91,7 +91,7 @@ def set_generation_seed(self, seed: int): pass @abstractmethod - def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]: + def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]: """Generate responses for the given prompts. This method should be implemented by subclasses to define how diff --git a/promptolution/llms/local_llm.py b/promptolution/llms/local_llm.py index 46afe17..c735f3c 100644 --- a/promptolution/llms/local_llm.py +++ b/promptolution/llms/local_llm.py @@ -2,11 +2,10 @@ try: import torch import transformers -except ImportError as e: - import logging - logger = logging.getLogger(__name__) - logger.warning(f"Could not import torch or transformers in local_llm.py: {e}") + imports_successful = True +except ImportError: + imports_successful = False from promptolution.llms.base_llm import BaseLLM @@ -35,6 +34,11 @@ def __init__(self, model_id: str, batch_size=8): This method sets up a text generation pipeline with bfloat16 precision, automatic device mapping, and specific generation parameters. """ + if not imports_successful: + raise ImportError( + "Could not import at least one of the required libraries: torch, transformers. " + "Please ensure they are installed in your environment." + ) super().__init__() self.pipeline = transformers.pipeline( @@ -78,8 +82,5 @@ def _get_response(self, prompts: list[str], system_prompts: list[str]) -> list[s def __del__(self): """Cleanup method to delete the pipeline and free up GPU memory.""" - try: - del self.pipeline - torch.cuda.empty_cache() - except Exception as e: - logger.warning(f"Error during LocalLLM cleanup: {e}") + del self.pipeline + torch.cuda.empty_cache() diff --git a/promptolution/llms/vllm.py b/promptolution/llms/vllm.py index ec6505e..824ef08 100644 --- a/promptolution/llms/vllm.py +++ b/promptolution/llms/vllm.py @@ -12,8 +12,10 @@ import torch from transformers import AutoTokenizer from vllm import LLM, SamplingParams -except ImportError as e: - logger.warning(f"Could not import vllm, torch or transformers in vllm.py: {e}") + + imports_successful = True +except ImportError: + imports_successful = False class VLLM(BaseLLM): @@ -68,6 +70,11 @@ def __init__( Note: This method sets up a vLLM engine with specified parameters for efficient inference. """ + if not imports_successful: + raise ImportError( + "Could not import at least one of the required libraries: torch, transformers, vllm. " + "Please ensure they are installed in your environment." + ) super().__init__() self.dtype = dtype diff --git a/pyproject.toml b/pyproject.toml index d8bc054..97571a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,34 +1,36 @@ [tool.poetry] name = "promptolution" version = "1.3.2" -description = "" +description = "A framework for prompt optimization and a zoo of prompt optimization algorithms." authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"] readme = "README.md" [tool.poetry.dependencies] python = "^3.9" numpy = "^1.26.0" -langchain-anthropic = "^0.1.22" -langchain-openai = "^0.1.21" -langchain-core = "^0.2.29" -langchain-community = "^0.2.12" pandas = "^2.2.2" tqdm = "^4.66.5" scikit-learn = "^1.5.2" + +[tool.poetry.group.requests.dependencies] +openai = "^1.0.0" +requests = "^2.31.0" + +[tool.poetry.group.vllm.dependencies] vllm = "^0.7.3" -datasets = "^3.3.2" + +[tool.poetry.group.transformers.dependencies] +transformers = "^4.48.0" [tool.poetry.group.dev.dependencies] matplotlib = "^3.9.2" seaborn = "^0.13.2" -transformers = "^4.48.0" black = "^24.4.2" flake8 = "^7.1.0" isort = "^5.13.2" pre-commit = "^3.7.1" ipykernel = "^6.29.5" - [tool.poetry.group.docs.dependencies] mkdocs = "^1.6.1" mkdocs-material = "^9.5.39" @@ -46,4 +48,4 @@ line_length = 120 profile = "black" [tool.pydocstyle] -convention = "google" +convention = "google" \ No newline at end of file diff --git a/scripts/api_test.py b/scripts/api_test.py new file mode 100644 index 0000000..cfced84 --- /dev/null +++ b/scripts/api_test.py @@ -0,0 +1,70 @@ +"""Test run for the Opro optimizer.""" + +import argparse +import random +from logging import Logger + +from promptolution.callbacks import LoggerCallback +from promptolution.templates import EVOPROMPT_GA_TEMPLATE +from promptolution.tasks import ClassificationTask +from promptolution.predictors import MarkerBasedClassificator +from promptolution.optimizers import EvoPromptGA +from datasets import load_dataset + +from promptolution.llms.api_llm import APILLM + +logger = Logger(__name__) + +"""Run a test run for any of the implemented optimizers.""" +parser = argparse.ArgumentParser() +parser.add_argument("--base-url", default="https://api.openai.com/v1") +parser.add_argument("--model", default="gpt-4o-2024-08-06") +# parser.add_argument("--base-url", default="https://api.deepinfra.com/v1/openai") +# parser.add_argument("--model", default="meta-llama/Meta-Llama-3-8B-Instruct") +# parser.add_argument("--base-url", default="https://api.anthropic.com/v1/") +# parser.add_argument("--model", default="claude-3-haiku-20240307") +parser.add_argument("--n-steps", type=int, default=2) +parser.add_argument("--token", default=None) +args = parser.parse_args() + +df = load_dataset("SetFit/ag_news", split="train", revision="main").to_pandas().sample(300) + +df["input"] = df["text"] +df["target"] = df["label_text"] + +task = ClassificationTask( + df, + description="The dataset contains news articles categorized into four classes: World, Sports, Business, and Tech. The task is to classify each news article into one of the four categories.", + x_column="input", + y_column="target", +) + +initial_prompts = [ + "Classify this news article as World, Sports, Business, or Tech. Provide your answer between and tags.", + "Read the following news article and determine which category it belongs to: World, Sports, Business, or Tech. Your classification must be placed between markers.", + "Your task is to identify whether this news article belongs to World, Sports, Business, or Tech news. Provide your classification between the markers .", + "Conduct a thorough analysis of the provided news article and classify it as belonging to one of these four categories: World, Sports, Business, or Tech. Your answer should be presented within markers.", +] + +llm = APILLM(api_url=args.base_url, model_id=args.model, token=args.token) +downstream_llm = llm +meta_llm = llm + +predictor = MarkerBasedClassificator(downstream_llm, classes=task.classes) + +callbacks = [LoggerCallback(logger)] + +optimizer = EvoPromptGA( + task=task, + prompt_template=EVOPROMPT_GA_TEMPLATE, + predictor=predictor, + meta_llm=meta_llm, + initial_prompts=initial_prompts, + callbacks=callbacks, + n_eval_samples=20, + verbosity=2, # for debugging +) + +best_prompts = optimizer.optimize(n_steps=args.n_steps) + +logger.info(f"Optimized prompts: {best_prompts}")