From 3215ef17e6088dc92810859348be3c8f0dcd5101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=85=BC=E6=AC=A3?= Date: Fri, 12 Jan 2024 00:05:42 +0800 Subject: [PATCH] openai_api.py: bugfixes for system messages and function calling --- examples/function_call_examples.py | 325 +++++++++++---------- openai_api.py | 443 +++++++++++++++-------------- 2 files changed, 404 insertions(+), 364 deletions(-) diff --git a/examples/function_call_examples.py b/examples/function_call_examples.py index be656788..94e3e71b 100644 --- a/examples/function_call_examples.py +++ b/examples/function_call_examples.py @@ -1,4 +1,6 @@ # Reference: https://openai.com/blog/function-calling-and-other-api-updates +import json +from pprint import pprint import openai @@ -9,216 +11,223 @@ # python openai_api.py; # # Then configure the api_base and api_key in your client: -openai.api_base = "http://localhost:8000/v1" -openai.api_key = "none" +openai.api_base = 'http://localhost:8000/v1' +openai.api_key = 'none' def call_qwen(messages, functions=None): - print(messages) + print('input:') + pprint(messages, indent=2) if functions: - response = openai.ChatCompletion.create( - model="Qwen", messages=messages, functions=functions - ) + response = openai.ChatCompletion.create(model='Qwen', + messages=messages, + functions=functions) else: - response = openai.ChatCompletion.create(model="Qwen", messages=messages) - print(response) - print(response.choices[0].message.content) + response = openai.ChatCompletion.create(model='Qwen', + messages=messages) + response = response.choices[0]['message'] + response = json.loads(json.dumps(response, + ensure_ascii=False)) # fix zh rendering + print('output:') + pprint(response, indent=2) + print() return response def test_1(): - messages = [{"role": "user", "content": "你好"}] + messages = [{'role': 'user', 'content': '你好'}] call_qwen(messages) - messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"}) + messages.append({'role': 'assistant', 'content': '你好!很高兴为你提供帮助。'}) - messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"}) + messages.append({ + 'role': 'user', + 'content': '给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。' + }) call_qwen(messages) - messages.append( - { - "role": "assistant", - "content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……", - } - ) - - messages.append({"role": "user", "content": "给这个故事起一个标题"}) + messages.append({ + 'role': + 'assistant', + 'content': + '故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……', + }) + + messages.append({'role': 'user', 'content': '给这个故事起一个标题'}) call_qwen(messages) def test_2(): functions = [ { - "name_for_human": "谷歌搜索", - "name_for_model": "google_search", - "description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。" - + " Format the arguments as a JSON object.", - "parameters": [ - { - "name": "search_query", - "description": "搜索关键词或短语", - "required": True, - "schema": {"type": "string"}, - } - ], + 'name_for_human': + '谷歌搜索', + 'name_for_model': + 'google_search', + 'description_for_model': + '谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。' + + ' Format the arguments as a JSON object.', + 'parameters': [{ + 'name': 'search_query', + 'description': '搜索关键词或短语', + 'required': True, + 'schema': { + 'type': 'string' + }, + }], }, { - "name_for_human": "文生图", - "name_for_model": "image_gen", - "description_for_model": "文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。" - + " Format the arguments as a JSON object.", - "parameters": [ - { - "name": "prompt", - "description": "英文关键词,描述了希望图像具有什么内容", - "required": True, - "schema": {"type": "string"}, - } - ], + 'name_for_human': + '文生图', + 'name_for_model': + 'image_gen', + 'description_for_model': + '文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。' + + ' Format the arguments as a JSON object.', + 'parameters': [{ + 'name': 'prompt', + 'description': '英文关键词,描述了希望图像具有什么内容', + 'required': True, + 'schema': { + 'type': 'string' + }, + }], }, ] - messages = [{"role": "user", "content": "你好"}] + messages = [{'role': 'user', 'content': '(请不要调用工具)\n\n你好'}] call_qwen(messages, functions) - messages.append( - {"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"}, - ) + messages.append({ + 'role': 'assistant', + 'content': '你好!很高兴见到你。有什么我可以帮忙的吗?' + }, ) - messages.append({"role": "user", "content": "谁是周杰伦"}) + messages.append({'role': 'user', 'content': '搜索一下谁是周杰伦'}) call_qwen(messages, functions) - messages.append( - { - "role": "assistant", - "content": "Thought: 我应该使用Google搜索查找相关信息。", - "function_call": { - "name": "google_search", - "arguments": '{"search_query": "周杰伦"}', - }, - } - ) - - messages.append( - { - "role": "function", - "name": "google_search", - "content": "Jay Chou is a Taiwanese singer.", - } - ) - call_qwen(messages, functions) - messages.append( - { - "role": "assistant", - "content": "周杰伦(Jay Chou)是一位来自台湾的歌手。", + messages.append({ + 'role': 'assistant', + 'content': '我应该使用Google搜索查找相关信息。', + 'function_call': { + 'name': 'google_search', + 'arguments': '{"search_query": "周杰伦"}', }, - ) + }) - messages.append({"role": "user", "content": "他老婆是谁"}) + messages.append({ + 'role': 'function', + 'name': 'google_search', + 'content': 'Jay Chou is a Taiwanese singer.', + }) call_qwen(messages, functions) messages.append( { - "role": "assistant", - "content": "Thought: 我应该使用Google搜索查找相关信息。", - "function_call": { - "name": "google_search", - "arguments": '{"search_query": "周杰伦 老婆"}', - }, - } - ) + 'role': 'assistant', + 'content': '周杰伦(Jay Chou)是一位来自台湾的歌手。', + }, ) - messages.append( - {"role": "function", "name": "google_search", "content": "Hannah Quinlivan"} - ) + messages.append({'role': 'user', 'content': '搜索一下他老婆是谁'}) call_qwen(messages, functions) - messages.append( - { - "role": "assistant", - "content": "周杰伦的老婆是Hannah Quinlivan。", + messages.append({ + 'role': 'assistant', + 'content': '我应该使用Google搜索查找相关信息。', + 'function_call': { + 'name': 'google_search', + 'arguments': '{"search_query": "周杰伦 老婆"}', }, - ) + }) - messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"}) + messages.append({ + 'role': 'function', + 'name': 'google_search', + 'content': 'Hannah Quinlivan' + }) call_qwen(messages, functions) messages.append( { - "role": "assistant", - "content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。", - "function_call": { - "name": "image_gen", - "arguments": '{"prompt": "cute black cat"}', - }, - } - ) + 'role': 'assistant', + 'content': '周杰伦的老婆是Hannah Quinlivan。', + }, ) - messages.append( - { - "role": "function", - "name": "image_gen", - "content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', - } - ) + messages.append({'role': 'user', 'content': '用文生图工具画个可爱的小猫吧,最好是黑猫'}) + call_qwen(messages, functions) + messages.append({ + 'role': 'assistant', + 'content': '我应该使用文生图API来生成一张可爱的小猫图片。', + 'function_call': { + 'name': 'image_gen', + 'arguments': '{"prompt": "cute black cat"}', + }, + }) + + messages.append({ + 'role': + 'function', + 'name': + 'image_gen', + 'content': + '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', + }) call_qwen(messages, functions) def test_3(): - functions = [ - { - "name": "get_current_weather", - "description": "Get the current weather in a given location.", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + functions = [{ + 'name': 'get_current_weather', + 'description': 'Get the current weather in a given location.', + 'parameters': { + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': + 'The city and state, e.g. San Francisco, CA', + }, + 'unit': { + 'type': 'string', + 'enum': ['celsius', 'fahrenheit'] }, - "required": ["location"], }, - } - ] - - messages = [ - { - "role": "user", - # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, - # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. - "content": "波士顿天气如何?", - } - ] + 'required': ['location'], + }, + }] + + messages = [{ + 'role': 'user', + # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, + # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. + 'content': '波士顿天气如何?', + }] call_qwen(messages, functions) messages.append( { - "role": "assistant", - "content": None, - "function_call": { - "name": "get_current_weather", - "arguments": '{"location": "Boston, MA"}', + 'role': 'assistant', + 'content': None, + 'function_call': { + 'name': 'get_current_weather', + 'arguments': '{"location": "Boston, MA"}', }, - }, - ) - - messages.append( - { - "role": "function", - "name": "get_current_weather", - "content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', - } - ) + }, ) + + messages.append({ + 'role': + 'function', + 'name': + 'get_current_weather', + 'content': + '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', + }) call_qwen(messages, functions) def test_4(): + from langchain.agents import AgentType, initialize_agent, load_tools from langchain.chat_models import ChatOpenAI - from langchain.agents import load_tools, initialize_agent, AgentType llm = ChatOpenAI( - model_name="Qwen", - openai_api_base="http://localhost:8000/v1", - openai_api_key="EMPTY", + model_name='Qwen', + openai_api_base='http://localhost:8000/v1', + openai_api_key='EMPTY', streaming=False, ) - tools = load_tools( - ["arxiv"], - ) + tools = load_tools(['arxiv'], ) agent_chain = initialize_agent( tools, llm, @@ -226,15 +235,15 @@ def test_4(): verbose=True, ) # TODO: The performance is okay with Chinese prompts, but not so good when it comes to English. - agent_chain.run("查一下论文 1605.08386 的信息") + agent_chain.run('查一下论文 1605.08386 的信息') -if __name__ == "__main__": - print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") +if __name__ == '__main__': + print('### Test Case 1 - No Function Calling (普通问答、无函数调用) ###') test_1() - print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") + print('### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###') test_2() - print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###") + print('### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###') test_3() - print("### Test Case 4 - Use LangChain (接入Langchain) ###") + print('### Test Case 4 - Use LangChain (接入Langchain) ###') test_4() diff --git a/openai_api.py b/openai_api.py index e9674bcb..fd8e6354 100644 --- a/openai_api.py +++ b/openai_api.py @@ -1,14 +1,16 @@ -# coding=utf-8 -# Implements API for Qwen-7B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) -# Usage: python openai_api.py +# Requirement: +# pip install "openai<1.0" +# Usage: +# python openai_api.py # Visit http://localhost:8000/docs for documents. -import re +import base64 import copy import json import time from argparse import ArgumentParser from contextlib import asynccontextmanager +from pprint import pprint from typing import Dict, List, Literal, Optional, Union import torch @@ -17,20 +19,22 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers.generation import GenerationConfig from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response -import base64 +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.generation import GenerationConfig + class BasicAuthMiddleware(BaseHTTPMiddleware): + def __init__(self, app, username: str, password: str): super().__init__(app) - self.required_credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + self.required_credentials = base64.b64encode( + f'{username}:{password}'.encode()).decode() async def dispatch(self, request: Request, call_next): - authorization: str = request.headers.get("Authorization") + authorization: str = request.headers.get('Authorization') if authorization: try: schema, credentials = authorization.split() @@ -38,16 +42,18 @@ async def dispatch(self, request: Request, call_next): return await call_next(request) except ValueError: pass - + headers = {'WWW-Authenticate': 'Basic'} return Response(status_code=401, headers=headers) - + + def _gc(forced: bool = False): global args if args.disable_gc and not forced: return import gc + gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -63,36 +69,36 @@ async def lifespan(app: FastAPI): # collects GPU memory app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=['*'], allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_methods=['*'], + allow_headers=['*'], ) class ModelCard(BaseModel): id: str - object: str = "model" + object: str = 'model' created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "owner" + owned_by: str = 'owner' root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = None class ModelList(BaseModel): - object: str = "list" + object: str = 'list' data: List[ModelCard] = [] class ChatMessage(BaseModel): - role: Literal["user", "assistant", "system", "function"] + role: Literal['user', 'assistant', 'system', 'function'] content: Optional[str] function_call: Optional[Dict] = None class DeltaMessage(BaseModel): - role: Optional[Literal["user", "assistant", "system"]] = None + role: Optional[Literal['user', 'assistant', 'system']] = None content: Optional[str] = None @@ -102,6 +108,7 @@ class ChatCompletionRequest(BaseModel): functions: Optional[List[Dict]] = None temperature: Optional[float] = None top_p: Optional[float] = None + top_k: Optional[int] = None max_length: Optional[int] = None stream: Optional[bool] = False stop: Optional[List[str]] = None @@ -109,29 +116,28 @@ class ChatCompletionRequest(BaseModel): class ChatCompletionResponseChoice(BaseModel): index: int - message: ChatMessage - finish_reason: Literal["stop", "length", "function_call"] + message: Union[ChatMessage] + finish_reason: Literal['stop', 'length', 'function_call'] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage - finish_reason: Optional[Literal["stop", "length"]] + finish_reason: Optional[Literal['stop', 'length']] class ChatCompletionResponse(BaseModel): model: str - object: Literal["chat.completion", "chat.completion.chunk"] - choices: List[ - Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] - ] + object: Literal['chat.completion', 'chat.completion.chunk'] + choices: List[Union[ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) -@app.get("/v1/models", response_model=ModelList) +@app.get('/v1/models', response_model=ModelList) async def list_models(): global model_args - model_card = ModelCard(id="gpt-3.5-turbo") + model_card = ModelCard(id='gpt-3.5-turbo') return ModelList(data=[model_card]) @@ -141,7 +147,7 @@ def add_extra_stop_words(stop_words): _stop_words = [] _stop_words.extend(stop_words) for x in stop_words: - s = x.lstrip("\n") + s = x.lstrip('\n') if s and (s not in _stop_words): _stop_words.append(s) return _stop_words @@ -157,7 +163,10 @@ def trim_stop_words(response, stop_words): return response -TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" +TOOL_DESC = ( + '{name_for_model}: Call this tool to interact with the {name_for_human} API.' + ' What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}' +) REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: @@ -179,37 +188,28 @@ def trim_stop_words(response, stop_words): _TEXT_COMPLETION_CMD = object() -# -# Temporarily, the system role does not work as expected. -# We advise that you write the setups for role-play in your query, -# i.e., use the user role instead of the system role. -# -# TODO: Use real system role when the model is ready. -# def parse_messages(messages, functions): - if all(m.role != "user" for m in messages): + if all(m.role != 'user' for m in messages): raise HTTPException( status_code=400, - detail=f"Invalid request: Expecting at least one user message.", + detail='Invalid request: Expecting at least one user message.', ) messages = copy.deepcopy(messages) - default_system = "You are a helpful assistant." - system = "" - if messages[0].role == "system": - system = messages.pop(0).content.lstrip("\n").rstrip() - if system == default_system: - system = "" + if messages[0].role == 'system': + system = messages.pop(0).content.lstrip('\n').rstrip() + else: + system = 'You are a helpful assistant.' if functions: tools_text = [] tools_name_text = [] for func_info in functions: - name = func_info.get("name", "") - name_m = func_info.get("name_for_model", name) - name_h = func_info.get("name_for_human", name) - desc = func_info.get("description", "") - desc_m = func_info.get("description_for_model", desc) + name = func_info.get('name', '') + name_m = func_info.get('name_for_model', name) + name_h = func_info.get('name_for_human', name) + desc = func_info.get('description', '') + desc_m = func_info.get('description_for_model', desc) tool = TOOL_DESC.format( name_for_model=name_m, name_for_human=name_h, @@ -217,150 +217,152 @@ def parse_messages(messages, functions): # "Format the arguments as a JSON object." # "Enclose the code within triple backticks (`) at the beginning and end of the code." description_for_model=desc_m, - parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + parameters=json.dumps(func_info['parameters'], + ensure_ascii=False), ) tools_text.append(tool) tools_name_text.append(name_m) - tools_text = "\n\n".join(tools_text) - tools_name_text = ", ".join(tools_name_text) - system += "\n\n" + REACT_INSTRUCTION.format( + tools_text = '\n\n'.join(tools_text) + tools_name_text = ', '.join(tools_name_text) + instruction = (REACT_INSTRUCTION.format( tools_text=tools_text, tools_name_text=tools_name_text, - ) - system = system.lstrip("\n").rstrip() - - dummy_thought = { - "en": "\nThought: I now know the final answer.\nFinal answer: ", - "zh": "\nThought: 我会作答了。\nFinal answer: ", - } + ).lstrip('\n').rstrip()) + else: + instruction = '' - _messages = messages + messages_with_fncall = messages messages = [] - for m_idx, m in enumerate(_messages): + for m_idx, m in enumerate(messages_with_fncall): role, content, func_call = m.role, m.content, m.function_call - if content: - content = content.lstrip("\n").rstrip() - if role == "function": - if (len(messages) == 0) or (messages[-1].role != "assistant"): + content = content or '' + content = content.lstrip('\n').rstrip() + if role == 'function': + if (len(messages) == 0) or (messages[-1].role != 'assistant'): raise HTTPException( status_code=400, - detail=f"Invalid request: Expecting role assistant before role function.", + detail= + 'Invalid request: Expecting role assistant before role function.', ) - messages[-1].content += f"\nObservation: {content}" - if m_idx == len(_messages) - 1: - messages[-1].content += "\nThought:" - elif role == "assistant": + messages[-1].content += f'\nObservation: {content}' + if m_idx == len(messages_with_fncall) - 1: + # add a prefix for text completion + messages[-1].content += '\nThought:' + elif role == 'assistant': if len(messages) == 0: raise HTTPException( status_code=400, - detail=f"Invalid request: Expecting role user before role assistant.", + detail= + 'Invalid request: Expecting role user before role assistant.', ) - last_msg = messages[-1].content - last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 if func_call is None: if functions: - content = dummy_thought["zh" if last_msg_has_zh else "en"] + content + content = f'Thought: I now know the final answer.\nFinal Answer: {content}' else: - f_name, f_args = func_call["name"], func_call["arguments"] - if not content: - if last_msg_has_zh: - content = f"Thought: 我可以使用 {f_name} API。" - else: - content = f"Thought: I can use {f_name}." - content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" - if messages[-1].role == "user": + f_name, f_args = func_call['name'], func_call['arguments'] + if not content.startswith('Thought:'): + content = f'Thought: {content}' + content = f'{content}\nAction: {f_name}\nAction Input: {f_args}' + if messages[-1].role == 'user': messages.append( - ChatMessage(role="assistant", content=content.lstrip("\n").rstrip()) - ) + ChatMessage(role='assistant', + content=content.lstrip('\n').rstrip())) else: - messages[-1].content += content - elif role == "user": + messages[-1].content += '\n' + content + elif role == 'user': messages.append( - ChatMessage(role="user", content=content.lstrip("\n").rstrip()) - ) + ChatMessage(role='user', + content=content.lstrip('\n').rstrip())) else: raise HTTPException( - status_code=400, detail=f"Invalid request: Incorrect role {role}." - ) + status_code=400, + detail=f'Invalid request: Incorrect role {role}.') query = _TEXT_COMPLETION_CMD - if messages[-1].role == "user": + if messages[-1].role == 'user': query = messages[-1].content messages = messages[:-1] if len(messages) % 2 != 0: - raise HTTPException(status_code=400, detail="Invalid request") + raise HTTPException(status_code=400, detail='Invalid request') history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] for i in range(0, len(messages), 2): - if messages[i].role == "user" and messages[i + 1].role == "assistant": - usr_msg = messages[i].content.lstrip("\n").rstrip() - bot_msg = messages[i + 1].content.lstrip("\n").rstrip() - if system and (i == len(messages) - 2): - usr_msg = f"{system}\n\nQuestion: {usr_msg}" - system = "" - for t in dummy_thought.values(): - t = t.lstrip("\n") - if bot_msg.startswith(t) and ("\nAction: " in bot_msg): - bot_msg = bot_msg[len(t) :] + if messages[i].role == 'user' and messages[i + 1].role == 'assistant': + usr_msg = messages[i].content.lstrip('\n').rstrip() + bot_msg = messages[i + 1].content.lstrip('\n').rstrip() + if instruction and (i == len(messages) - 2): + usr_msg = f'{instruction}\n\nQuestion: {usr_msg}' + instruction = '' history.append([usr_msg, bot_msg]) else: raise HTTPException( status_code=400, - detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", + detail= + 'Invalid request: Expecting exactly one user (or function) role before every assistant role.', ) - if system: + if instruction: assert query is not _TEXT_COMPLETION_CMD - query = f"{system}\n\nQuestion: {query}" - return query, history + query = f'{instruction}\n\nQuestion: {query}' + return query, history, system def parse_response(response): - func_name, func_args = "", "" - i = response.rfind("\nAction:") - j = response.rfind("\nAction Input:") - k = response.rfind("\nObservation:") + func_name, func_args = '', '' + i = response.find('\nAction:') + j = response.find('\nAction Input:') + k = response.find('\nObservation:') if 0 <= i < j: # If the text has `Action` and `Action input`, if k < j: # but does not contain `Observation`, # then it is likely that `Observation` is omitted by the LLM, # because the output text may have discarded the stop word. - response = response.rstrip() + "\nObservation:" # Add it back. - k = response.rfind("\nObservation:") - func_name = response[i + len("\nAction:") : j].strip() - func_args = response[j + len("\nAction Input:") : k].strip() + response = response.rstrip() + '\nObservation:' # Add it back. + k = response.find('\nObservation:') + func_name = response[i + len('\nAction:'):j].strip() + func_args = response[j + len('\nAction Input:'):k].strip() + if func_name: + response = response[:i] + t = response.find('Thought: ') + if t >= 0: + response = response[t + len('Thought: '):] + response = response.strip() choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage( - role="assistant", - content=response[:i], - function_call={"name": func_name, "arguments": func_args}, + role='assistant', + content=response, + function_call={ + 'name': func_name, + 'arguments': func_args + }, ), - finish_reason="function_call", + finish_reason='function_call', ) return choice_data - z = response.rfind("\nFinal Answer: ") + + z = response.rfind('\nFinal Answer: ') if z >= 0: - response = response[z + len("\nFinal Answer: ") :] + response = response[z + len('\nFinal Answer: '):] choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", + message=ChatMessage(role='assistant', content=response), + finish_reason='stop', ) return choice_data # completion mode, not chat mode -def text_complete_last_message(history, stop_words_ids, gen_kwargs): - im_start = "<|im_start|>" - im_end = "<|im_end|>" - prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" +def text_complete_last_message(history, stop_words_ids, gen_kwargs, system): + im_start = '<|im_start|>' + im_end = '<|im_end|>' + prompt = f'{im_start}system\n{system}{im_end}' for i, (query, response) in enumerate(history): - query = query.lstrip("\n").rstrip() - response = response.lstrip("\n").rstrip() - prompt += f"\n{im_start}user\n{query}{im_end}" - prompt += f"\n{im_start}assistant\n{response}{im_end}" - prompt = prompt[: -len(im_end)] + query = query.lstrip('\n').rstrip() + response = response.lstrip('\n').rstrip() + prompt += f'\n{im_start}user\n{query}{im_end}' + prompt += f'\n{im_start}assistant\n{response}{im_end}' + prompt = prompt[:-len(im_end)] _stop_words_ids = [tokenizer.encode(im_end)] if stop_words_ids: @@ -369,20 +371,24 @@ def text_complete_last_message(history, stop_words_ids, gen_kwargs): stop_words_ids = _stop_words_ids input_ids = torch.tensor([tokenizer.encode(prompt)]).to(model.device) - output = model.generate(input_ids, stop_words_ids=stop_words_ids, **gen_kwargs).tolist()[0] - output = tokenizer.decode(output, errors="ignore") + output = model.generate(input_ids, + stop_words_ids=stop_words_ids, + **gen_kwargs).tolist()[0] + output = tokenizer.decode(output, errors='ignore') assert output.startswith(prompt) - output = output[len(prompt) :] - output = trim_stop_words(output, ["<|endoftext|>", im_end]) - print(f"\n{prompt}\n\n{output}\n") + output = output[len(prompt):] + output = trim_stop_words(output, ['<|endoftext|>', im_end]) + print(f'\n{prompt}\n\n{output}\n') return output -@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) +@app.post('/v1/chat/completions', response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): global model, tokenizer gen_kwargs = {} + if request.top_k is not None: + gen_kwargs['top_k'] = request.top_k if request.temperature is not None: if request.temperature < 0.01: gen_kwargs['top_k'] = 1 # greedy decoding @@ -395,32 +401,46 @@ async def create_chat_completion(request: ChatCompletionRequest): stop_words = add_extra_stop_words(request.stop) if request.functions: stop_words = stop_words or [] - if "Observation:" not in stop_words: - stop_words.append("Observation:") + if 'Observation:' not in stop_words: + stop_words.append('Observation:') - query, history = parse_messages(request.messages, request.functions) + query, history, system = parse_messages(request.messages, + request.functions) if request.stream: if request.functions: raise HTTPException( status_code=400, - detail="Invalid request: Function calling is not yet implemented for stream mode.", + detail= + 'Invalid request: Function calling is not yet implemented for stream mode.', ) - generate = predict(query, history, request.model, stop_words, gen_kwargs) - return EventSourceResponse(generate, media_type="text/event-stream") - - stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None + generate = predict(query, + history, + request.model, + stop_words, + gen_kwargs, + system=system) + return EventSourceResponse(generate, media_type='text/event-stream') + + stop_words_ids = [tokenizer.encode(s) + for s in stop_words] if stop_words else None if query is _TEXT_COMPLETION_CMD: - response = text_complete_last_message(history, stop_words_ids=stop_words_ids, gen_kwargs=gen_kwargs) + response = text_complete_last_message(history, + stop_words_ids=stop_words_ids, + gen_kwargs=gen_kwargs, + system=system) else: response, _ = model.chat( tokenizer, query, history=history, + system=system, stop_words_ids=stop_words_ids, - **gen_kwargs + **gen_kwargs, ) - print(f"\n{history}\n{query}\n\n{response}\n") + print('') + pprint(history, indent=2) + print(f'{query}\n\n{response}\n') _gc() response = trim_stop_words(response, stop_words) @@ -429,12 +449,12 @@ async def create_chat_completion(request: ChatCompletionRequest): else: choice_data = ChatCompletionResponseChoice( index=0, - message=ChatMessage(role="assistant", content=response), - finish_reason="stop", + message=ChatMessage(role='assistant', content=response), + finish_reason='stop', ) - return ChatCompletionResponse( - model=request.model, choices=[choice_data], object="chat.completion" - ) + return ChatCompletionResponse(model=request.model, + choices=[choice_data], + object='chat.completion') def _dump_json(data: BaseModel, *args, **kwargs) -> str: @@ -445,28 +465,37 @@ def _dump_json(data: BaseModel, *args, **kwargs) -> str: async def predict( - query: str, history: List[List[str]], model_id: str, stop_words: List[str], gen_kwargs: Dict, + query: str, + history: List[List[str]], + model_id: str, + stop_words: List[str], + gen_kwargs: Dict, + system: str, ): global model, tokenizer choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=DeltaMessage(role="assistant"), finish_reason=None - ) - chunk = ChatCompletionResponse( - model=model_id, choices=[choice_data], object="chat.completion.chunk" - ) - yield "{}".format(_dump_json(chunk, exclude_unset=True)) + index=0, delta=DeltaMessage(role='assistant'), finish_reason=None) + chunk = ChatCompletionResponse(model=model_id, + choices=[choice_data], + object='chat.completion.chunk') + yield '{}'.format(_dump_json(chunk, exclude_unset=True)) current_length = 0 - stop_words_ids = [tokenizer.encode(s) for s in stop_words] if stop_words else None + stop_words_ids = [tokenizer.encode(s) + for s in stop_words] if stop_words else None if stop_words: # TODO: It's a little bit tricky to trim stop words in the stream mode. raise HTTPException( status_code=400, - detail="Invalid request: custom stop words are not yet supported for stream mode.", + detail= + 'Invalid request: custom stop words are not yet supported for stream mode.', ) - response_generator = model.chat_stream( - tokenizer, query, history=history, stop_words_ids=stop_words_ids, **gen_kwargs - ) + response_generator = model.chat_stream(tokenizer, + query, + history=history, + stop_words_ids=stop_words_ids, + system=system, + **gen_kwargs) for new_response in response_generator: if len(new_response) == current_length: continue @@ -475,21 +504,20 @@ async def predict( current_length = len(new_response) choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=DeltaMessage(content=new_text), finish_reason=None - ) - chunk = ChatCompletionResponse( - model=model_id, choices=[choice_data], object="chat.completion.chunk" - ) - yield "{}".format(_dump_json(chunk, exclude_unset=True)) - - choice_data = ChatCompletionResponseStreamChoice( - index=0, delta=DeltaMessage(), finish_reason="stop" - ) - chunk = ChatCompletionResponse( - model=model_id, choices=[choice_data], object="chat.completion.chunk" - ) - yield "{}".format(_dump_json(chunk, exclude_unset=True)) - yield "[DONE]" + index=0, delta=DeltaMessage(content=new_text), finish_reason=None) + chunk = ChatCompletionResponse(model=model_id, + choices=[choice_data], + object='chat.completion.chunk') + yield '{}'.format(_dump_json(chunk, exclude_unset=True)) + + choice_data = ChatCompletionResponseStreamChoice(index=0, + delta=DeltaMessage(), + finish_reason='stop') + chunk = ChatCompletionResponse(model=model_id, + choices=[choice_data], + object='chat.completion.chunk') + yield '{}'.format(_dump_json(chunk, exclude_unset=True)) + yield '[DONE]' _gc() @@ -497,36 +525,39 @@ async def predict( def _get_args(): parser = ArgumentParser() parser.add_argument( - "-c", - "--checkpoint-path", + '-c', + '--checkpoint-path', type=str, - default="Qwen/Qwen-7B-Chat", - help="Checkpoint name or path, default to %(default)r", + default='Qwen/Qwen-7B-Chat', + help='Checkpoint name or path, default to %(default)r', ) + parser.add_argument('--api-auth', help='API authentication credentials') + parser.add_argument('--cpu-only', + action='store_true', + help='Run demo with CPU only') + parser.add_argument('--server-port', + type=int, + default=8000, + help='Demo server port.') parser.add_argument( - "--api-auth", help="API authentication credentials" - ) - parser.add_argument( - "--cpu-only", action="store_true", help="Run demo with CPU only" - ) - parser.add_argument( - "--server-port", type=int, default=8000, help="Demo server port." + '--server-name', + type=str, + default='127.0.0.1', + help= + 'Demo server name. Default: 127.0.0.1, which is only visible from the local computer.' + ' If you want other computers to access your server, use 0.0.0.0 instead.', ) parser.add_argument( - "--server-name", - type=str, - default="127.0.0.1", - help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer." - " If you want other computers to access your server, use 0.0.0.0 instead.", + '--disable-gc', + action='store_true', + help='Disable GC after each response generated.', ) - parser.add_argument("--disable-gc", action="store_true", - help="Disable GC after each response generated.") args = parser.parse_args() return args -if __name__ == "__main__": +if __name__ == '__main__': args = _get_args() tokenizer = AutoTokenizer.from_pretrained( @@ -536,14 +567,14 @@ def _get_args(): ) if args.api_auth: - app.add_middleware( - BasicAuthMiddleware, username=args.api_auth.split(":")[0], password=args.api_auth.split(":")[1] - ) + app.add_middleware(BasicAuthMiddleware, + username=args.api_auth.split(':')[0], + password=args.api_auth.split(':')[1]) if args.cpu_only: - device_map = "cpu" + device_map = 'cpu' else: - device_map = "auto" + device_map = 'auto' model = AutoModelForCausalLM.from_pretrained( args.checkpoint_path,