diff --git a/inference/server/oasst_inference_server/plugin_utils.py b/inference/server/oasst_inference_server/plugin_utils.py new file mode 100644 index 0000000000..4e9e3b90a9 --- /dev/null +++ b/inference/server/oasst_inference_server/plugin_utils.py @@ -0,0 +1,55 @@ +import asyncio +import json + +import aiohttp +import yaml +from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError +from fastapi import HTTPException +from loguru import logger +from oasst_shared.schemas import inference + + +async def attempt_fetch_plugin(session: aiohttp.ClientSession, url: str, timeout: float = 5.0): + async with session.get(url, timeout=timeout) as response: + content_type = response.headers.get("Content-Type") + + if response.status == 404: + raise HTTPException(status_code=404, detail="Plugin not found") + if response.status != 200: + raise HTTPException(status_code=500, detail="Failed to fetch plugin") + + if "application/json" in content_type or "text/json" in content_type or url.endswith(".json"): + if "text/json" in content_type: + logger.warning(f"Plugin {url} is using text/json as its content type. This is not recommended.") + config = json.loads(await response.text()) + else: + config = await response.json() + elif ( + "application/yaml" in content_type + or "application/x-yaml" in content_type + or url.endswith(".yaml") + or url.endswith(".yml") + ): + config = yaml.safe_load(await response.text()) + else: + raise HTTPException( + status_code=400, + detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.", + ) + + return inference.PluginConfig(**config) + + +async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig: + async with aiohttp.ClientSession() as session: + for attempt in range(retries): + try: + plugin_config = await attempt_fetch_plugin(session, url, timeout=timeout) + return plugin_config + except (ClientConnectorError, ServerTimeoutError) as e: + if attempt == retries - 1: + raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}") + await asyncio.sleep(2**attempt) # exponential backoff + except aiohttp.ClientError as e: + raise HTTPException(status_code=500, detail=f"Request failed: {e}") + raise HTTPException(status_code=500, detail="Failed to fetch plugin") diff --git a/inference/server/oasst_inference_server/routes/configs.py b/inference/server/oasst_inference_server/routes/configs.py index d09988c68b..b9318d1a90 100644 --- a/inference/server/oasst_inference_server/routes/configs.py +++ b/inference/server/oasst_inference_server/routes/configs.py @@ -1,13 +1,8 @@ -import asyncio -import json - -import aiohttp import fastapi import pydantic -import yaml -from aiohttp.client_exceptions import ClientConnectorError, ServerTimeoutError from fastapi import HTTPException from loguru import logger +from oasst_inference_server import plugin_utils from oasst_inference_server.settings import settings from oasst_shared import model_configs from oasst_shared.schemas import inference @@ -114,50 +109,6 @@ class ModelConfigInfo(pydantic.BaseModel): ] -async def fetch_plugin(url: str, retries: int = 3, timeout: float = 5.0) -> inference.PluginConfig: - async with aiohttp.ClientSession() as session: - for attempt in range(retries): - try: - async with session.get(url, timeout=timeout) as response: - content_type = response.headers.get("Content-Type") - - if response.status == 200: - if "application/json" in content_type or "text/json" in content_type or url.endswith(".json"): - if "text/json" in content_type: - logger.warning( - f"Plugin {url} is using text/json as its content type. This is not recommended." - ) - config = json.loads(await response.text()) - else: - config = await response.json() - elif ( - "application/yaml" in content_type - or "application/x-yaml" in content_type - or url.endswith(".yaml") - or url.endswith(".yml") - ): - config = yaml.safe_load(await response.text()) - else: - raise HTTPException( - status_code=400, - detail=f"Unsupported content type: {content_type}. Only JSON and YAML are supported.", - ) - - return inference.PluginConfig(**config) - elif response.status == 404: - raise HTTPException(status_code=404, detail="Plugin not found") - else: - raise HTTPException(status_code=response.status, detail="Unexpected status code") - except (ClientConnectorError, ServerTimeoutError) as e: - if attempt == retries - 1: # last attempt - raise HTTPException(status_code=500, detail=f"Request failed after {retries} retries: {e}") - await asyncio.sleep(2**attempt) # exponential backoff - - except aiohttp.ClientError as e: - raise HTTPException(status_code=500, detail=f"Request failed: {e}") - raise HTTPException(status_code=500, detail="Failed to fetch plugin") - - @router.get("/model_configs") async def get_model_configs() -> list[ModelConfigInfo]: return [ @@ -173,7 +124,7 @@ async def get_model_configs() -> list[ModelConfigInfo]: @router.post("/plugin_config") async def get_plugin_config(plugin: inference.PluginEntry) -> inference.PluginEntry: try: - plugin_config = await fetch_plugin(plugin.url) + plugin_config = await plugin_utils.fetch_plugin(plugin.url) except HTTPException as e: logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}") raise fastapi.HTTPException(status_code=e.status_code, detail=e.detail) @@ -187,7 +138,7 @@ async def get_builtin_plugins() -> list[inference.PluginEntry]: for plugin in OA_PLUGINS: try: - plugin_config = await fetch_plugin(plugin.url) + plugin_config = await plugin_utils.fetch_plugin(plugin.url) except HTTPException as e: logger.warning(f"Failed to fetch plugin config from {plugin.url}: {e.detail}") continue diff --git a/inference/worker/basic_hf_server.py b/inference/worker/basic_hf_server.py index 4c3da096df..4e44a0578f 100644 --- a/inference/worker/basic_hf_server.py +++ b/inference/worker/basic_hf_server.py @@ -5,13 +5,13 @@ from queue import Queue import fastapi +import hf_stopping import hf_streamer import interface import torch import transformers import uvicorn from fastapi.middleware.cors import CORSMiddleware -from hf_stopping import SequenceStoppingCriteria from loguru import logger from oasst_shared import model_configs from settings import settings @@ -85,7 +85,9 @@ def print_text(token_id: int): streamer = hf_streamer.HFStreamer(input_ids=ids, printer=print_text) ids = ids.to(model.device) stopping_criteria = ( - transformers.StoppingCriteriaList([SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)]) + transformers.StoppingCriteriaList( + [hf_stopping.SequenceStoppingCriteria(tokenizer, stop_sequences, prompt)] + ) if stop_sequences else None ) diff --git a/inference/worker/chat_chain.py b/inference/worker/chat_chain.py index 3422d3eca6..d64d6cb1ce 100644 --- a/inference/worker/chat_chain.py +++ b/inference/worker/chat_chain.py @@ -25,48 +25,91 @@ from oasst_shared.schemas import inference from settings import settings -# NOTE: Max depth of retries for tool usage +# Max depth of retries for tool usage MAX_DEPTH = 6 -# NOTE: If we want to exclude tools description from final prompt, -# to save ctx token space, but it can hurt output quality, especially if -# truncation kicks in! -# I keep switching this on/off, depending on the model used. +# Exclude tools description from final prompt. Saves ctx space but can hurt output +# quality especially if truncation kicks in. Dependent on model used REMOVE_TOOLS_FROM_FINAL_PROMPT = False -current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - llm = HFInference( inference_server_url=settings.inference_server_url, max_new_tokens=512, stop_sequences=[], top_k=50, temperature=0.20, - # NOTE: It seems to me like it's better without repetition_penalty for - # llama-sft7e3 model seed=43, - repetition_penalty=(1 / 0.92), # works with good with > 0.88 + repetition_penalty=(1 / 0.92), # Best with > 0.88 ) -def populate_memory(memory: ConversationBufferMemory, work_request: inference.WorkRequest) -> None: - for message in work_request.thread.messages[:-1]: - if message.role == "prompter" and message.state == inference.MessageState.manual and message.content: - memory.chat_memory.add_user_message(message.content) - elif message.role == "assistant" and message.state == inference.MessageState.complete and message.content: - memory.chat_memory.add_ai_message(message.content) +class PromptedLLM: + """ + Handles calls to an LLM via LangChain with a prompt template and memory. + """ + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizer, + worker_config: inference.WorkerConfig, + parameters: interface.GenerateStreamParameters, + prompt_template: PromptTemplate, + memory: ConversationBufferMemory, + tool_names: list[str], + language: str, + action_input_format: str, + ): + self.tokenizer = tokenizer + self.worker_config = worker_config + self.parameters = parameters + self.prompt_template = prompt_template + self.memory = memory + self.tool_names = tool_names + self.language = language + self.action_input_format = action_input_format + self.current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + def call(self, prompt: str) -> tuple[str, str]: + """Prepares and truncates prompt, calls LLM, returns used prompt and response.""" + prompt = prepare_prompt( + prompt, + self.prompt_template, + self.memory, + self.tool_names, + self.current_time, + self.language, + self.tokenizer, + self.worker_config, + self.action_input_format, + ) + + # We do not strip() outputs as it seems to degrade instruction-following abilities of the model + prompt = utils.truncate_prompt(self.tokenizer, self.worker_config, self.parameters, prompt, True) + + response = ( + llm.generate(prompts=[prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"]) + .generations[0][0] + .text + ) + + if response: + response = response.replace("\n\n", "\n") + if response[0] != "\n": + response = f"\n{response}" + + return prompt, response def handle_plugin_usage( input_prompt: str, prompt_template: PromptTemplate, language: str, - tools: list[Tool], memory: ConversationBufferMemory, - plugin: inference.PluginEntry | None, worker_config: inference.WorkerConfig, tokenizer: transformers.PreTrainedTokenizer, parameters: interface.GenerateStreamParameters, + tools: list[Tool], + plugin: inference.PluginEntry | None, ) -> tuple[str, inference.PluginUsed]: execution_details = inference.PluginExecutionDetails( inner_monologue=[], @@ -94,105 +137,45 @@ def handle_plugin_usage( action_input_format = ( JSON_FORMAT_PAYLOAD if prompt_template.template.find("payload") != -1 else JSON_FORMAT_NO_PAYLOAD ) + eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else "" + tool_names = [tool.name for tool in tools] - eos_token = "" - if hasattr(tokenizer, "eos_token"): - eos_token = tokenizer.eos_token - - tools_names = [tool.name for tool in tools] - - init_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}" - memory, init_prompt = prepare_prompt( - init_prompt, - prompt_template, - memory, - tools_names, - current_time, - language, - tokenizer, - worker_config, - action_input_format, + chain = PromptedLLM( + tokenizer, worker_config, parameters, prompt_template, memory, tool_names, language, action_input_format ) - # NOTE: Do not strip() any of the outputs ever, as it will degrade the - # instruction following performance, at least with - # `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model` - - init_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, init_prompt, True) - chain_response = ( - llm.generate(prompts=[init_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"]) - .generations[0][0] - .text - ) - if chain_response is not None and chain_response != "": - chain_response = chain_response.replace("\n\n", "\n") - if chain_response[0] != "\n": - chain_response = f"\n{chain_response}" + init_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}" + init_prompt, chain_response = chain.call(init_prompt) inner_monologue.append("In: " + str(init_prompt)) inner_monologue.append("Out: " + str(chain_response)) - # out_1 -> tool name/assistant prefix - # out_2 -> tool input/assistant response - out_1, out_2 = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX) - - # whether model decided to use Plugin or not - assisted = False if ASSISTANT_PREFIX in out_1 else True + # Tool name/assistant prefix, Tool input/assistant response + prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX) + assisted = False if ASSISTANT_PREFIX in prefix else True chain_finished = not assisted - # Check if there is need to go deeper while not chain_finished and assisted and achieved_depth < MAX_DEPTH: - tool_response = use_tool(out_1, out_2, tools) + tool_response = use_tool(prefix, response, tools) - # Save previous chain response, that we will use for the final prompt + # Save previous chain response for use in final prompt prev_chain_response = chain_response - new_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}{chain_response}{OBSERVATION_SEQ} {tool_response}" - memory, new_prompt = prepare_prompt( - new_prompt, - prompt_template, - memory, - tools_names, - current_time, - language, - tokenizer, - worker_config, - action_input_format, - ) - # NOTE: Do not strip() any of the outputs ever, as it will degrade the - # instruction following performance, at least with - # `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model` - new_prompt = utils.truncate_prompt(tokenizer, worker_config, parameters, new_prompt, True) - chain_response = ( - llm.generate(prompts=[new_prompt], stop=[ASSISTANT_PREFIX, OBSERVATION_SEQ, f"\n{OBSERVATION_SEQ}"]) - .generations[0][0] - .text - ) - - if chain_response is not None and chain_response != "": - chain_response = chain_response.replace("\n\n", "\n") - if chain_response[0] != "\n": - chain_response = f"\n{chain_response}" + new_prompt, chain_response = chain.call(new_prompt) inner_monologue.append("In: " + str(new_prompt)) inner_monologue.append("Out: " + str(chain_response)) - out_1, out_2 = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX) - # Did model decided to use Plugin again or not? - assisted = False if ASSISTANT_PREFIX in out_1 else True + prefix, response = extract_tool_and_input(llm_output=chain_response, ai_prefix=ASSISTANT_PREFIX) + assisted = False if ASSISTANT_PREFIX in prefix else True - # NOTE: Check if tool response contains ERROR string, this is something - # that we would like to avoid, but until models are better, we will - # help them with this... - # for now models, sometime decides to retry, when tool usage reports - # error, but sometime it just ignore error... + # Check if tool response contains ERROR string and force retry + # Current models sometimes decide to retry on error but sometimes just ignore if tool_response.find("ERROR") != -1 and assisted is False: chain_response = prev_chain_response assisted = True - # Now LLM is done with using the plugin, - # so we need to generate the final prompt if not assisted: chain_finished = True @@ -201,17 +184,17 @@ def handle_plugin_usage( input_variables = ["input", "chat_history", "language", "current_time"] prompt_template = PromptTemplate(input_variables=input_variables, template=TEMPLATE) - tools_names = None + tool_names = None final_input = ( f"{input_prompt}{eos_token}{V2_ASST_PREFIX}\n{prev_chain_response}{OBSERVATION_SEQ} {tool_response}" ) - memory, inner_prompt = prepare_prompt( + inner_prompt = prepare_prompt( final_input, prompt_template, memory, - tools_names, - current_time, + tool_names, + chain.current_time, language, tokenizer, worker_config, @@ -226,26 +209,24 @@ def handle_plugin_usage( plugin_used.execution_details.final_generation_assisted = True plugin_used.execution_details.achieved_depth = achieved_depth + 1 plugin_used.execution_details.status = "success" - plugin_used.name = getattr(plugin.plugin_config, "name_for_human", None) - plugin_used.trusted = getattr(plugin, "trusted", None) - plugin_used.url = getattr(plugin, "url", None) + plugin_used.name = plugin.plugin_config.name_for_human + plugin_used.trusted = plugin.trusted + plugin_used.url = plugin.url return inner_prompt, plugin_used achieved_depth += 1 - plugin_used.name = getattr(plugin.plugin_config, "name_for_human", None) - plugin_used.trusted = getattr(plugin, "trusted", None) - plugin_used.url = getattr(plugin, "url", None) + plugin_used.name = plugin.plugin_config.name_for_human + plugin_used.trusted = plugin.trusted + plugin_used.url = plugin.url plugin_used.execution_details.inner_monologue = inner_monologue - # bring back ASSISTANT_PREFIX to chain_response, - # that was omitted with stop=[ASSISTANT_PREFIX] + # Re-add ASSISTANT_PREFIX to chain_response, omitted with stop=[ASSISTANT_PREFIX] chain_response = f"{chain_response}{ASSISTANT_PREFIX}: " - # Return non-assisted response if chain_finished: - # Malformed non-assisted LLM output - if not out_2 or out_2 == "": + if not response: + # Malformed non-assisted LLM output plugin_used.execution_details.status = "failure" plugin_used.execution_details.error_message = "Malformed LLM output" return init_prompt, plugin_used @@ -253,7 +234,7 @@ def handle_plugin_usage( plugin_used.execution_details.status = "success" return f"{init_prompt}{THOUGHT_SEQ} I now know the final answer\n{ASSISTANT_PREFIX}: ", plugin_used else: - # Max depth reached, just try to answer without using a tool + # Max depth reached, answer without tool plugin_used.execution_details.final_prompt = init_prompt plugin_used.execution_details.achieved_depth = achieved_depth plugin_used.execution_details.status = "failure" @@ -262,6 +243,46 @@ def handle_plugin_usage( return init_prompt, plugin_used +def handle_standard_usage( + original_prompt: str, + prompt_template: PromptTemplate, + language: str, + memory: ConversationBufferMemory, + worker_config: inference.WorkerConfig, + tokenizer: transformers.PreTrainedTokenizer, +): + eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else "" + current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + # Non-plugin prompt template can include some external data e.g. datetime, language + action_input_format = ( + JSON_FORMAT_PAYLOAD if prompt_template.template.find("payload") != -1 else JSON_FORMAT_NO_PAYLOAD + ) + input = f"{original_prompt}{eos_token}{V2_ASST_PREFIX}" + init_prompt = prepare_prompt( + input, prompt_template, memory, None, current_time, language, tokenizer, worker_config, action_input_format + ) + return init_prompt, None + + +def build_memory(work_request: inference.WorkRequest) -> ConversationBufferMemory: + memory = ConversationBufferMemory( + memory_key="chat_history", + input_key="input", + output_key="output", + ai_prefix=ASSISTANT_PREFIX, + human_prefix=HUMAN_PREFIX, + ) + + for message in work_request.thread.messages[:-1]: + if message.role == "prompter" and message.state == inference.MessageState.manual and message.content: + memory.chat_memory.add_user_message(message.content) + elif message.role == "assistant" and message.state == inference.MessageState.complete and message.content: + memory.chat_memory.add_ai_message(message.content) + + return memory + + def handle_conversation( work_request: inference.WorkRequest, worker_config: inference.WorkerConfig, @@ -274,30 +295,11 @@ def handle_conversation( raise ValueError("Prompt is empty") language = "English" - - # Get one and only one enabled plugin - # TODO: Add support for multiple plugins at once - # maybe... should be explored plugin = next((p for p in parameters.plugins if p.enabled), None) - # Compose tools from plugin, where every endpoint of plugin will become - # one tool, and return prepared prompt with instructions tools_instructions_template, tools = compose_tools_from_plugin(plugin) - plugin_enabled = len(tools) - - eos_token = "" - if hasattr(tokenizer, "eos_token"): - eos_token = tokenizer.eos_token - - memory = ConversationBufferMemory( - memory_key="chat_history", - input_key="input", - output_key="output", - ai_prefix=ASSISTANT_PREFIX, - human_prefix=HUMAN_PREFIX, - ) - - populate_memory(memory, work_request) + plugin_enabled = len(tools) > 0 + memory: ConversationBufferMemory = build_memory(work_request) TEMPLATE = f"""{V2_PROMPTER_PREFIX}{PREFIX}{tools_instructions_template}{SUFFIX}""" input_variables = [ @@ -308,42 +310,20 @@ def handle_conversation( "action_input_format", ] + (["tools_names"] if plugin_enabled else []) - # NOTE: Should we pass language from the UI here? + # TODO: Consider passing language from the UI here prompt_template = PromptTemplate(input_variables=input_variables, template=TEMPLATE) - # Run trough plugin chain. Returns PluginUsed and final prompt - # that will be passed to worker for final completion with LLM - # using sampling settings derived from frontend UI if plugin_enabled: return handle_plugin_usage( - original_prompt, prompt_template, language, tools, memory, plugin, worker_config, tokenizer, parameters + original_prompt, prompt_template, language, memory, worker_config, tokenizer, parameters, tools, plugin ) - # Just regular prompt template without plugin chain. - # Here is prompt in format of a template, that includes some - # external/ "realtime" data, such as current date&time and language - # that can be passed from frontend here. - action_input_format = ( - JSON_FORMAT_PAYLOAD if prompt_template.template.find("payload") != -1 else JSON_FORMAT_NO_PAYLOAD - ) - input = f"{original_prompt}{eos_token}{V2_ASST_PREFIX}" - memory, init_prompt = prepare_prompt( - input, prompt_template, memory, None, current_time, language, tokenizer, worker_config, action_input_format - ) - return init_prompt, None - + return handle_standard_usage(original_prompt, prompt_template, language, memory, worker_config, tokenizer) except Exception as e: logger.error(f"Error while handling conversation: {e}") return "", None -# NOTE: Only for local DEV and prompt "engineering" -# some of the plugins that can be used for testing: -# - https://www.klarna.com/.well-known/ai-plugin.json -# - https://nla.zapier.com/.well-known/ai-plugin.json (this one is behind auth) -# - https://chat-calculator-plugin.supportmirage.repl.co/.well-known/ai-plugin.json (error responses seems to be html) -# - https://www.joinmilo.com/.well-known/ai-plugin.json (works quite well, but -# is very simple, so it's not very useful for testing) if __name__ == "__main__": plugin = inference.PluginEntry( enabled=True, diff --git a/inference/worker/chat_chain_utils.py b/inference/worker/chat_chain_utils.py index b10b190c29..a50fc860c5 100644 --- a/inference/worker/chat_chain_utils.py +++ b/inference/worker/chat_chain_utils.py @@ -1,5 +1,6 @@ import json import re +from typing import Callable import requests import transformers @@ -15,8 +16,9 @@ from utils import shared_tokenizer_lock RESPONSE_MAX_LENGTH = 2048 +DESCRIPTION_FOR_MODEL_MAX_LENGTH = 512 -llm_parser = HFInference( +llm_json_parser = HFInference( inference_server_url=settings.inference_server_url, max_new_tokens=512, stop_sequences=[""], @@ -26,12 +28,9 @@ ) -# NOTE: https://en.wikipedia.org/wiki/Jaro%E2%80%93Winkler_distance -# We are using plugin API-s endpoint/paths as tool names, -# e.g.: /get_weather, /get_news etc... so this algo should be fine -# possible improvement: try levenshtein or vector distances -# but best way is to just use better models. +# This algo should be fine but possible improvements could be levenshtein or vector distance def similarity(ts1: str, ts2: str) -> float: + """Compute Jaro-Winkler distance between two strings.""" if ts1 == ts2: return 1 @@ -68,28 +67,29 @@ def similarity(ts1: str, ts2: str) -> float: return (match / len1 + match / len2 + (match - t) / match) / 3.0 -# TODO: Can be improved, like... try to use another pass trough LLM -# with custom tuned prompt for fixing the formatting. -# e.g. "This is malformed text, please fix it: {malformed text} -> FIX magic :)" def extract_tool_and_input(llm_output: str, ai_prefix: str) -> tuple[str, str]: + """ + Extract tool name and tool input from LLM output. If LLM chose not to use a tool, `ai_prefix` is returned instead of tool name, and LLM output is returned instead of tool input. + """ llm_output = llm_output.strip().replace("```", "") if f"{ai_prefix}:" in llm_output: + # No tool used, return LLM prefix and LLM output return ai_prefix, llm_output.split(f"{ai_prefix}:")[-1].strip() + regex = r"Action: (.*?)[\n]*Action Input:\n?(.*)" - # match = re.search(regex, llm_output) # this is for 65B llama :( match = re.search(regex, llm_output, re.MULTILINE | re.DOTALL) if not match: if OBSERVATION_SEQ in llm_output: return ai_prefix, llm_output.split(OBSERVATION_SEQ)[-1].strip() return ai_prefix, llm_output + action = match.group(1) action_input = match.group(2) return action.strip().replace("'", ""), action_input.strip().strip(" ") -# Truncate string, but append matching bracket if string starts with [ or { or ( -# it helps in a way, that LLM will not try to just continue generating output -# continuation +# Truncate, but append closing bracket if string starts with [ or { or ( +# Helps prevent LLM from just generating output continuously def truncate_str(output: str, max_length: int = 1024) -> str: if len(output) > max_length: if output[0] == "(": @@ -103,7 +103,7 @@ def truncate_str(output: str, max_length: int = 1024) -> str: return output -# Parse JSON and try to fix it if it's not valid +# Parse JSON and try to fix it if invalid def prepare_json(json_str: str) -> str: json_str = json_str.strip() fixed_json = json_str @@ -138,7 +138,7 @@ def prepare_json(json_str: str) -> str: json.loads(fixed_json) except json.decoder.JSONDecodeError as e: logger.warning(f"JSON is still not valid, trying to fix it with LLM {fixed_json}") - # if it's still not valid, try with LLM fixer + # If still invalid, try using LLM to fix prompt = f"""{V2_PROMPTER_PREFIX}Below is malformed JSON object string: -------------- {json_str} @@ -153,7 +153,7 @@ def prepare_json(json_str: str) -> str: Here is the fixed JSON object string:{V2_ASST_PREFIX}""" logger.warning(f"JSON Fix Prompt: {prompt}") - out = llm_parser.generate(prompts=[prompt]).generations[0][0].text + out = llm_json_parser.generate(prompts=[prompt]).generations[0][0].text out = out[: out.find("}") + 1] logger.warning(f"JSON Fix Output: {out}") return out @@ -161,20 +161,32 @@ def prepare_json(json_str: str) -> str: return fixed_json -def use_tool(tool_name: str, tool_input: str, tools: list) -> str: - best_match, best_similarity = max( - ((tool, similarity(tool.name, tool_name)) for tool in tools), key=lambda x: x[1], default=(None, 0) +def select_tool(tool_name: str, tools: list[Tool]) -> Tool | None: + tool = next((t for t in tools if t.name in tool_name), None) + if tool: + return tool + tool, tool_similarity = max( + ((t, similarity(t.name, tool_name)) for t in tools), + key=lambda x: x[1], + default=(None, 0), ) - # This should become stricter and stricter as we get better models - if best_match is not None and best_similarity > 0.75: - tool_input = prepare_json(tool_input) - return best_match.func(tool_input) - return f"ERROR! {tool_name} is not a valid tool. Try again with different tool!" + # TODO: make stricter with better models + if tool and tool_similarity > 0.75: + return tool + return None + + +def use_tool(tool_name: str, tool_input: str, tools: list[Tool]) -> str: + tool = select_tool(tool_name, tools) + if not tool: + return f"ERROR! {tool_name} is not a valid tool. Try again with different tool!" + prepared_input = prepare_json(tool_input) + tool_output = tool.func(prepared_input) + return tool_output # Needs more work for errors, error-prompt tweaks are currently based on # `OpenAssistant/oasst-sft-6-llama-30b-epoch-1 model` -# TODO: Add other missing methods and Content-Types etc... class RequestsForLLM: def run(self, params: str, url: str, param_location: str, type: str, payload: str | None = None) -> str: return self.run_request(params, url, param_location, type, payload) @@ -195,7 +207,7 @@ def run_request(self, params: str, url: str, param_location: str, type: str, pay ) res = requests.get(url, params=query_params, headers=headers) elif type.lower() == "post": - # model didn't generated separate payload object, so we just put params as payload and hope for the best... + # if model did not generate payload object, use params as payload data = json.dumps(payload) if payload else json.dumps(params) logger.info( f"Running {type.upper()} request on {url} with\nparams: {params}\nparam_location: {param_location}\npayload: {data}" @@ -205,11 +217,10 @@ def run_request(self, params: str, url: str, param_location: str, type: str, pay return f"ERROR! Unsupported request type: {type}. Only GET and POST are supported. Try again!" return self.process_response(res) - except Exception as e: return f"ERROR! That didn't work, try modifying Action Input.\n{e}. Try again!" - def process_response(self, res): + def process_response(self, res: requests.Response) -> str: logger.info(f"Request response: {res.text}") if res.status_code != 200: return f"ERROR! That didn't work, try modifying Action Input.\n{res.text}. Try again!" @@ -227,15 +238,15 @@ def compose_tools_from_plugin(plugin: inference.PluginEntry | None) -> tuple[str if not plugin: return "", [] - llm_plugin = prepare_plugin_for_llm(plugin.url) + llm_plugin: inference.PluginConfig = prepare_plugin_for_llm(plugin.url) if not llm_plugin: return "", [] tools = [] request_tool = RequestsForLLM() - def create_tool_func(endpoint, param_location): - def func(req): + def create_tool_func(endpoint: inference.PluginOpenAPIEndpoint, param_location: str) -> Callable[..., str]: + def func(req) -> str: try: json_obj = json.loads(req) request = json_obj.get("request", {}) @@ -254,14 +265,9 @@ def func(req): return func - # Generate tool for each endpoint of the plugin - # NOTE: This approach is a bit weird, but it is a good way to help LLM - # to use tools, so LLM does not need to choose api server url - # and paramter locations: query, path, body, etc. on its own. - # LLM will only, choose what endpoint, what parameters and what values - # to use. Modifying this part of the prompt, we can degrade or improve - # performance of tool usage. - for endpoint in llm_plugin["endpoints"]: + # Generate tool for each plugin endpoint. Helps LLM use tools as it does not choose API server URL etc on its own + # LLM only chooses endpoint, parameters and values to use. Modifying this can degrade or improve tool usage + for endpoint in llm_plugin.endpoints: params = "\n\n".join( [ f""" name: "{param.name}",\n in: "{param.in_}",\n description: "{truncate_str(param.description, 128)}",\n schema: {param.schema_},\n required: {param.required}""" @@ -269,9 +275,8 @@ def func(req): ] ) - # NOTE: LangChain is using internaly {input_name} for templating - # and some OA/ChatGPT plugins of course, can have {some_word} in theirs - # descriptions + # LangChain uses {input_name} for templating + # Some plugins can have {some_word} in their description params = params.replace("{", "{{").replace("}", "}}") payload_description = "" if endpoint.payload: @@ -293,34 +298,27 @@ def func(req): param_location = endpoint.params[0].in_ if len(endpoint.params) > 0 else "query" - # some plugins do not have operation_id, so we use path as fallback + # If plugin has no operation_id, use path as fallback path = endpoint.path[1:] if endpoint.path and len(endpoint.path) > 0 else endpoint.path tool = Tool( name=endpoint.operation_id if endpoint.operation_id != "" else path, - # Could be path, e.g /api/v1/endpoint - # but it can lead LLM to makeup some URLs - # and problem with EP description is that - # it can be too long for some plugins + # Could be path, e.g /api/v1/endpoint but can lead LLM to invent URLs + # Problem with EP description is that it is too long for some plugins func=create_tool_func(endpoint, param_location), description=f"{openapi_specification_title}{parameters_description}{payload_description}tool description: {endpoint.summary}\n", ) tools.append(tool) tools_string = "\n".join([f"> {tool.name}{tool.description}" for tool in tools]) - # NOTE: This can be super long for some plugins, that I tested so far. - # and because we don't have 32k CTX size, we need to truncate it. - plugin_description_for_model = truncate_str(llm_plugin["description_for_model"], 512) + # This can be long for some plugins, we need to truncate due to ctx limitations + plugin_description_for_model = truncate_str(llm_plugin.description_for_model, DESCRIPTION_FOR_MODEL_MAX_LENGTH) return ( - f"{TOOLS_PREFIX}{tools_string}\n\n{llm_plugin['name_for_model']} plugin description:\n{plugin_description_for_model}\n\n{INSTRUCTIONS}", + f"{TOOLS_PREFIX}{tools_string}\n\n{llm_plugin.name_for_model} plugin description:\n{plugin_description_for_model}\n\n{INSTRUCTIONS}", tools, ) -# TODO: -# here we will not be not truncating per token, but will be deleting messages -# from the history, and we will leave hard truncation to work.py which if -# occurs it will degrade quality of the output. def prepare_prompt( input_prompt: str, prompt_template: PromptTemplate, @@ -331,7 +329,7 @@ def prepare_prompt( tokenizer: transformers.PreTrainedTokenizer, worker_config: inference.WorkerConfig, action_input_format: str, -) -> tuple[ConversationBufferMemory, str]: +) -> str: max_input_length = worker_config.model_config.max_input_length args = { @@ -350,7 +348,7 @@ def prepare_prompt( with shared_tokenizer_lock: ids = tokenizer.encode(out_prompt) - # soft truncation + # soft truncation (delete whole messages) while len(ids) > max_input_length and len(memory.chat_memory.messages) > 0: memory.chat_memory.messages.pop(0) args = { @@ -370,4 +368,4 @@ def prepare_prompt( ids = tokenizer.encode(out_prompt) logger.warning(f"Prompt too long, deleting chat history. New length: {len(ids)}") - return memory, out_prompt + return out_prompt diff --git a/inference/worker/hf_langchain_inference.py b/inference/worker/hf_langchain_inference.py index 591e44b743..fa9abfa414 100644 --- a/inference/worker/hf_langchain_inference.py +++ b/inference/worker/hf_langchain_inference.py @@ -41,11 +41,8 @@ def _call(self, prompt: str, stop: list[str] | None = None) -> str: for event in utils.get_inference_server_stream_events(request): stream_response = event - generated_text = stream_response.generated_text - if generated_text is None: - generated_text = "" + generated_text = stream_response.generated_text or "" - # remove stop sequences from the end of the generated text for stop_seq in stop: if stop_seq in generated_text: generated_text = generated_text[: generated_text.index(stop_seq)] diff --git a/inference/worker/openapi_parser.py b/inference/worker/openapi_parser.py index a452084450..2f644b6029 100644 --- a/inference/worker/openapi_parser.py +++ b/inference/worker/openapi_parser.py @@ -33,13 +33,14 @@ def get_plugin_config(url: str) -> inference.PluginConfig | None: response.raise_for_status() plugin_dict = response.json() logger.info(f"Plugin config downloaded {plugin_dict}") - return plugin_dict + plugin_config = inference.PluginConfig.parse_obj(plugin_dict) + return plugin_config except (requests.RequestException, ValueError) as e: logger.warning(f"Error downloading or parsing Plugin config: {e}") return None -def resolve_schema_reference(ref, openapi_dict): +def resolve_schema_reference(ref: str, openapi_dict: dict): if not ref.startswith("#/"): raise ValueError(f"Invalid reference format: {ref}") @@ -53,82 +54,100 @@ def resolve_schema_reference(ref, openapi_dict): return schema -# TODO: Extract endpoints from this function to separate one! -# also get rid of endpoints from PluginConfig class -def prepare_plugin_for_llm(plugin_url: str) -> inference.PluginConfig | None: - plugin_config = get_plugin_config(plugin_url) - if not plugin_config: - return None - - api_dict = plugin_config.get("api") - api_url = api_dict.get("url") if api_dict else None - if not api_url: - return None - # check if url has www or http and if not, add base url + url - # but delete everything from plugin_url after last slash - # if last char is slash, first delete it and then find last slash - parsed_link = urlsplit(plugin_url) - base_of_url = f"{parsed_link.scheme}://{parsed_link.netloc}" - api_url = api_url if api_url.startswith("http") else f"{base_of_url}/{api_url}" - openapi_dict = fetch_openapi_spec(api_url) - - if not openapi_dict: - return None - +def parse_plugin_endpoint( + api_url: str, + method: str, + details: dict, + base_url: str, + path: str, + openapi_dict: dict, +) -> inference.PluginOpenAPIEndpoint: + """ + Parse details of a single plugin endpoint from OpenAPI spec. + + Args: + api_url: URL of the plugin API. + method: HTTP method of the endpoint. + details: Details of the endpoint from OpenAPI spec. + base_url: Base URL of the plugin. + path: Path of the endpoint. + openapi_dict: Full OpenAPI spec of the plugin. + """ + split_result = urlsplit(api_url) + backup_url = f"{split_result.scheme}://{split_result.netloc}" + params_list = [] + parameters = details.get("parameters", []) + if parameters is not None: + for param in parameters: + schema = None + if "$ref" in param["schema"]: + schema = resolve_schema_reference(param["schema"]["$ref"], openapi_dict) + + params_list.append( + inference.PluginOpenAPIParameter( + name=param.get("name", ""), + in_=param.get("in", "query"), + description=param.get("description", ""), + required=param.get("required", False), + schema_=schema, + ) + ) + # Check if the method is POST and extract request body schema + payload = None + if "requestBody" in details: + content = details["requestBody"].get("content", {}) + for media_type, media_schema in content.items(): + if media_type == "application/json": + if "$ref" in media_schema["schema"]: + payload = resolve_schema_reference(media_schema["schema"]["$ref"], openapi_dict) + else: + payload = media_schema["schema"] + + endpoint_data = { + "type": method, + "summary": details.get("summary", ""), + "operation_id": details.get("operationId", ""), + "url": f"{base_url}{path}" if base_url is not None else f"{backup_url}{path}", + "path": path, + "params": params_list, + "payload": payload, + } + + if "tags" in details: + tag_name = details["tags"][0] + endpoint_data["tag"] = tag_name + + endpoint = inference.PluginOpenAPIEndpoint(**endpoint_data) + return endpoint + + +def get_plugin_endpoints(api_url: str, openapi_dict: dict) -> list[inference.PluginOpenAPIEndpoint]: endpoints = [] - base_url = openapi_dict.get("servers", [{}])[0].get("url") - paths = openapi_dict.get("paths", {}) for path, methods in paths.items(): for method, details in methods.items(): - split_result = urlsplit(api_url) - backup_url = f"{split_result.scheme}://{split_result.netloc}" - params_list = [] - parameters = details.get("parameters", []) - if parameters is not None: - for param in parameters: - schema = None - if "$ref" in param["schema"]: - schema = resolve_schema_reference(param["schema"]["$ref"], openapi_dict) - - params_list.append( - inference.PluginOpenAPIParameter( - name=param.get("name", ""), - in_=param.get("in", "query"), - description=param.get("description", ""), - required=param.get("required", False), - schema_=schema, - ) - ) - # Check if the method is POST and extract request body schema - payload = None - if "requestBody" in details: - content = details["requestBody"].get("content", {}) - for media_type, media_schema in content.items(): - if media_type == "application/json": - if "$ref" in media_schema["schema"]: - payload = resolve_schema_reference(media_schema["schema"]["$ref"], openapi_dict) - else: - payload = media_schema["schema"] - - endpoint_data = { - "type": method, - "summary": details.get("summary", ""), - "operation_id": details.get("operationId", ""), - "url": f"{base_url}{path}" if base_url is not None else f"{backup_url}{path}", - "path": path, - "params": params_list, - "payload": payload, - } - - if "tags" in details: - tag_name = details["tags"][0] - endpoint_data["tag"] = tag_name - - endpoint = inference.PluginOpenAPIEndpoint(**endpoint_data) - endpoints.append(endpoint) - - plugin_config["endpoints"] = endpoints - return plugin_config + endpoints.append(parse_plugin_endpoint(api_url, method, details, base_url, path, openapi_dict)) + + return endpoints + + +def prepare_plugin_for_llm(plugin_url: str) -> inference.PluginConfig | None: + plugin_config = get_plugin_config(plugin_url) + if not plugin_config: + return None + + try: + api_url = plugin_config.api.url + if not api_url.startswith("http"): + parsed_link = urlsplit(plugin_url) + base_of_url = f"{parsed_link.scheme}://{parsed_link.netloc}" + api_url = f"{base_of_url}/{api_url}" + + openapi_dict = fetch_openapi_spec(api_url) + plugin_config.endpoints = get_plugin_endpoints(api_url, openapi_dict) + return plugin_config + except Exception: + logger.debug(f"Error preparing plugin: {plugin_url}") + return None diff --git a/inference/worker/work.py b/inference/worker/work.py index 322508fcf7..81cc02e80c 100644 --- a/inference/worker/work.py +++ b/inference/worker/work.py @@ -29,9 +29,7 @@ def make_prompt_and_parameters( if settings.oa_protocol_version != "v2": raise RuntimeError(f"Unsupported oa protocol version: {settings.oa_protocol_version}") - eos_token = "" - if hasattr(tokenizer, "eos_token"): - eos_token = tokenizer.eos_token + eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else "" def _prepare_message(message: inference.MessageRead) -> str: prefix = V2_ASST_PREFIX if message.is_assistant else V2_PROMPTER_PREFIX @@ -84,18 +82,15 @@ def handle_work_request( prompt = "" used_plugin = None - # Check if any plugin is enabled, if so, use it. for plugin in parameters.plugins: if plugin.enabled: prompt, used_plugin = chat_chain.handle_conversation(work_request, worker_config, parameters, tokenizer) - # When using plugins, and final prompt being truncated due to the input - # length limit, LLaMA llm has tendency to leak internal promptings, - # and generate undesirable continuations, so here we will be adding - # some plugin keywords/sequences to the stop sequences to try preventing it - parameters.stop.extend([END_SEQ, START_SEQ, THOUGHT_SEQ, ASSISTANT_PREFIX]) + # When using plugins and final prompt is truncated due to length limit + # LLaMA has tendency to leak internal prompts and generate bad continuations + # So we add keywords/sequences to the stop sequences to reduce this + parameters.stop.extend([END_SEQ, START_SEQ, THOUGHT_SEQ, f"{ASSISTANT_PREFIX}:"]) break - # If no plugin was "used", use the default prompt generation. if not used_plugin: prompt, parameters = make_prompt_and_parameters(tokenizer=tokenizer, work_request=work_request) @@ -103,7 +98,6 @@ def handle_work_request( model_config = worker_config.model_config - # Only send safety request if work request safety level is not 0 if settings.enable_safety and work_request.safety_parameters.level: safety_request = inference.SafetyRequest(inputs=prompt, parameters=work_request.safety_parameters) safety_response = get_safety_server_response(safety_request) @@ -181,8 +175,7 @@ def handle_work_request( if model_config.is_llama: stream_response.generated_text = stream_response.generated_text.strip() - # NOTE: This is only help for RLHF models using plugin prompts... - # Get the generated text up to the first occurrence of any of: + # Helps with RLHF models using plugin prompts. Get generated text to first occurrence of: # START_SEQ, END_SEQ, ASSISTANT_PREFIX, THOUGHT_SEQ, OBSERVATION_SEQ end_seq_index = min( [ diff --git a/oasst-shared/oasst_shared/schemas/inference.py b/oasst-shared/oasst_shared/schemas/inference.py index 4b051b9a13..e3551be7bf 100644 --- a/oasst-shared/oasst_shared/schemas/inference.py +++ b/oasst-shared/oasst_shared/schemas/inference.py @@ -3,7 +3,7 @@ import random import uuid from datetime import datetime -from typing import Annotated, Any, Literal, Union +from typing import Annotated, Literal, Union import psutil import pydantic @@ -161,19 +161,12 @@ class PluginConfig(pydantic.BaseModel): legal_info_url: str | None = None endpoints: list[PluginOpenAPIEndpoint] | None = None - def __getitem__(self, key: str) -> Any: - return getattr(self, key) - - def __setitem__(self, key: str, value: Any) -> None: - setattr(self, key, value) - class PluginEntry(pydantic.BaseModel): url: str enabled: bool = True plugin_config: PluginConfig | None = None - # NOTE: Idea, is to have OA internal plugins as trusted, - # and all other plugins as untrusted by default(until proven otherwise) + # Idea is for OA internal plugins to be trusted, others untrusted by default trusted: bool | None = False