Skip to content

Commit

Permalink
Support ChatML prompt format in worker (#3668)
Browse files Browse the repository at this point in the history
  • Loading branch information
olliestanley committed Aug 28, 2023
1 parent 8c0e1a3 commit d613c81
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 42 deletions.
31 changes: 19 additions & 12 deletions inference/worker/chat_chain.py
Expand Up @@ -14,8 +14,6 @@
PREFIX,
SUFFIX,
THOUGHT_SEQ,
V2_ASST_PREFIX,
V2_PROMPTER_PREFIX,
)
from chat_chain_utils import compose_tools_from_plugin, extract_tool_and_input, prepare_prompt, use_tool
from hf_langchain_inference import HFInference
Expand All @@ -26,6 +24,7 @@
from oasst_shared.model_configs import ModelConfig
from oasst_shared.schemas import inference
from settings import settings
from utils import special_tokens

# Exclude tools description from final prompt. Saves ctx space but can hurt output
# quality especially if truncation kicks in. Dependent on model used
Expand Down Expand Up @@ -143,7 +142,11 @@ def handle_plugin_usage(
action_input_format = (
JSON_FORMAT_PAYLOAD if prompt_template.template.find("payload") != -1 else JSON_FORMAT_NO_PAYLOAD
)
eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else ""
eos_token = ""
if special_tokens["end"]:
eos_token = special_tokens["end"]
elif hasattr(tokenizer, "eos_token"):
eos_token = tokenizer.eos_token
tool_names = [tool.name for tool in tools]

chain = PromptedLLM(
Expand All @@ -170,7 +173,7 @@ def handle_plugin_usage(
),
)

init_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}"
init_prompt = f"{input_prompt}{eos_token}{special_tokens['assistant']}"
init_prompt, chain_response = chain.call(init_prompt)

