diff --git a/promptolution/llms/api_llm.py b/promptolution/llms/api_llm.py
index d00bc91..e1286f6 100644
--- a/promptolution/llms/api_llm.py
+++ b/promptolution/llms/api_llm.py
@@ -1,144 +1,86 @@
"""Module to interface with various language models through their respective APIs."""
-import asyncio
-import time
-from logging import Logger
-from typing import Any, List
-import nest_asyncio
-import openai
-import requests
-from langchain_anthropic import ChatAnthropic
-from langchain_community.chat_models.deepinfra import ChatDeepInfra, ChatDeepInfraException
-from langchain_core.messages import HumanMessage, SystemMessage
-from langchain_openai import ChatOpenAI
+try:
+ import asyncio
-from promptolution.llms.base_llm import BaseLLM
+ from openai import AsyncOpenAI
-logger = Logger(__name__)
+ import_successful = True
+except ImportError:
+ import_successful = False
+from logging import Logger
+from typing import Any, List
-async def invoke_model(prompt, system_prompt, model, semaphore):
- """Asynchronously invoke a language model with retry logic.
+from promptolution.llms.base_llm import BaseLLM
- Args:
- prompt (str): The input prompt for the model.
- system_prompt (str): The system prompt for the model.
- model: The language model to invoke.
- semaphore (asyncio.Semaphore): Semaphore to limit concurrent calls.
+logger = Logger(__name__)
- Returns:
- str: The model's response content.
- Raises:
- ChatDeepInfraException: If all retry attempts fail.
- """
+async def _invoke_model(prompt, system_prompt, max_tokens, model_id, client, semaphore):
async with semaphore:
- max_retries = 100
- delay = 3
- attempts = 0
-
- while attempts < max_retries:
- try:
- response = await model.ainvoke([SystemMessage(content=system_prompt), HumanMessage(content=prompt)])
- return response.content
- except ChatDeepInfraException as e:
- print(f"DeepInfra error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds...")
- attempts += 1
- await asyncio.sleep(delay)
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}]
+ response = await client.chat.completions.create(
+ model=model_id,
+ messages=messages,
+ max_tokens=max_tokens,
+ )
+ return response
class APILLM(BaseLLM):
- """A class to interface with various language models through their respective APIs.
+ """A class to interface with language models through their respective APIs.
- This class supports Claude (Anthropic), GPT (OpenAI), and LLaMA (DeepInfra) models.
- It handles API key management, model initialization, and provides methods for
- both synchronous and asynchronous inference.
+ This class provides a unified interface for making API calls to language models
+ using the OpenAI client library. It handles rate limiting through semaphores
+ and supports both synchronous and asynchronous operations.
Attributes:
- model: The initialized language model instance.
-
- Methods:
- get_response: Synchronously get responses for a list of prompts.
- get_response_async: Asynchronously get responses for a list of prompts.
+ model_id (str): Identifier for the model to use.
+ client (AsyncOpenAI): The initialized API client.
+ max_tokens (int): Maximum number of tokens in model responses.
+ semaphore (asyncio.Semaphore): Semaphore to limit concurrent API calls.
"""
- def __init__(self, model_id: str, token: str = None, **kwargs: Any):
- """Initialize the APILLM with a specific model.
+ def __init__(
+ self, api_url: str, model_id: str, token: str = None, max_concurrent_calls=50, max_tokens=512, **kwargs: Any
+ ):
+ """Initialize the APILLM with a specific model and API configuration.
Args:
+ api_url (str): The base URL for the API endpoint.
model_id (str): Identifier for the model to use.
- token (str): API key for the model.
+ token (str, optional): API key for authentication. Defaults to None.
+ max_concurrent_calls (int, optional): Maximum number of concurrent API calls. Defaults to 50.
+ max_tokens (int, optional): Maximum number of tokens in model responses. Defaults to 512.
+ **kwargs (Any): Additional parameters to pass to the API client.
Raises:
- ValueError: If an unknown model identifier is provided.
+ ImportError: If required libraries are not installed.
"""
+ if not import_successful:
+ raise ImportError(
+ "Could not import at least one of the required libraries: openai, asyncio. "
+ "Please ensure they are installed in your environment."
+ )
super().__init__()
- if "claude" in model_id:
- self.model = ChatAnthropic(model=model_id, api_key=token)
- elif "gpt" in model_id:
- self.model = ChatOpenAI(model=model_id, api_key=token)
- else:
- self.model = ChatDeepInfra(model_name=model_id, deepinfra_api_token=token)
-
- def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
- """Get responses for a list of prompts in a synchronous manner.
+ self.model_id = model_id
+ self.client = AsyncOpenAI(base_url=api_url, api_key=token, **kwargs)
+ self.max_tokens = max_tokens
- This method includes retry logic for handling connection errors and rate limits.
+ self.semaphore = asyncio.Semaphore(max_concurrent_calls)
- Args:
- prompts (list[str]): List of input prompts.
- system_prompts (list[str]): List of system prompts. If not provided, uses default system_prompts
-
- Returns:
- list[str]: List of model responses.
-
- Raises:
- requests.exceptions.ConnectionError: If max retries are exceeded.
- """
- max_retries = 100
- delay = 3
- attempts = 0
-
- nest_asyncio.apply()
-
- while attempts < max_retries:
- try:
- responses = asyncio.run(self.get_response_async(prompts))
- return responses
- except requests.exceptions.ConnectionError as e:
- attempts += 1
- logger.critical(
- f"Connection error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..."
- )
- time.sleep(delay)
- except openai.RateLimitError as e:
- attempts += 1
- logger.critical(
- f"Rate limit error: {e}. Attempt {attempts}/{max_retries}. Retrying in {delay} seconds..."
- )
- time.sleep(delay)
-
- # If the loop exits, it means max retries were reached
- raise requests.exceptions.ConnectionError("Max retries exceeded. Connection could not be established.")
-
- async def get_response_async(self, prompts: list[str], max_concurrent_calls=200) -> list[str]:
- """Asynchronously get responses for a list of prompts.
-
- This method uses a semaphore to limit the number of concurrent API calls.
-
- Args:
- prompts (list[str]): List of input prompts.
- max_concurrent_calls (int): Maximum number of concurrent API calls allowed.
-
- Returns:
- list[str]: List of model responses.
- """
- semaphore = asyncio.Semaphore(max_concurrent_calls)
- tasks = []
-
- for prompt in prompts:
- tasks.append(invoke_model(prompt, self.model, semaphore))
+ def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
+ # Setup for async execution in sync context
+ loop = asyncio.get_event_loop()
+ responses = loop.run_until_complete(self._get_response_async(prompts, system_prompts))
+ return responses
+ async def _get_response_async(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
+ tasks = [
+ _invoke_model(prompt, system_prompt, self.max_tokens, self.model_id, self.client, self.semaphore)
+ for prompt, system_prompt in zip(prompts, system_prompts)
+ ]
responses = await asyncio.gather(*tasks)
- return responses
+ return [response.choices[0].message.content for response in responses]
diff --git a/promptolution/llms/base_llm.py b/promptolution/llms/base_llm.py
index 1a79d29..b9be7fb 100644
--- a/promptolution/llms/base_llm.py
+++ b/promptolution/llms/base_llm.py
@@ -91,7 +91,7 @@ def set_generation_seed(self, seed: int):
pass
@abstractmethod
- def _get_response(self, prompts: List[str], system_prompts: List[str] = None) -> List[str]:
+ def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[str]:
"""Generate responses for the given prompts.
This method should be implemented by subclasses to define how
diff --git a/promptolution/llms/local_llm.py b/promptolution/llms/local_llm.py
index 46afe17..c735f3c 100644
--- a/promptolution/llms/local_llm.py
+++ b/promptolution/llms/local_llm.py
@@ -2,11 +2,10 @@
try:
import torch
import transformers
-except ImportError as e:
- import logging
- logger = logging.getLogger(__name__)
- logger.warning(f"Could not import torch or transformers in local_llm.py: {e}")
+ imports_successful = True
+except ImportError:
+ imports_successful = False
from promptolution.llms.base_llm import BaseLLM
@@ -35,6 +34,11 @@ def __init__(self, model_id: str, batch_size=8):
This method sets up a text generation pipeline with bfloat16 precision,
automatic device mapping, and specific generation parameters.
"""
+ if not imports_successful:
+ raise ImportError(
+ "Could not import at least one of the required libraries: torch, transformers. "
+ "Please ensure they are installed in your environment."
+ )
super().__init__()
self.pipeline = transformers.pipeline(
@@ -78,8 +82,5 @@ def _get_response(self, prompts: list[str], system_prompts: list[str]) -> list[s
def __del__(self):
"""Cleanup method to delete the pipeline and free up GPU memory."""
- try:
- del self.pipeline
- torch.cuda.empty_cache()
- except Exception as e:
- logger.warning(f"Error during LocalLLM cleanup: {e}")
+ del self.pipeline
+ torch.cuda.empty_cache()
diff --git a/promptolution/llms/vllm.py b/promptolution/llms/vllm.py
index ec6505e..824ef08 100644
--- a/promptolution/llms/vllm.py
+++ b/promptolution/llms/vllm.py
@@ -12,8 +12,10 @@
import torch
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
-except ImportError as e:
- logger.warning(f"Could not import vllm, torch or transformers in vllm.py: {e}")
+
+ imports_successful = True
+except ImportError:
+ imports_successful = False
class VLLM(BaseLLM):
@@ -68,6 +70,11 @@ def __init__(
Note:
This method sets up a vLLM engine with specified parameters for efficient inference.
"""
+ if not imports_successful:
+ raise ImportError(
+ "Could not import at least one of the required libraries: torch, transformers, vllm. "
+ "Please ensure they are installed in your environment."
+ )
super().__init__()
self.dtype = dtype
diff --git a/pyproject.toml b/pyproject.toml
index d8bc054..97571a4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,34 +1,36 @@
[tool.poetry]
name = "promptolution"
version = "1.3.2"
-description = ""
+description = "A framework for prompt optimization and a zoo of prompt optimization algorithms."
authors = ["Tom Zehle, Moritz Schlager, Timo Heiß"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.9"
numpy = "^1.26.0"
-langchain-anthropic = "^0.1.22"
-langchain-openai = "^0.1.21"
-langchain-core = "^0.2.29"
-langchain-community = "^0.2.12"
pandas = "^2.2.2"
tqdm = "^4.66.5"
scikit-learn = "^1.5.2"
+
+[tool.poetry.group.requests.dependencies]
+openai = "^1.0.0"
+requests = "^2.31.0"
+
+[tool.poetry.group.vllm.dependencies]
vllm = "^0.7.3"
-datasets = "^3.3.2"
+
+[tool.poetry.group.transformers.dependencies]
+transformers = "^4.48.0"
[tool.poetry.group.dev.dependencies]
matplotlib = "^3.9.2"
seaborn = "^0.13.2"
-transformers = "^4.48.0"
black = "^24.4.2"
flake8 = "^7.1.0"
isort = "^5.13.2"
pre-commit = "^3.7.1"
ipykernel = "^6.29.5"
-
[tool.poetry.group.docs.dependencies]
mkdocs = "^1.6.1"
mkdocs-material = "^9.5.39"
@@ -46,4 +48,4 @@ line_length = 120
profile = "black"
[tool.pydocstyle]
-convention = "google"
+convention = "google"
\ No newline at end of file
diff --git a/scripts/api_test.py b/scripts/api_test.py
new file mode 100644
index 0000000..cfced84
--- /dev/null
+++ b/scripts/api_test.py
@@ -0,0 +1,70 @@
+"""Test run for the Opro optimizer."""
+
+import argparse
+import random
+from logging import Logger
+
+from promptolution.callbacks import LoggerCallback
+from promptolution.templates import EVOPROMPT_GA_TEMPLATE
+from promptolution.tasks import ClassificationTask
+from promptolution.predictors import MarkerBasedClassificator
+from promptolution.optimizers import EvoPromptGA
+from datasets import load_dataset
+
+from promptolution.llms.api_llm import APILLM
+
+logger = Logger(__name__)
+
+"""Run a test run for any of the implemented optimizers."""
+parser = argparse.ArgumentParser()
+parser.add_argument("--base-url", default="https://api.openai.com/v1")
+parser.add_argument("--model", default="gpt-4o-2024-08-06")
+# parser.add_argument("--base-url", default="https://api.deepinfra.com/v1/openai")
+# parser.add_argument("--model", default="meta-llama/Meta-Llama-3-8B-Instruct")
+# parser.add_argument("--base-url", default="https://api.anthropic.com/v1/")
+# parser.add_argument("--model", default="claude-3-haiku-20240307")
+parser.add_argument("--n-steps", type=int, default=2)
+parser.add_argument("--token", default=None)
+args = parser.parse_args()
+
+df = load_dataset("SetFit/ag_news", split="train", revision="main").to_pandas().sample(300)
+
+df["input"] = df["text"]
+df["target"] = df["label_text"]
+
+task = ClassificationTask(
+ df,
+ description="The dataset contains news articles categorized into four classes: World, Sports, Business, and Tech. The task is to classify each news article into one of the four categories.",
+ x_column="input",
+ y_column="target",
+)
+
+initial_prompts = [
+ "Classify this news article as World, Sports, Business, or Tech. Provide your answer between and tags.",
+ "Read the following news article and determine which category it belongs to: World, Sports, Business, or Tech. Your classification must be placed between markers.",
+ "Your task is to identify whether this news article belongs to World, Sports, Business, or Tech news. Provide your classification between the markers .",
+ "Conduct a thorough analysis of the provided news article and classify it as belonging to one of these four categories: World, Sports, Business, or Tech. Your answer should be presented within markers.",
+]
+
+llm = APILLM(api_url=args.base_url, model_id=args.model, token=args.token)
+downstream_llm = llm
+meta_llm = llm
+
+predictor = MarkerBasedClassificator(downstream_llm, classes=task.classes)
+
+callbacks = [LoggerCallback(logger)]
+
+optimizer = EvoPromptGA(
+ task=task,
+ prompt_template=EVOPROMPT_GA_TEMPLATE,
+ predictor=predictor,
+ meta_llm=meta_llm,
+ initial_prompts=initial_prompts,
+ callbacks=callbacks,
+ n_eval_samples=20,
+ verbosity=2, # for debugging
+)
+
+best_prompts = optimizer.optimize(n_steps=args.n_steps)
+
+logger.info(f"Optimized prompts: {best_prompts}")