diff --git a/keybert/llm/_litellm.py b/keybert/llm/_litellm.py index f0e55469..5a1fe1c8 100644 --- a/keybert/llm/_litellm.py +++ b/keybert/llm/_litellm.py @@ -32,6 +32,9 @@ class LiteLLM(BaseLLM): `self.default_prompt_` is used instead. NOTE: Use `"[DOCUMENT]"` in the prompt to decide where the document needs to be inserted + system_prompt: The message that sets the behavior of the assistant. + It's typically used to provide high-level instructions + for the conversation. delay_in_seconds: The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. verbose: Set this to True if you want to see a progress bar for the @@ -68,6 +71,7 @@ class LiteLLM(BaseLLM): def __init__(self, model: str = "gpt-3.5-turbo", prompt: str = None, + system_prompt: str = "You are a helpful assistant.", generator_kwargs: Mapping[str, Any] = {}, delay_in_seconds: float = None, verbose: bool = False @@ -79,6 +83,7 @@ def __init__(self, else: self.prompt = prompt + self.system_prompt = system_prompt self.default_prompt_ = DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds self.verbose = verbose @@ -116,7 +121,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s # Use a chat model messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} diff --git a/keybert/llm/_openai.py b/keybert/llm/_openai.py index 5c8c078c..9514b5e2 100644 --- a/keybert/llm/_openai.py +++ b/keybert/llm/_openai.py @@ -60,6 +60,9 @@ class OpenAI(BaseLLM): `self.default_prompt_` is used instead. NOTE: Use `"[DOCUMENT]"` in the prompt to decide where the document needs to be inserted + system_prompt: The message that sets the behavior of the assistant. + It's typically used to provide high-level instructions + for the conversation. delay_in_seconds: The delay in seconds between consecutive prompts in order to prevent RateLimitErrors. exponential_backoff: Retry requests with a random exponential backoff. @@ -114,6 +117,7 @@ def __init__(self, client, model: str = "gpt-3.5-turbo-instruct", prompt: str = None, + system_prompt: str = "You are a helpful assistant.", generator_kwargs: Mapping[str, Any] = {}, delay_in_seconds: float = None, exponential_backoff: bool = False, @@ -128,6 +132,7 @@ def __init__(self, else: self.prompt = prompt + self.system_prompt = system_prompt self.default_prompt_ = DEFAULT_CHAT_PROMPT if chat else DEFAULT_PROMPT self.delay_in_seconds = delay_in_seconds self.exponential_backoff = exponential_backoff @@ -170,7 +175,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s # Use a chat model if self.chat: messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} ] kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}