inner_monologue.append("In: " + str(init_prompt))
Expand Down Expand Up @@ -203,7 +206,9 @@ def handle_plugin_usage(

# Save previous chain response for use in final prompt
prev_chain_response = chain_response
new_prompt = f"{input_prompt}{eos_token}{V2_ASST_PREFIX}{chain_response}{OBSERVATION_SEQ} {tool_response}"
new_prompt = (
f"{input_prompt}{eos_token}{special_tokens['assistant']}{chain_response}{OBSERVATION_SEQ} {tool_response}"
)

new_prompt, chain_response = chain.call(new_prompt)

Expand Down Expand Up @@ -239,15 +244,13 @@ def handle_plugin_usage(
chain_finished = True

if REMOVE_TOOLS_FROM_FINAL_PROMPT:
TEMPLATE = f"""{V2_PROMPTER_PREFIX}{PREFIX}{SUFFIX}"""
TEMPLATE = f"""{special_tokens['prompter']}{PREFIX}{SUFFIX}"""
input_variables = ["input", "chat_history", "language", "current_time"]

prompt_template = PromptTemplate(input_variables=input_variables, template=TEMPLATE)
tool_names = None

final_input = (
f"{input_prompt}{eos_token}{V2_ASST_PREFIX}\n{prev_chain_response}{OBSERVATION_SEQ} {tool_response}"
)
final_input = f"{input_prompt}{eos_token}{special_tokens['assistant']}\n{prev_chain_response}{OBSERVATION_SEQ} {tool_response}"
inner_prompt = prepare_prompt(
final_input,
prompt_template,
Expand Down Expand Up @@ -312,14 +315,18 @@ def handle_standard_usage(
tokenizer: transformers.PreTrainedTokenizer,
custom_instructions: str = "",
):
eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else ""
eos_token = ""
if special_tokens["end"]:
eos_token = special_tokens["end"]
elif hasattr(tokenizer, "eos_token"):
eos_token = tokenizer.eos_token
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

# Non-plugin prompt template can include some external data e.g. datetime, language
action_input_format = (
JSON_FORMAT_PAYLOAD if prompt_template.template.find("payload") != -1 else JSON_FORMAT_NO_PAYLOAD
)
input = f"{original_prompt}{eos_token}{V2_ASST_PREFIX}"
input = f"{original_prompt}{eos_token}{special_tokens['assistant']}"
init_prompt = prepare_prompt(
input,
prompt_template,
Expand Down Expand Up @@ -372,7 +379,7 @@ def handle_conversation(
plugin_enabled = len(tools) > 0
memory: ConversationBufferMemory = build_memory(work_request)

TEMPLATE = f"""{V2_PROMPTER_PREFIX}{PREFIX}{tools_instructions_template}{SUFFIX}"""
TEMPLATE = f"""{special_tokens['prompter']}{PREFIX}{tools_instructions_template}{SUFFIX}"""
input_variables = [
"input",
"chat_history",
Expand Down
4 changes: 0 additions & 4 deletions inference/worker/chat_chain_prompts.py
@@ -1,7 +1,3 @@
V2_ASST_PREFIX = "<|assistant|>"
V2_PROMPTER_PREFIX = "<|prompter|>"
V2_SYSTEM_PREFIX = "<|system|>"

ASSISTANT_PREFIX = "Open Assistant"
HUMAN_PREFIX = "Human"
OBSERVATION_SEQ = "Observation:"
Expand Down
10 changes: 5 additions & 5 deletions inference/worker/chat_chain_utils.py
Expand Up @@ -4,7 +4,7 @@

import requests
import transformers
from chat_chain_prompts import INSTRUCTIONS, OBSERVATION_SEQ, TOOLS_PREFIX, V2_ASST_PREFIX, V2_PROMPTER_PREFIX
from chat_chain_prompts import INSTRUCTIONS, OBSERVATION_SEQ, TOOLS_PREFIX
from hf_langchain_inference import HFInference
from langchain.agents import Tool
from langchain.memory import ConversationBufferMemory
Expand All @@ -13,15 +13,15 @@
from oasst_shared.schemas import inference
from openapi_parser import prepare_plugin_for_llm
from settings import settings
from utils import shared_tokenizer_lock
from utils import shared_tokenizer_lock, special_tokens

RESPONSE_MAX_LENGTH = 2048
DESCRIPTION_FOR_MODEL_MAX_LENGTH = 512

llm_json_parser = HFInference(
inference_server_url=settings.inference_server_url,
max_new_tokens=512,
stop_sequences=["</s>"],
stop_sequences=[special_tokens["end"] if special_tokens["end"] else "</s>"],
top_k=5,
temperature=0.20,
repetition_penalty=(1 / 0.83),
Expand Down Expand Up @@ -139,7 +139,7 @@ def prepare_json(json_str: str) -> str:
except json.decoder.JSONDecodeError as e:
logger.warning(f"JSON is still not valid, trying to fix it with LLM {fixed_json}")
# If still invalid, try using LLM to fix
prompt = f"""{V2_PROMPTER_PREFIX}Below is malformed JSON object string:
prompt = f"""{special_tokens['prompter']}Below is malformed JSON object string:
--------------
{json_str}
--------------
Expand All @@ -151,7 +151,7 @@ def prepare_json(json_str: str) -> str:
1. If malformed JSON object string contains multiple objects, you merge them into one.
2. You will never made up or add any new data, you will only fix the malformed JSON object string.
Here is the fixed JSON object string:</s>{V2_ASST_PREFIX}"""
Here is the fixed JSON object string:{special_tokens['end'] or '</s>'}{special_tokens['assistant']}"""
logger.warning(f"JSON Fix Prompt: {prompt}")
out = llm_json_parser.generate(prompts=[prompt]).generations[0][0].text
out = out[: out.find("}") + 1]
Expand Down
3 changes: 3 additions & 0 deletions inference/worker/settings.py
Expand Up @@ -11,6 +11,9 @@ class Settings(pydantic.BaseSettings):

oa_protocol_version: str = "v2"

# Supported: oasst, chatml
model_prompt_format: str = "oasst"

retry_on_error: bool = True
hf_pause: float = 0.075
max_parallel_requests: int = 1
Expand Down
39 changes: 31 additions & 8 deletions inference/worker/utils.py
Expand Up @@ -11,14 +11,29 @@
import sseclient
import transformers
import websocket
from chat_chain_prompts import V2_PROMPTER_PREFIX, V2_SYSTEM_PREFIX
from loguru import logger
from oasst_shared.schemas import inference
from settings import settings

shared_tokenizer_lock = threading.Lock()


if settings.model_prompt_format == "chatml":
special_tokens = {
"prompter": "<|im_start|>user\n",
"assistant": "<|im_start|>assistant\n",
"system": "<|im_start|>system\n",
"end": "<|im_end|>\n",
}
else:
special_tokens = {
"prompter": "<|prompter|>",
"assistant": "<|assistant|>",
"system": "<|system|>",
"end": "",
}


class TokenBuffer:
"""
A buffer for storing and managing tokens based on various conditions including stop sequences.
Expand Down Expand Up @@ -80,6 +95,13 @@ def get_max_input_length(worker_config: inference.WorkerConfig, plugin_used: boo
return max_input_length


def get_tokens_until(tokens: list[int], target: int | list[int]) -> list[int]:
if isinstance(target, int):
return tokens[: tokens.index(target)]
else:
return next((i for i in range(len(tokens) - len(target) + 1) if tokens[i : i + len(target)] == target))


def truncate_prompt(
tokenizer: transformers.PreTrainedTokenizer,
worker_config: inference.WorkerConfig,
Expand All @@ -96,13 +118,14 @@ def truncate_prompt(
"""
with shared_tokenizer_lock:
ids = tokenizer.encode(prompt)
prompter_prefix_id = tokenizer.convert_tokens_to_ids(V2_PROMPTER_PREFIX)
# prompter_prefix_ids could be int or list of ints
prompter_prefix_ids = tokenizer.convert_tokens_to_ids(special_tokens["prompter"])

system_prompt: str | None = None
system_tokens: list[int] | None = None
if prompt.startswith(V2_SYSTEM_PREFIX):
system_prompt = prompt[: prompt.index(V2_PROMPTER_PREFIX)]
system_tokens = ids[: ids.index(prompter_prefix_id)]
if prompt.startswith(special_tokens["system"]):
system_prompt = prompt[: prompt.index(special_tokens["prompter"])]
system_tokens = get_tokens_until(ids, prompter_prefix_ids)

max_input_length = get_max_input_length(worker_config, plugin_used)

Expand All @@ -117,9 +140,9 @@ def truncate_prompt(
with shared_tokenizer_lock:
prompt = tokenizer.decode(ids)

if V2_PROMPTER_PREFIX not in prompt:
prompt = V2_PROMPTER_PREFIX + prompt
ids = tokenizer.encode(V2_PROMPTER_PREFIX) + ids
if special_tokens["prompter"] not in prompt:
prompt = special_tokens["prompter"] + prompt
ids = tokenizer.encode(special_tokens["prompter"]) + ids

if system_tokens:
prompt = system_prompt + prompt
Expand Down
27 changes: 14 additions & 13 deletions inference/worker/work.py
Expand Up @@ -14,14 +14,11 @@
OBSERVATION_SEQ,
START_SEQ,
THOUGHT_SEQ,
V2_ASST_PREFIX,
V2_PROMPTER_PREFIX,
V2_SYSTEM_PREFIX,
)
from loguru import logger
from oasst_shared.schemas import inference
from settings import settings
from utils import shared_tokenizer_lock
from utils import shared_tokenizer_lock, special_tokens


def make_prompt_and_parameters(
Expand All @@ -32,10 +29,14 @@ def make_prompt_and_parameters(
if settings.oa_protocol_version != "v2":
raise RuntimeError(f"Unsupported oa protocol version: {settings.oa_protocol_version}")

eos_token = tokenizer.eos_token if hasattr(tokenizer, "eos_token") else ""
eos_token = ""
if special_tokens["end"]:
eos_token = special_tokens["end"]
elif hasattr(tokenizer, "eos_token"):
eos_token = tokenizer.eos_token

def _prepare_message(message: inference.MessageRead) -> str:
prefix = V2_ASST_PREFIX if message.is_assistant else V2_PROMPTER_PREFIX
prefix = special_tokens["assistant"] if message.is_assistant else special_tokens["prompter"]
return prefix + message.content + eos_token

# Construct prompt
Expand All @@ -44,7 +45,7 @@ def _prepare_message(message: inference.MessageRead) -> str:
# Prepend system prompt and custom_instructions if it was specified in work parameters
work_params = work_request.parameters
if work_params.system_prompt or work_params.user_profile or work_params.user_response_instructions:
pre_prompt = V2_SYSTEM_PREFIX + (work_params.system_prompt or "")
pre_prompt = special_tokens["system"] + (work_params.system_prompt or "")

if work_params.user_profile or work_params.user_response_instructions:
pre_prompt = f"""{pre_prompt}\n{CUSTOM_INSTRUCTIONS_PREFIX.format(user_profile=work_params.user_profile or "", user_response_instructions=work_params.user_response_instructions or "")}"""
Expand All @@ -53,14 +54,14 @@ def _prepare_message(message: inference.MessageRead) -> str:
messages = [pre_prompt] + messages

# Stringify and append assistant prefix to signify start of generation
prompt = "".join(messages) + V2_ASST_PREFIX
prompt = "".join(messages) + special_tokens["assistant"]

parameters = interface.GenerateStreamParameters.from_work_parameters(work_request.parameters)
if settings.use_stop_sequences:
parameters.stop = [
V2_PROMPTER_PREFIX,
V2_ASST_PREFIX,
V2_SYSTEM_PREFIX,
special_tokens["prompter"],
special_tokens["assistant"],
special_tokens["system"],
]
if eos_token:
parameters.stop.append(eos_token)
Expand All @@ -73,9 +74,9 @@ def _prepare_message(message: inference.MessageRead) -> str:
def prepare_safe_prompt(prompt: str, label: str, rots: str) -> str:
"""Given a prompt, safety label, and safety rule of thumb, prepare a 'safe prompt' to replace the prompt."""
pre_prompt = f"Answer the following request with {label} as responsible chatbot that believes that {rots}: "
input_list = prompt.split(V2_PROMPTER_PREFIX)
input_list = prompt.split(special_tokens["prompter"])
input_list[-1] = pre_prompt + input_list[-1]
return V2_PROMPTER_PREFIX.join(input_list)
return special_tokens["prompter"].join(input_list)


def is_safety_triggered(safety_label: str, safety_level: int) -> bool:
Expand Down
5 changes: 5 additions & 0 deletions oasst-shared/oasst_shared/model_configs.py
Expand Up @@ -145,4 +145,9 @@ def compat_hash(self) -> str:
max_total_length=2048,
quantized=True,
),
"OA_SFT_Llama2_70B_10": ModelConfig(
model_id="OpenAssistant/llama2-70b-oasst-sft-v10",
max_input_length=3072,
max_total_length=4096,
),
}

0 comments on commit d613c81

Please sign in to comment.