From dccb99c541fc3b80990912afe52937510442e1e1 Mon Sep 17 00:00:00 2001 From: Yimi81 <1548222878@qq.com> Date: Sun, 17 Dec 2023 17:39:21 +0800 Subject: [PATCH 1/3] 1. support openai_api.py to the latest openai sdk(>=1.0.0); 2. add streaming func call --- .../function_call_examples_v2.py | 307 ++++++ examples/openai_api_demo/openai_api.py | 858 +++++++++++++++++ .../openai_api_demo/openai_api_request.py | 40 + examples/openai_api_demo/openai_utils.py | 879 ++++++++++++++++++ 4 files changed, 2084 insertions(+) create mode 100644 examples/openai_api_demo/function_call_examples_v2.py create mode 100644 examples/openai_api_demo/openai_api.py create mode 100644 examples/openai_api_demo/openai_api_request.py create mode 100644 examples/openai_api_demo/openai_utils.py diff --git a/examples/openai_api_demo/function_call_examples_v2.py b/examples/openai_api_demo/function_call_examples_v2.py new file mode 100644 index 00000000..1564f9f1 --- /dev/null +++ b/examples/openai_api_demo/function_call_examples_v2.py @@ -0,0 +1,307 @@ +# Reference: https://openai.com/blog/function-calling-and-other-api-updates +import json +from openai import OpenAI + +# To start an Latest OpenAI-like Qwen server, use the following commands: +# git clone https://github.com/QwenLM/Qwen; +# cd Qwen; +# pip install fastapi uvicorn openai pydantic sse_starlette; +# python examples/openai_api_demo/openai_api.py; +# +# Then configure the api_base and api_key in your client: +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1/", +) + + +# Change the default values of stream parameter to enable streaming +def call_qwen(messages, functions=None, stream=False): + print(messages) + if functions: + response = client.chat.completions.create( + model="Qwen", messages=messages, functions=functions, stream=stream + ) + else: + response = client.chat.completions.create(model="Qwen", messages=messages, stream=stream) + if stream: + for part in response: + print(part.choices[0].delta.content or "", end="", flush=True) + else: + # print(response) + print(response.choices[0].message.content) + return response + + +def test_1(): + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages) + messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"}) + + messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"}) + call_qwen(messages) + messages.append( + { + "role": "assistant", + "content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……", + } + ) + + messages.append({"role": "user", "content": "给这个故事起一个标题"}) + call_qwen(messages) + + +def test_2(): + functions = [ + { + "name_for_human": "谷歌搜索", + "name_for_model": "google_search", + "description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "search_query", + "description": "搜索关键词或短语", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + { + "name_for_human": "文生图", + "name_for_model": "image_gen", + "description_for_model": "文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "prompt", + "description": "英文关键词,描述了希望图像具有什么内容", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + ] + + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages, functions) + messages.append( + {"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"}, + ) + + messages.append({"role": "user", "content": "谁是周杰伦"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "google_search", + "content": "Jay Chou is a Taiwanese singer.", + } + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦(Jay Chou)是一位来自台湾的歌手。", + }, + ) + + messages.append({"role": "user", "content": "他老婆是谁"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦 老婆"}', + }, + } + ) + + messages.append( + {"role": "function", "name": "google_search", "content": "Hannah Quinlivan"} + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦的老婆是Hannah Quinlivan。", + }, + ) + + messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。", + "function_call": { + "name": "image_gen", + "arguments": '{"prompt": "cute black cat"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "image_gen", + "content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', + } + ) + call_qwen(messages, functions) + + +def test_3(): + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + + messages = [ + { + "role": "user", + # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, + # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. + "content": "波士顿天气如何?", + } + ] + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": None, + "function_call": { + "name": "get_current_weather", + "arguments": '{"location": "Boston, MA"}', + }, + }, + ) + + messages.append( + { + "role": "function", + "name": "get_current_weather", + "content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', + } + ) + call_qwen(messages, functions) + + +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + +# Parallel function calling +def test_4(): + # Step 1: send the conversation and available functions to the model + messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + response = client.chat.completions.create( + model="qwen", + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + ) + response_message = response.choices[0].message + print(f"first_response: {response_message} \n") + + tool_calls = response_message.tool_calls + # Step 2: check if the model wanted to call a function + if tool_calls: + # Step 3: call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = { + "get_current_weather": get_current_weather, + } # only one function in this example, but you can have multiple + messages.append(response_message) # extend conversation with assistant's reply + # Step 4: send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + location=function_args.get("location"), + unit=function_args.get("unit"), + ) + print(f"function_response: {function_response} \n") + + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) # extend conversation with function response + + print(f"second_messages: {messages} \n") + second_response = client.chat.completions.create( + model="qwen", + messages=messages + ) # get a new response from the model where it can see the function response + print(f"second_response: {second_response}") + + +if __name__ == "__main__": + print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") + test_1() + print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") + test_2() + print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###") + test_3() + # # Qwen has not optimized parallel tool calls, often unable to parse a parallel call instruction into multiple tool_calls + # print("### Test Case 4 - Parallel function calling (并行函数调用,GPT格式) ###") + # test_4() diff --git a/examples/openai_api_demo/openai_api.py b/examples/openai_api_demo/openai_api.py new file mode 100644 index 00000000..f8363c34 --- /dev/null +++ b/examples/openai_api_demo/openai_api.py @@ -0,0 +1,858 @@ +import gc +import traceback +import torch +import uvicorn +import time +import uuid +import anyio +import json +from anyio.streams.memory import MemoryObjectSendStream + +from abc import ABC +from threading import Lock +from argparse import ArgumentParser +from contextlib import asynccontextmanager +from functools import partial +from typing import Dict, List, Any, Literal, Optional, Union, Tuple, Iterator, Iterable, AsyncIterator +from loguru import logger +from starlette.concurrency import run_in_threadpool, iterate_in_threadpool +from sse_starlette import EventSourceResponse +from pydantic import BaseModel + +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from openai.types.model import Model +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall +from openai.types.completion_usage import CompletionUsage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice +from openai.types.chat.chat_completion_chunk import ( + ChoiceDelta, + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, +) +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, +) + +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.generation import GenerationConfig + +from openai_utils import ( + Role, + ModelList, + ChatCompletionCreateParams, + CompletionCreateParams, + ErrorCode, + ErrorResponse, + model_dump, + model_parse, + model_json, + build_qwen_chat_input, + is_partial_stop, + prepare_logits_processor) + + +llama_outer_lock = Lock() + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/v1/models") +async def list_models(): + return ModelList( + data=[ + Model( + id="qwen", + object="model", + created=int(time.time()), + owned_by="open" + ) + ] +) + + +@app.post("/v1/chat/completions") +async def create_chat_completion( + request: ChatCompletionCreateParams, + raw_request: Request +): + global model, tokenizer + + if len(request.messages) < 1 or request.messages[-1]["role"] == Role.ASSISTANT: + raise HTTPException(status_code=400, detail="Invalid request") + + request = await handle_request(request, template.stop) + request.max_tokens = request.max_tokens or 1024 + + params = model_dump(request) + params.update(dict(echo=False)) + logger.debug(f"==== request ====\n{params}") + + iterator_or_completion = await run_in_threadpool(_create_chat_completion, params) + + if isinstance(iterator_or_completion, Iterator): + # It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid, and we can use it to stream the response. + def iterator() -> Iterator: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( + get_event_publisher, + request=raw_request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + ) + else: + return iterator_or_completion + + +def _create_chat_completion( + params: Optional[Dict[str, Any]] = None, + **kwargs, +) -> Union[Iterator, ChatCompletion]: + params = params or {} + params.update(kwargs) + return ( + _create_chat_completion_stream(params) + if params.get("stream", False) + else _create_chat_completion_non_stream(params) + ) + + +def _create_chat_completion_stream(params: Dict[str, Any]) -> Iterator: + """ + Creates a chat completion stream. + + Args: + params (Dict[str, Any]): The parameters for generating the chat completion. + + Yields: + Dict[str, Any]: The output of the chat completion stream. + """ + _id, _created, _model = None, None, None + has_function_call = False + for i, output in enumerate(_generate(params)): + if output["error_code"] != 0: + yield output + return + + _id, _created, _model = output["id"], output["created"], output["model"] + if i == 0: + choice = ChunkChoice( + index=0, + delta=ChoiceDelta(role="assistant", content=""), + finish_reason=None, + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + finish_reason = output["finish_reason"] + if len(output["delta"]) == 0 and finish_reason != "function_call": + continue + + function_call = None + if finish_reason == "function_call": + try: + _, function_call = template.parse_assistant_response( + output["text"], params.get("functions"), params.get("tools"), + ) + except Exception as e: + traceback.print_exc() + logger.warning("Failed to parse tool call") + + if isinstance(function_call, dict) and "arguments" in function_call: + has_function_call = True + function_call = ChoiceDeltaFunctionCall(**function_call) + delta = ChoiceDelta( + content=output["delta"], + function_call=function_call + ) + elif isinstance(function_call, dict) and "function" in function_call: + has_function_call = True + finish_reason = "tool_calls" + function_call["index"] = 0 + tool_calls = [model_parse(ChoiceDeltaToolCall, function_call)] + delta = ChoiceDelta( + content=output["delta"], + tool_calls=tool_calls, + ) + else: + delta = ChoiceDelta(content=output["delta"]) + + choice = ChunkChoice( + index=0, + delta=delta, + finish_reason=finish_reason + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + if not has_function_call: + choice = ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason="stop" + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + +def _create_chat_completion_non_stream(params: Dict[str, Any]) -> Union[ChatCompletion, JSONResponse]: + """ + Creates a chat completion based on the given parameters. + + Args: + params (Dict[str, Any]): The parameters for generating the chat completion. + + Returns: + ChatCompletion: The generated chat completion. + """ + last_output = None + for output in _generate(params): + last_output = output + + if last_output["error_code"] != 0: + return create_error_response(last_output["error_code"], last_output["text"]) + + function_call, finish_reason = None, "stop" + if params.get("functions") or params.get("tools"): + try: + res, function_call = template.parse_assistant_response( + last_output["text"], params.get("functions"), params.get("tools"), + ) + last_output["text"] = res + except Exception as e: + traceback.print_exc() + logger.warning("Failed to parse tool call") + + if isinstance(function_call, dict) and "arguments" in function_call: + finish_reason = "function_call" + function_call = FunctionCall(**function_call) + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"], + function_call=function_call, + ) + elif isinstance(function_call, dict) and "function" in function_call: + finish_reason = "tool_calls" + tool_calls = [model_parse(ChatCompletionMessageToolCall, function_call)] + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"], + tool_calls=tool_calls, + ) + else: + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"].strip(), + ) + + choice = Choice( + index=0, + message=message, + finish_reason=finish_reason, + ) + usage = model_parse(CompletionUsage, last_output["usage"]) + return ChatCompletion( + id=f"chat{last_output['id']}", + choices=[choice], + created=last_output["created"], + model=last_output["model"], + object="chat.completion", + usage=usage, + ) + + +def _generate(params: Dict[str, Any]) -> Iterator: + """ + Generates text based on the given parameters. + + Args: + params (Dict[str, Any]): A dictionary containing the parameters for text generation. + + Yields: + Iterator: A dictionary containing the generated text and error code. + """ + messages = params.get("messages") + inputs, prompt = _apply_chat_template( + messages, + max_new_tokens=params.get("max_tokens", 256), + functions=params.get("functions"), + tools=params.get("tools"), + ) + + params.update(dict(inputs=inputs, prompt=prompt)) + + try: + for output in _generate_stream_func(params): + output["error_code"] = 0 + yield output + + except (ValueError, RuntimeError) as e: + traceback.print_exc() + yield { + "text": f"{e}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + + +def _apply_chat_template( + messages: List[ChatCompletionMessageParam], + max_new_tokens: Optional[int] = 256, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Tuple[Union[List[int], Dict[str, Any]], Optional[str]]: + """ + Apply chat template to generate model inputs and prompt. + + Args: + messages (List[ChatCompletionMessageParam]): List of chat completion message parameters. + max_new_tokens (Optional[int], optional): Maximum number of new tokens to generate. Defaults to 256. + functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional): Functions to apply to the messages. Defaults to None. + tools (Optional[List[Dict[str, Any]]], optional): Tools to apply to the messages. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[Union[List[int], Dict[str, Any]], Union[str, None]]: Tuple containing the generated inputs and prompt. + """ + if template.function_call_available: + messages = template.postprocess_messages( + messages, functions, tools=tools, + ) + if functions or tools: + logger.debug(f"==== Messages with tools ====\n{messages}") + + inputs = build_qwen_chat_input( + tokenizer, messages, context_len, max_new_tokens, functions, tools + ) + return inputs, None + + +@torch.inference_mode() +def _generate_stream_func( + params: Dict[str, Any], +): + # Read parameters + input_ids = params.get("inputs") + prompt = params.get("prompt") + model_name = params.get("model", "llm") + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_tokens", 256)) + logprobs = params.get("logprobs") + echo = bool(params.get("echo", True)) + stop_str = params.get("stop") + + stop_token_ids = params.get("stop_token_ids") or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + device = model.device + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values, sent_interrupt = None, False + token_logprobs = [None] # The first token has no logprobs. + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) + previous_text = "" + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [output_ids if sent_interrupt else [token]], device=device + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=None if sent_interrupt else past_key_values, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [output_ids if sent_interrupt else [token]], device=device + ), + use_cache=True, + past_key_values=None if sent_interrupt else past_key_values, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + + token = tokens[0] + output_ids.append(token) + + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % 2 == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len(prompt) + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=False, # fix for qwen react + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], + "top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + partially_stopped, finish_reason = False, None + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + if each_stop == "Observation:": + finish_reason = "function_call" + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if (not partially_stopped) and output and output[-1] != "�": + delta_text = output[len(previous_text):] + previous_text = output + + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "delta": delta_text, + "text": output, + "logprobs": ret_logprobs, + "finish_reason": finish_reason, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + } + + if stopped: + break + + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "delta": "", + "text": output, + "logprobs": ret_logprobs, + "finish_reason": "stop", + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + + +class QwenTemplate(ABC): + + name = "qwen" + system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + allow_models = ["qwen"] + stop = { + "token_ids": [151643, 151644, 151645], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + "strings": ["<|endoftext|>", "<|im_end|>"], + } + function_call_available = True + + @property + def template(self) -> str: + """ This template formats inputs in the standard ChatML format. See + https://github.com/openai/openai-python/blob/main/chatml.md + """ + return ( + "{{ system_prompt }}" + "{% for message in messages %}" + "{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" + ) + + def postprocess_messages( + self, + messages: List[ChatCompletionMessageParam], + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> List[Dict[str, Any]]: + return messages + + def parse_assistant_response( + self, + output: str, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[str, Optional[Union[str, Dict[str, Any]]]]: + func_name, func_args = "", "" + i = output.rfind("\nAction:") + j = output.rfind("\nAction Input:") + k = output.rfind("\nObservation:") + + if 0 <= i < j: # If the text has `Action` and `Action input`, + if k < j: # but does not contain `Observation`, + # then it is likely that `Observation` is omitted by the LLM, + # because the output text may have discarded the stop word. + output = output.rstrip() + "\nObservation:" # Add it back. + k = output.rfind("\nObservation:") + func_name = output[i + len("\nAction:"): j].strip() + func_args = output[j + len("\nAction Input:"): k].strip() + + if func_name: + if functions: + function_call = { + "name": func_name, + "arguments": func_args + } + else: + function_call = { + "function": { + "name": func_name, + "arguments": func_args + }, + "id": func_name, + "type": "function", + } + return output[:k], function_call + + z = output.rfind("\nFinal Answer: ") + if z >= 0: + output = output[z + len("\nFinal Answer: "):] + return output, None + + +async def handle_request( + request: Union[CompletionCreateParams, ChatCompletionCreateParams], + stop: Dict[str, Any] = None +) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]: + error_check_ret = check_requests(request) + if error_check_ret is not None: + raise error_check_ret + + # stop settings + _stop, _stop_token_ids = [], [] + if stop is not None: + _stop_token_ids = stop.get("token_ids", []) + _stop = stop.get("strings", []) + + request.stop = request.stop or [] + if isinstance(request.stop, str): + request.stop = [request.stop] + + if request.functions: + request.stop.append("Observation:") + + request.stop = list(set(_stop + request.stop)) + request.stop_token_ids = request.stop_token_ids or [] + request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids)) + + return request + + +def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.stop is None or isinstance(request.stop, (str, list)): + return None + else: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500) + + +async def get_event_publisher( + request: Request, + inner_send_chan: MemoryObjectSendStream, + iterator: Union[Iterator, AsyncIterator], +): + async with inner_send_chan: + try: + async for chunk in iterate_in_threadpool(iterator): + if isinstance(chunk, BaseModel): + chunk = model_json(chunk) + elif isinstance(chunk, dict): + chunk = json.dumps(chunk, ensure_ascii=False) + + await inner_send_chan.send(dict(data=chunk)) + + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + + if llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() + except anyio.get_cancelled_exc_class() as e: + logger.info("disconnected") + with anyio.move_on_after(1, shield=True): + logger.info(f"Disconnected from client (via refresh/close) {request.client}") + raise e + + +def _get_args(): + parser = ArgumentParser() + parser.add_argument( + "-c", + "--checkpoint-path", + type=str, + default="Qwen/Qwen-7B-Chat", + help="Checkpoint name or path, default to %(default)r", + ) + parser.add_argument( + "--cpu-only", action="store_true", help="Run demo with CPU only" + ) + parser.add_argument( + "--server-port", type=int, default=8000, help="Demo server port." + ) + parser.add_argument( + "--server-name", + type=str, + default="127.0.0.1", + help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer." + " If you want other computers to access your server, use 0.0.0.0 instead.", + ) + parser.add_argument( + "--context_len", type=int, default=None, help="Context length for generating completions." + ) + parser.add_argument("--disable-gc", action="store_true", + help="Disable GC after each response generated.") + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = _get_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ) + + if args.cpu_only: + device = "cpu" + else: + device = "cuda" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ).to(device).eval() + + # Multi-GPU support, use the following two lines instead of the above line, num gpus to your actual number of graphics cards + # from utils import load_model_on_gpus + # model = load_model_on_gpus(args.checkpoint_path, num_gpus=2) + + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ) + + context_len = 8192 if args.context_len is None else args.context_len + template = QwenTemplate() + + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) \ No newline at end of file diff --git a/examples/openai_api_demo/openai_api_request.py b/examples/openai_api_demo/openai_api_request.py new file mode 100644 index 00000000..081e7849 --- /dev/null +++ b/examples/openai_api_demo/openai_api_request.py @@ -0,0 +1,40 @@ +from openai import OpenAI + +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1/", +) + + +# List models API +models = client.models.list() +print(models.model_dump()) + + +# Chat completion API +chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "你好,请问你是谁?", + } + ], + model="qwen", +) +print(chat_completion) + + +# Stream +stream = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "你好,请问你是谁?", + } + ], + model="qwen", + stream=True, +) +for part in stream: + print(part.choices[0].delta.content or "", end="", flush=True) + diff --git a/examples/openai_api_demo/openai_utils.py b/examples/openai_api_demo/openai_utils.py new file mode 100644 index 00000000..fe30d554 --- /dev/null +++ b/examples/openai_api_demo/openai_utils.py @@ -0,0 +1,879 @@ + +import pydantic +import json +import re + +from copy import deepcopy +from enum import Enum, IntEnum +from pydantic import BaseModel +from loguru import logger +from typing import ( + Dict, + List, + Any, + Literal, + Optional, + Union, + cast, + Type, + Tuple +) + +from transformers import PreTrainedTokenizer +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastapi import HTTPException + +from openai.types.model import Model +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, +) +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.completion_create_params import ResponseFormat +from openai.types.create_embedding_response import Usage + + +TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" + +REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: + +{tools_text} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tools_name_text}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin!""" + +_TEXT_COMPLETION_CMD = object() + +# --------------- Pydantic v2 compatibility --------------- + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + + +def model_json(model: pydantic.BaseModel, **kwargs) -> str: + if PYDANTIC_V2: + return model.model_dump_json(**kwargs) + return model.json(**kwargs) # type: ignore + + +def model_dump(model: pydantic.BaseModel, **kwargs) -> Dict[str, Any]: + if PYDANTIC_V2: + return model.model_dump(**kwargs) + return cast( + "dict[str, Any]", + model.dict(**kwargs), + ) + + +def model_parse(model: Type[pydantic.BaseModel], data: Any) -> pydantic.BaseModel: + if PYDANTIC_V2: + return model.model_validate(data) + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + + +def disable_warnings(model: Type[pydantic.BaseModel]): + # Disable warning for model_name settings + if PYDANTIC_V2: + model.model_config["protected_namespaces"] = () + +def parse_messages( + messages: List[ChatCompletionMessageParam], split_role="user" +) -> Tuple[str, List[List[ChatCompletionMessageParam]]]: + """ + Parse a list of chat completion messages into system and rounds. + + Args: + messages (List[ChatCompletionMessageParam]): The list of chat completion messages. + split_role: The role at which to split the rounds. Defaults to Role.USER. + + Returns: + Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds. + """ + system, rounds = "", [] + r = [] + for i, message in enumerate(messages): + if message["role"] == "system": + system = message["content"] + continue + if message["role"] == split_role and r: + rounds.append(r) + r = [] + r.append(message) + if r: + rounds.append(r) + return system, rounds + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 + + +class ModelList(BaseModel): + object: str = "list" + data: List[Model] = [] + + +class ChatCompletionCreateParams(BaseModel): + messages: List[ChatCompletionMessageParam] + """A list of messages comprising the conversation so far. + + [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models). + """ + + model: str + """ID of the model to use. + + See the + [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) + table for details on which models work with the Chat API. + """ + + frequency_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on their existing frequency in the + text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + function_call: Optional[FunctionCall] = None + """Deprecated in favor of `tool_choice`. + + Controls which (if any) function is called by the model. `none` means the model + will not call a function and instead generates a message. `auto` means the model + can pick between generating a message or calling a function. Specifying a + particular function via `{"name": "my_function"}` forces the model to call that + function. + + `none` is the default when no functions are present. `auto`` is the default if + functions are present. + """ + + functions: Optional[List] = None + """Deprecated in favor of `tools`. + + A list of functions the model may generate JSON inputs for. + """ + + logit_bias: Optional[Dict[str, int]] = None + """Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + """ + + max_tokens: Optional[int] = None + """The maximum number of [tokens](/tokenizer) to generate in the chat completion. + + The total length of input tokens and generated tokens is limited by the model's + context length. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + n: Optional[int] = 1 + """How many chat completion choices to generate for each input message.""" + + presence_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + response_format: Optional[ResponseFormat] = None + """An object specifying the format that the model must output. + + Used to enable JSON mode. + """ + + seed: Optional[int] = None + """This feature is in Beta. + + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same `seed` and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + `system_fingerprint` response parameter to monitor changes in the backend. + """ + + stop: Optional[Union[str, List[str]]] = None + """Up to 4 sequences where the API will stop generating further tokens.""" + + temperature: Optional[float] = 0.9 + """What sampling temperature to use, between 0 and 2. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + """ + + tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None + """ + Controls which (if any) function is called by the model. `none` means the model + will not call a function and instead generates a message. `auto` means the model + can pick between generating a message or calling a function. Specifying a + particular function via + `{"type: "function", "function": {"name": "my_function"}}` forces the model to + call that function. + + `none` is the default when no functions are present. `auto` is the default if + functions are present. + """ + + tools: Optional[List] = None + """A list of tools the model may call. + + Currently, only functions are supported as a tool. Use this to provide a list of + functions the model may generate JSON inputs for. + """ + + top_p: Optional[float] = 1.0 + """ + An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + stream: Optional[bool] = False + """If set, partial message deltas will be sent, like in ChatGPT. + + Tokens will be sent as data-only + [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` + message. + [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + """ + + # Addictional parameters + repetition_penalty: Optional[float] = 1.03 + """The parameter for repetition penalty. 1.0 means no penalty. + See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details. + """ + + typical_p: Optional[float] = None + """Typical Decoding mass. + See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information + """ + + watermark: Optional[bool] = False + """Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226) + """ + + best_of: Optional[int] = 1 + + ignore_eos: Optional[bool] = False + + use_beam_search: Optional[bool] = False + + stop_token_ids: Optional[List[int]] = None + + skip_special_tokens: Optional[bool] = True + + spaces_between_special_tokens: Optional[bool] = True + + min_p: Optional[float] = 0.0 + + +class CompletionCreateParams(BaseModel): + model: str + """ID of the model to use. + + You can use the + [List models](https://platform.openai.com/docs/api-reference/models/list) API to + see all of your available models, or see our + [Model overview](https://platform.openai.com/docs/models/overview) for + descriptions of them. + """ + + prompt: Union[str, List[str], List[int], List[List[int]], None] + """ + The prompt(s) to generate completions for, encoded as a string, array of + strings, array of tokens, or array of token arrays. + + Note that <|endoftext|> is the document separator that the model sees during + training, so if a prompt is not specified the model will generate as if from the + beginning of a new document. + """ + + best_of: Optional[int] = 1 + """ + Generates `best_of` completions server-side and returns the "best" (the one with + the highest log probability per token). Results cannot be streamed. + + When used with `n`, `best_of` controls the number of candidate completions and + `n` specifies how many to return – `best_of` must be greater than `n`. + + **Note:** Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and `stop`. + """ + + echo: Optional[bool] = False + """Echo back the prompt in addition to the completion""" + + frequency_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on their existing frequency in the + text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + logit_bias: Optional[Dict[str, int]] = None + """Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the GPT + tokenizer) to an associated bias value from -100 to 100. You can use this + [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to + convert text to token IDs. Mathematically, the bias is added to the logits + generated by the model prior to sampling. The exact effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; + values like -100 or 100 should result in a ban or exclusive selection of the + relevant token. + + As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token + from being generated. + """ + + logprobs: Optional[int] = None + """ + Include the log probabilities on the `logprobs` most likely tokens, as well the + chosen tokens. For example, if `logprobs` is 5, the API will return a list of + the 5 most likely tokens. The API will always return the `logprob` of the + sampled token, so there may be up to `logprobs+1` elements in the response. + + The maximum value for `logprobs` is 5. + """ + + max_tokens: Optional[int] = 16 + """The maximum number of [tokens](/tokenizer) to generate in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + n: Optional[int] = 1 + """How many completions to generate for each prompt. + + **Note:** Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and `stop`. + """ + + presence_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + seed: Optional[int] = None + """ + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same `seed` and parameters should return + the same result. + + Determinism is not guaranteed, and you should refer to the `system_fingerprint` + response parameter to monitor changes in the backend. + """ + + stop: Optional[Union[str, List[str]]] = None + """Up to 4 sequences where the API will stop generating further tokens. + + The returned text will not contain the stop sequence. + """ + + suffix: Optional[str] = None + """The suffix that comes after a completion of inserted text.""" + + temperature: Optional[float] = 1. + """What sampling temperature to use, between 0 and 2. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + """ + + top_p: Optional[float] = 1. + """ + An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + stream: Optional[bool] = False + """If set, partial message deltas will be sent, like in ChatGPT. + + Tokens will be sent as data-only + [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` + message. + [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + """ + + # Addictional parameters + repetition_penalty: Optional[float] = 1.03 + """The parameter for repetition penalty. 1.0 means no penalty. + See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details. + """ + + typical_p: Optional[float] = None + """Typical Decoding mass. + See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information + """ + + watermark: Optional[bool] = False + """Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226) + """ + + ignore_eos: Optional[bool] = False + + use_beam_search: Optional[bool] = False + + stop_token_ids: Optional[List[int]] = None + + skip_special_tokens: Optional[bool] = True + + spaces_between_special_tokens: Optional[bool] = True + + min_p: Optional[float] = 0.0 + + +class EmbeddingCreateParams(BaseModel): + input: Union[str, List[str], List[int], List[List[int]]] + """Input text to embed, encoded as a string or array of tokens. + + To embed multiple inputs in a single request, pass an array of strings or array + of token arrays. The input must not exceed the max input tokens for the model + (8192 tokens for `text-embedding-ada-002`) and cannot be an empty string. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + model: str + """ID of the model to use. + + You can use the + [List models](https://platform.openai.com/docs/api-reference/models/list) API to + see all of your available models, or see our + [Model overview](https://platform.openai.com/docs/models/overview) for + descriptions of them. + """ + + encoding_format: Literal["float", "base64"] = "float" + """The format to return the embeddings in. + + Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + +class Embedding(BaseModel): + embedding: Any + """The embedding vector, which is a list of floats. + + The length of vector depends on the model as listed in the + [embedding guide](https://platform.openai.com/docs/guides/embeddings). + """ + + index: int + """The index of the embedding in the list of embeddings.""" + + object: Literal["embedding"] + """The object type, which is always "embedding".""" + + +class CreateEmbeddingResponse(BaseModel): + data: List[Embedding] + """The list of embeddings generated by the model.""" + + model: str + """The name of the model used to generate the embedding.""" + + object: Literal["list"] + """The object type, which is always "list".""" + + usage: Usage + """The usage information for the request.""" + + +def build_qwen_chat_input( + tokenizer: PreTrainedTokenizer, + messages: List[ChatCompletionMessageParam], + context_len: int = 8192, + max_new_tokens: int = 256, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> List[int]: + """ + Builds the input tokens for Qwen chat generation. + + Refs: + https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py + + Args: + tokenizer: The tokenizer used to encode the input tokens. + messages: The list of chat messages. + context_len: The maximum length of the context. + max_new_tokens: The maximum number of new tokens to add. + functions: Optional dictionary or list of dictionaries representing the functions. + tools: Optional list of dictionaries representing the tools. + + Returns: + The list of input tokens. + """ + query, history = process_qwen_messages(messages, functions, tools) + if query is _TEXT_COMPLETION_CMD: + return build_last_message_input(tokenizer, history) + + messages = [] + for q, r in history: + messages.extend( + [ + ChatCompletionUserMessageParam(role="user", content=q), + ChatCompletionAssistantMessageParam(role="assistant", content=r) + ] + ) + messages.append(ChatCompletionUserMessageParam(role="user", content=query)) + + max_input_tokens = context_len - max_new_tokens + system, rounds = parse_messages(messages) + system = f"You are a helpful assistant.{system}" + + im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") + + def _tokenize_str(role, content): + return tokenizer.encode( + role, allowed_special=set() + ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + + system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for r in rounds[::-1]: + round_tokens = [] + for message in r: + if round_tokens: + round_tokens += nl_tokens + + if message["role"] == Role.USER: + content_tokens = im_start_tokens + _tokenize_str("user", message["content"]) + im_end_tokens + else: + content_tokens = im_start_tokens + _tokenize_str("assistant", message["content"]) + im_end_tokens + + round_tokens.extend(content_tokens) + + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + if history_tokens: + history_tokens = nl_tokens + history_tokens + + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + nl_tokens + history_tokens + if messages[-1]["role"] != Role.ASSISTANT: + input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens + return input_tokens[-max_input_tokens:] # truncate left + + + +def process_qwen_messages( + messages: List[ChatCompletionMessageParam], + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Tuple[str, List[List[str]]]: + """ + Process the Qwen messages and generate a query and history. + + Args: + messages (List[ChatCompletionMessageParam]): The list of chat completion messages. + functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]): The functions to be used. + tools (Optional[List[Dict[str, Any]]]): The tools to be used. + + Returns: + Tuple[str, List[List[str]]]: The generated query and history. + """ + if all(m["role"] != Role.USER for m in messages): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting at least one user message.", + ) + + messages = deepcopy(messages) + default_system = "You are a helpful assistant." + system = "" + if messages[0]["role"] == Role.SYSTEM: + system = messages.pop(0)["content"].lstrip("\n").rstrip() + if system == default_system: + system = "" + + if tools: + functions = [t["function"] for t in tools] + + if functions: + tools_text = [] + tools_name_text = [] + for func_info in functions: + name = func_info.get("name", "") + name_m = func_info.get("name_for_model", name) + name_h = func_info.get("name_for_human", name) + desc = func_info.get("description", "") + desc_m = func_info.get("description_for_model", desc) + tool = TOOL_DESC.format( + name_for_model=name_m, + name_for_human=name_h, + # Hint: You can add the following format requirements in description: + # "Format the arguments as a JSON object." + # "Enclose the code within triple backticks (`) at the beginning and end of the code." + description_for_model=desc_m, + parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + ) + + tools_text.append(tool) + tools_name_text.append(name_m) + + tools_text = "\n\n".join(tools_text) + tools_name_text = ", ".join(tools_name_text) + system += "\n\n" + REACT_INSTRUCTION.format( + tools_text=tools_text, + tools_name_text=tools_name_text, + ) + system = system.lstrip("\n").rstrip() + + dummy_thought = { + "en": "\nThought: I now know the final answer.\nFinal answer: ", + "zh": "\nThought: 我会作答了。\nFinal answer: ", + } + + _messages = messages + messages = [] + for m_idx, m in enumerate(_messages): + role, content = m["role"], m["content"] + func_call, tools_call = m.get("function_call", None), m.get("tools_call", None) + if content: + content = content.lstrip("\n").rstrip() + if role in [Role.FUNCTION, Role.TOOL]: + if (len(messages) == 0) or (messages[-1]["role"] != Role.ASSISTANT): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role assistant before role function.", + ) + messages[-1]["content"] += f"\nObservation: {content}" + if m_idx == len(_messages) - 1: + messages[-1]["content"] += "\nThought:" + elif role == Role.ASSISTANT: + if len(messages) == 0: + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role user before role assistant.", + ) + last_msg = messages[-1]["content"] + last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 + + if func_call is None and tools_call is None: + if functions or tools_call: + content = dummy_thought["zh" if last_msg_has_zh else "en"] + content + else: + if func_call: + f_name, f_args = func_call.get("name"), func_call.get("arguments") + else: + f_name, f_args = tools_call[0]["function"]["name"], tools_call[0]["function"]["arguments"] + if not content: + if last_msg_has_zh: + content = f"Thought: 我可以使用 {f_name} API。" + else: + content = f"Thought: I can use {f_name}." + + if messages[-1]["role"] == Role.USER: + messages.append( + ChatCompletionAssistantMessageParam(role="assistant", content=content.lstrip("\n").rstrip()) + ) + else: + messages[-1]["content"] += content + elif role == Role.USER: + messages.append( + ChatCompletionUserMessageParam(role="user", content=content.lstrip("\n").rstrip()) + ) + else: + raise HTTPException( + status_code=400, detail=f"Invalid request: Incorrect role {role}." + ) + + query = _TEXT_COMPLETION_CMD + if messages[-1]["role"] == Role.USER: + query = messages[-1]["content"] + messages = messages[:-1] + + if len(messages) % 2 != 0: + raise HTTPException(status_code=400, detail="Invalid request") + + history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] + for i in range(0, len(messages), 2): + if messages[i]["role"] == Role.USER and messages[i + 1]["role"] == Role.ASSISTANT: + usr_msg = messages[i]["content"].lstrip("\n").rstrip() + bot_msg = messages[i + 1]["content"].lstrip("\n").rstrip() + if system and (i == len(messages) - 2): + usr_msg = f"{system}\n\nQuestion: {usr_msg}" + system = "" + for t in dummy_thought.values(): + t = t.lstrip("\n") + if bot_msg.startswith(t) and ("\nAction: " in bot_msg): + bot_msg = bot_msg[len(t):] + history.append([usr_msg, bot_msg]) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", + ) + if system: + assert query is not _TEXT_COMPLETION_CMD + query = f"{system}\n\nQuestion: {query}" + return query, history + + +def build_last_message_input(tokenizer: PreTrainedTokenizer, history: list): + im_start = "<|im_start|>" + im_end = "<|im_end|>" + prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" + for i, (query, response) in enumerate(history): + query = query.lstrip("\n").rstrip() + response = response.lstrip("\n").rstrip() + prompt += f"\n{im_start}user\n{query}{im_end}" + prompt += f"\n{im_start}assistant\n{response}{im_end}" + prompt = prompt[:-len(im_end)] + logger.debug(f"==== Prompt with tools ====\n{prompt}") + return tokenizer.encode(prompt) + + +def is_partial_stop(output: str, stop_str: str): + """ Check whether the output contains a partial stop str. """ + return any( + stop_str.startswith(output[-i:]) + for i in range(0, min(len(output), len(stop_str))) + ) + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + """ + Prepare a list of logits processors based on the provided parameters. + + Args: + temperature (float): The temperature value for temperature warping. + repetition_penalty (float): The repetition penalty value. + top_p (float): The top-p value for top-p warping. + top_k (int): The top-k value for top-k warping. + + Returns: + LogitsProcessorList: A list of logits processors. + """ + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list From 6403a98bab9e44fef110b96cee5d374839a0fca3 Mon Sep 17 00:00:00 2001 From: Yimi81 <1548222878@qq.com> Date: Sun, 17 Dec 2023 17:39:21 +0800 Subject: [PATCH 2/3] 1. support openai_api.py to the latest openai sdk(>=1.0.0); 2. add streaming func call --- .../function_call_examples_v2.py | 307 ++++++ examples/openai_api_demo/openai_api.py | 858 +++++++++++++++++ .../openai_api_demo/openai_api_request.py | 40 + examples/openai_api_demo/openai_utils.py | 879 ++++++++++++++++++ 4 files changed, 2084 insertions(+) create mode 100644 examples/openai_api_demo/function_call_examples_v2.py create mode 100644 examples/openai_api_demo/openai_api.py create mode 100644 examples/openai_api_demo/openai_api_request.py create mode 100644 examples/openai_api_demo/openai_utils.py diff --git a/examples/openai_api_demo/function_call_examples_v2.py b/examples/openai_api_demo/function_call_examples_v2.py new file mode 100644 index 00000000..1564f9f1 --- /dev/null +++ b/examples/openai_api_demo/function_call_examples_v2.py @@ -0,0 +1,307 @@ +# Reference: https://openai.com/blog/function-calling-and-other-api-updates +import json +from openai import OpenAI + +# To start an Latest OpenAI-like Qwen server, use the following commands: +# git clone https://github.com/QwenLM/Qwen; +# cd Qwen; +# pip install fastapi uvicorn openai pydantic sse_starlette; +# python examples/openai_api_demo/openai_api.py; +# +# Then configure the api_base and api_key in your client: +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1/", +) + + +# Change the default values of stream parameter to enable streaming +def call_qwen(messages, functions=None, stream=False): + print(messages) + if functions: + response = client.chat.completions.create( + model="Qwen", messages=messages, functions=functions, stream=stream + ) + else: + response = client.chat.completions.create(model="Qwen", messages=messages, stream=stream) + if stream: + for part in response: + print(part.choices[0].delta.content or "", end="", flush=True) + else: + # print(response) + print(response.choices[0].message.content) + return response + + +def test_1(): + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages) + messages.append({"role": "assistant", "content": "你好!很高兴为你提供帮助。"}) + + messages.append({"role": "user", "content": "给我讲一个年轻人奋斗创业最终取得成功的故事。故事只能有一句话。"}) + call_qwen(messages) + messages.append( + { + "role": "assistant", + "content": "故事的主人公叫李明,他来自一个普通的家庭,父母都是普通的工人。李明想要成为一名成功的企业家。……", + } + ) + + messages.append({"role": "user", "content": "给这个故事起一个标题"}) + call_qwen(messages) + + +def test_2(): + functions = [ + { + "name_for_human": "谷歌搜索", + "name_for_model": "google_search", + "description_for_model": "谷歌搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "search_query", + "description": "搜索关键词或短语", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + { + "name_for_human": "文生图", + "name_for_model": "image_gen", + "description_for_model": "文生图是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL。" + + " Format the arguments as a JSON object.", + "parameters": [ + { + "name": "prompt", + "description": "英文关键词,描述了希望图像具有什么内容", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + ] + + messages = [{"role": "user", "content": "你好"}] + call_qwen(messages, functions) + messages.append( + {"role": "assistant", "content": "你好!很高兴见到你。有什么我可以帮忙的吗?"}, + ) + + messages.append({"role": "user", "content": "谁是周杰伦"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "google_search", + "content": "Jay Chou is a Taiwanese singer.", + } + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦(Jay Chou)是一位来自台湾的歌手。", + }, + ) + + messages.append({"role": "user", "content": "他老婆是谁"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用Google搜索查找相关信息。", + "function_call": { + "name": "google_search", + "arguments": '{"search_query": "周杰伦 老婆"}', + }, + } + ) + + messages.append( + {"role": "function", "name": "google_search", "content": "Hannah Quinlivan"} + ) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "周杰伦的老婆是Hannah Quinlivan。", + }, + ) + + messages.append({"role": "user", "content": "给我画个可爱的小猫吧,最好是黑猫"}) + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": "Thought: 我应该使用文生图API来生成一张可爱的小猫图片。", + "function_call": { + "name": "image_gen", + "arguments": '{"prompt": "cute black cat"}', + }, + } + ) + + messages.append( + { + "role": "function", + "name": "image_gen", + "content": '{"image_url": "https://image.pollinations.ai/prompt/cute%20black%20cat"}', + } + ) + call_qwen(messages, functions) + + +def test_3(): + functions = [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + ] + + messages = [ + { + "role": "user", + # Note: The current version of Qwen-7B-Chat (as of 2023.08) performs okay with Chinese tool-use prompts, + # but performs terribly when it comes to English tool-use prompts, due to a mistake in data collecting. + "content": "波士顿天气如何?", + } + ] + call_qwen(messages, functions) + messages.append( + { + "role": "assistant", + "content": None, + "function_call": { + "name": "get_current_weather", + "arguments": '{"location": "Boston, MA"}', + }, + }, + ) + + messages.append( + { + "role": "function", + "name": "get_current_weather", + "content": '{"temperature": "22", "unit": "celsius", "description": "Sunny"}', + } + ) + call_qwen(messages, functions) + + +def get_current_weather(location, unit="fahrenheit"): + """Get the current weather in a given location""" + if "tokyo" in location.lower(): + return json.dumps({"location": "Tokyo", "temperature": "10", "unit": unit}) + elif "san francisco" in location.lower(): + return json.dumps({"location": "San Francisco", "temperature": "72", "unit": unit}) + elif "paris" in location.lower(): + return json.dumps({"location": "Paris", "temperature": "22", "unit": unit}) + else: + return json.dumps({"location": location, "temperature": "unknown"}) + + +# Parallel function calling +def test_4(): + # Step 1: send the conversation and available functions to the model + messages = [{"role": "user", "content": "What's the weather like in San Francisco, Tokyo, and Paris?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + response = client.chat.completions.create( + model="qwen", + messages=messages, + tools=tools, + tool_choice="auto", # auto is default, but we'll be explicit + ) + response_message = response.choices[0].message + print(f"first_response: {response_message} \n") + + tool_calls = response_message.tool_calls + # Step 2: check if the model wanted to call a function + if tool_calls: + # Step 3: call the function + # Note: the JSON response may not always be valid; be sure to handle errors + available_functions = { + "get_current_weather": get_current_weather, + } # only one function in this example, but you can have multiple + messages.append(response_message) # extend conversation with assistant's reply + # Step 4: send the info for each function call and function response to the model + for tool_call in tool_calls: + function_name = tool_call.function.name + function_to_call = available_functions[function_name] + function_args = json.loads(tool_call.function.arguments) + function_response = function_to_call( + location=function_args.get("location"), + unit=function_args.get("unit"), + ) + print(f"function_response: {function_response} \n") + + messages.append( + { + "tool_call_id": tool_call.id, + "role": "tool", + "name": function_name, + "content": function_response, + } + ) # extend conversation with function response + + print(f"second_messages: {messages} \n") + second_response = client.chat.completions.create( + model="qwen", + messages=messages + ) # get a new response from the model where it can see the function response + print(f"second_response: {second_response}") + + +if __name__ == "__main__": + print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") + test_1() + print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") + test_2() + print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###") + test_3() + # # Qwen has not optimized parallel tool calls, often unable to parse a parallel call instruction into multiple tool_calls + # print("### Test Case 4 - Parallel function calling (并行函数调用,GPT格式) ###") + # test_4() diff --git a/examples/openai_api_demo/openai_api.py b/examples/openai_api_demo/openai_api.py new file mode 100644 index 00000000..f8363c34 --- /dev/null +++ b/examples/openai_api_demo/openai_api.py @@ -0,0 +1,858 @@ +import gc +import traceback +import torch +import uvicorn +import time +import uuid +import anyio +import json +from anyio.streams.memory import MemoryObjectSendStream + +from abc import ABC +from threading import Lock +from argparse import ArgumentParser +from contextlib import asynccontextmanager +from functools import partial +from typing import Dict, List, Any, Literal, Optional, Union, Tuple, Iterator, Iterable, AsyncIterator +from loguru import logger +from starlette.concurrency import run_in_threadpool, iterate_in_threadpool +from sse_starlette import EventSourceResponse +from pydantic import BaseModel + +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from openai.types.model import Model +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall +from openai.types.completion_usage import CompletionUsage +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice +from openai.types.chat.chat_completion_chunk import ( + ChoiceDelta, + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, +) +from openai.types.chat import ( + ChatCompletionMessage, + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, +) + +from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers.generation import GenerationConfig + +from openai_utils import ( + Role, + ModelList, + ChatCompletionCreateParams, + CompletionCreateParams, + ErrorCode, + ErrorResponse, + model_dump, + model_parse, + model_json, + build_qwen_chat_input, + is_partial_stop, + prepare_logits_processor) + + +llama_outer_lock = Lock() + + +@asynccontextmanager +async def lifespan(app: FastAPI): # collects GPU memory + yield + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +app = FastAPI(lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/v1/models") +async def list_models(): + return ModelList( + data=[ + Model( + id="qwen", + object="model", + created=int(time.time()), + owned_by="open" + ) + ] +) + + +@app.post("/v1/chat/completions") +async def create_chat_completion( + request: ChatCompletionCreateParams, + raw_request: Request +): + global model, tokenizer + + if len(request.messages) < 1 or request.messages[-1]["role"] == Role.ASSISTANT: + raise HTTPException(status_code=400, detail="Invalid request") + + request = await handle_request(request, template.stop) + request.max_tokens = request.max_tokens or 1024 + + params = model_dump(request) + params.update(dict(echo=False)) + logger.debug(f"==== request ====\n{params}") + + iterator_or_completion = await run_in_threadpool(_create_chat_completion, params) + + if isinstance(iterator_or_completion, Iterator): + # It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid, and we can use it to stream the response. + def iterator() -> Iterator: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( + get_event_publisher, + request=raw_request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + ) + else: + return iterator_or_completion + + +def _create_chat_completion( + params: Optional[Dict[str, Any]] = None, + **kwargs, +) -> Union[Iterator, ChatCompletion]: + params = params or {} + params.update(kwargs) + return ( + _create_chat_completion_stream(params) + if params.get("stream", False) + else _create_chat_completion_non_stream(params) + ) + + +def _create_chat_completion_stream(params: Dict[str, Any]) -> Iterator: + """ + Creates a chat completion stream. + + Args: + params (Dict[str, Any]): The parameters for generating the chat completion. + + Yields: + Dict[str, Any]: The output of the chat completion stream. + """ + _id, _created, _model = None, None, None + has_function_call = False + for i, output in enumerate(_generate(params)): + if output["error_code"] != 0: + yield output + return + + _id, _created, _model = output["id"], output["created"], output["model"] + if i == 0: + choice = ChunkChoice( + index=0, + delta=ChoiceDelta(role="assistant", content=""), + finish_reason=None, + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + finish_reason = output["finish_reason"] + if len(output["delta"]) == 0 and finish_reason != "function_call": + continue + + function_call = None + if finish_reason == "function_call": + try: + _, function_call = template.parse_assistant_response( + output["text"], params.get("functions"), params.get("tools"), + ) + except Exception as e: + traceback.print_exc() + logger.warning("Failed to parse tool call") + + if isinstance(function_call, dict) and "arguments" in function_call: + has_function_call = True + function_call = ChoiceDeltaFunctionCall(**function_call) + delta = ChoiceDelta( + content=output["delta"], + function_call=function_call + ) + elif isinstance(function_call, dict) and "function" in function_call: + has_function_call = True + finish_reason = "tool_calls" + function_call["index"] = 0 + tool_calls = [model_parse(ChoiceDeltaToolCall, function_call)] + delta = ChoiceDelta( + content=output["delta"], + tool_calls=tool_calls, + ) + else: + delta = ChoiceDelta(content=output["delta"]) + + choice = ChunkChoice( + index=0, + delta=delta, + finish_reason=finish_reason + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + if not has_function_call: + choice = ChunkChoice( + index=0, + delta=ChoiceDelta(), + finish_reason="stop" + ) + yield ChatCompletionChunk( + id=f"chat{_id}", + choices=[choice], + created=_created, + model=_model, + object="chat.completion.chunk", + ) + + +def _create_chat_completion_non_stream(params: Dict[str, Any]) -> Union[ChatCompletion, JSONResponse]: + """ + Creates a chat completion based on the given parameters. + + Args: + params (Dict[str, Any]): The parameters for generating the chat completion. + + Returns: + ChatCompletion: The generated chat completion. + """ + last_output = None + for output in _generate(params): + last_output = output + + if last_output["error_code"] != 0: + return create_error_response(last_output["error_code"], last_output["text"]) + + function_call, finish_reason = None, "stop" + if params.get("functions") or params.get("tools"): + try: + res, function_call = template.parse_assistant_response( + last_output["text"], params.get("functions"), params.get("tools"), + ) + last_output["text"] = res + except Exception as e: + traceback.print_exc() + logger.warning("Failed to parse tool call") + + if isinstance(function_call, dict) and "arguments" in function_call: + finish_reason = "function_call" + function_call = FunctionCall(**function_call) + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"], + function_call=function_call, + ) + elif isinstance(function_call, dict) and "function" in function_call: + finish_reason = "tool_calls" + tool_calls = [model_parse(ChatCompletionMessageToolCall, function_call)] + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"], + tool_calls=tool_calls, + ) + else: + message = ChatCompletionMessage( + role="assistant", + content=last_output["text"].strip(), + ) + + choice = Choice( + index=0, + message=message, + finish_reason=finish_reason, + ) + usage = model_parse(CompletionUsage, last_output["usage"]) + return ChatCompletion( + id=f"chat{last_output['id']}", + choices=[choice], + created=last_output["created"], + model=last_output["model"], + object="chat.completion", + usage=usage, + ) + + +def _generate(params: Dict[str, Any]) -> Iterator: + """ + Generates text based on the given parameters. + + Args: + params (Dict[str, Any]): A dictionary containing the parameters for text generation. + + Yields: + Iterator: A dictionary containing the generated text and error code. + """ + messages = params.get("messages") + inputs, prompt = _apply_chat_template( + messages, + max_new_tokens=params.get("max_tokens", 256), + functions=params.get("functions"), + tools=params.get("tools"), + ) + + params.update(dict(inputs=inputs, prompt=prompt)) + + try: + for output in _generate_stream_func(params): + output["error_code"] = 0 + yield output + + except (ValueError, RuntimeError) as e: + traceback.print_exc() + yield { + "text": f"{e}", + "error_code": ErrorCode.INTERNAL_ERROR, + } + + +def _apply_chat_template( + messages: List[ChatCompletionMessageParam], + max_new_tokens: Optional[int] = 256, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Tuple[Union[List[int], Dict[str, Any]], Optional[str]]: + """ + Apply chat template to generate model inputs and prompt. + + Args: + messages (List[ChatCompletionMessageParam]): List of chat completion message parameters. + max_new_tokens (Optional[int], optional): Maximum number of new tokens to generate. Defaults to 256. + functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]], optional): Functions to apply to the messages. Defaults to None. + tools (Optional[List[Dict[str, Any]]], optional): Tools to apply to the messages. Defaults to None. + **kwargs: Additional keyword arguments. + + Returns: + Tuple[Union[List[int], Dict[str, Any]], Union[str, None]]: Tuple containing the generated inputs and prompt. + """ + if template.function_call_available: + messages = template.postprocess_messages( + messages, functions, tools=tools, + ) + if functions or tools: + logger.debug(f"==== Messages with tools ====\n{messages}") + + inputs = build_qwen_chat_input( + tokenizer, messages, context_len, max_new_tokens, functions, tools + ) + return inputs, None + + +@torch.inference_mode() +def _generate_stream_func( + params: Dict[str, Any], +): + # Read parameters + input_ids = params.get("inputs") + prompt = params.get("prompt") + model_name = params.get("model", "llm") + temperature = float(params.get("temperature", 1.0)) + repetition_penalty = float(params.get("repetition_penalty", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = int(params.get("top_k", -1)) # -1 means disable + max_new_tokens = int(params.get("max_tokens", 256)) + logprobs = params.get("logprobs") + echo = bool(params.get("echo", True)) + stop_str = params.get("stop") + + stop_token_ids = params.get("stop_token_ids") or [] + if tokenizer.eos_token_id not in stop_token_ids: + stop_token_ids.append(tokenizer.eos_token_id) + + logits_processor = prepare_logits_processor( + temperature, repetition_penalty, top_p, top_k + ) + + output_ids = list(input_ids) + input_echo_len = len(input_ids) + + device = model.device + if model.config.is_encoder_decoder: + encoder_output = model.encoder( + input_ids=torch.as_tensor([input_ids], device=device) + )[0] + start_ids = torch.as_tensor( + [[model.generation_config.decoder_start_token_id]], + dtype=torch.int64, + device=device, + ) + else: + start_ids = torch.as_tensor([input_ids], device=device) + + past_key_values, sent_interrupt = None, False + token_logprobs = [None] # The first token has no logprobs. + completion_id: str = f"cmpl-{str(uuid.uuid4())}" + created: int = int(time.time()) + previous_text = "" + for i in range(max_new_tokens): + if i == 0: # prefill + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=start_ids, + encoder_hidden_states=encoder_output, + use_cache=True, + ) + logits = model.lm_head(out[0]) + else: + out = model(torch.as_tensor([input_ids], device=device), use_cache=True) + logits = out.logits + past_key_values = out.past_key_values + + if logprobs is not None: + # Prefull logprobs for the prompt. + shift_input_ids = start_ids[..., 1:].contiguous() + shift_logits = logits[..., :-1, :].contiguous() + shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist() + for label_id, logit in zip( + shift_input_ids[0].tolist(), shift_logits[0] + ): + token_logprobs.append(logit[label_id]) + + else: # decoding + if model.config.is_encoder_decoder: + out = model.decoder( + input_ids=torch.as_tensor( + [output_ids if sent_interrupt else [token]], device=device + ), + encoder_hidden_states=encoder_output, + use_cache=True, + past_key_values=None if sent_interrupt else past_key_values, + ) + sent_interrupt = False + + logits = model.lm_head(out[0]) + else: + out = model( + input_ids=torch.as_tensor( + [output_ids if sent_interrupt else [token]], device=device + ), + use_cache=True, + past_key_values=None if sent_interrupt else past_key_values, + ) + sent_interrupt = False + logits = out.logits + past_key_values = out.past_key_values + + if logits_processor: + if repetition_penalty > 1.0: + tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) + else: + tmp_output_ids = None + last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] + else: + last_token_logits = logits[0, -1, :] + + if device == "mps": + # Switch to CPU by avoiding some bugs in mps backend. + last_token_logits = last_token_logits.float().to("cpu") + + if temperature < 1e-5 or top_p < 1e-8: # greedy + _, indices = torch.topk(last_token_logits, 2) + tokens = [int(index) for index in indices.tolist()] + else: + probs = torch.softmax(last_token_logits, dim=-1) + indices = torch.multinomial(probs, num_samples=2) + tokens = [int(token) for token in indices.tolist()] + + token = tokens[0] + output_ids.append(token) + + if logprobs is not None: + # Cannot use last_token_logits because logprobs is based on raw logits. + token_logprobs.append( + torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist() + ) + + if token in stop_token_ids: + stopped = True + else: + stopped = False + + # Yield the output tokens + if i % 2 == 0 or i == max_new_tokens - 1 or stopped: + if echo: + tmp_output_ids = output_ids + rfind_start = len(prompt) + else: + tmp_output_ids = output_ids[input_echo_len:] + rfind_start = 0 + + output = tokenizer.decode( + tmp_output_ids, + skip_special_tokens=False, # fix for qwen react + spaces_between_special_tokens=False, + clean_up_tokenization_spaces=True, + ) + + ret_logprobs = None + if logprobs is not None: + ret_logprobs = { + "text_offset": [], + "tokens": [ + tokenizer.decode(token) + for token in ( + output_ids if echo else output_ids[input_echo_len:] + ) + ], + "token_logprobs": token_logprobs if echo else token_logprobs[input_echo_len:], + "top_logprobs": [{}] * len(token_logprobs if echo else token_logprobs[input_echo_len:]), + } + # Compute text_offset + curr_pos = 0 + for text in ret_logprobs["tokens"]: + ret_logprobs["text_offset"].append(curr_pos) + curr_pos += len(text) + + partially_stopped, finish_reason = False, None + if stop_str: + if isinstance(stop_str, str): + pos = output.rfind(stop_str, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + else: + partially_stopped = is_partial_stop(output, stop_str) + elif isinstance(stop_str, Iterable): + for each_stop in stop_str: + pos = output.rfind(each_stop, rfind_start) + if pos != -1: + output = output[:pos] + stopped = True + if each_stop == "Observation:": + finish_reason = "function_call" + break + else: + partially_stopped = is_partial_stop(output, each_stop) + if partially_stopped: + break + else: + raise ValueError("Invalid stop field type.") + + # Prevent yielding partial stop sequence + if (not partially_stopped) and output and output[-1] != "�": + delta_text = output[len(previous_text):] + previous_text = output + + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "delta": delta_text, + "text": output, + "logprobs": ret_logprobs, + "finish_reason": finish_reason, + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + } + + if stopped: + break + + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "delta": "", + "text": output, + "logprobs": ret_logprobs, + "finish_reason": "stop", + "usage": { + "prompt_tokens": input_echo_len, + "completion_tokens": i, + "total_tokens": input_echo_len + i, + }, + } + + # Clean + del past_key_values, out + gc.collect() + torch.cuda.empty_cache() + + +class QwenTemplate(ABC): + + name = "qwen" + system_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + allow_models = ["qwen"] + stop = { + "token_ids": [151643, 151644, 151645], # "<|endoftext|>", "<|im_start|>", "<|im_end|>" + "strings": ["<|endoftext|>", "<|im_end|>"], + } + function_call_available = True + + @property + def template(self) -> str: + """ This template formats inputs in the standard ChatML format. See + https://github.com/openai/openai-python/blob/main/chatml.md + """ + return ( + "{{ system_prompt }}" + "{% for message in messages %}" + "{{ '<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>' + '\\n' }}" + "{% endfor %}" + "{% if add_generation_prompt %}" + "{{ '<|im_start|>assistant\\n' }}" + "{% endif %}" + ) + + def postprocess_messages( + self, + messages: List[ChatCompletionMessageParam], + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> List[Dict[str, Any]]: + return messages + + def parse_assistant_response( + self, + output: str, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[str, Optional[Union[str, Dict[str, Any]]]]: + func_name, func_args = "", "" + i = output.rfind("\nAction:") + j = output.rfind("\nAction Input:") + k = output.rfind("\nObservation:") + + if 0 <= i < j: # If the text has `Action` and `Action input`, + if k < j: # but does not contain `Observation`, + # then it is likely that `Observation` is omitted by the LLM, + # because the output text may have discarded the stop word. + output = output.rstrip() + "\nObservation:" # Add it back. + k = output.rfind("\nObservation:") + func_name = output[i + len("\nAction:"): j].strip() + func_args = output[j + len("\nAction Input:"): k].strip() + + if func_name: + if functions: + function_call = { + "name": func_name, + "arguments": func_args + } + else: + function_call = { + "function": { + "name": func_name, + "arguments": func_args + }, + "id": func_name, + "type": "function", + } + return output[:k], function_call + + z = output.rfind("\nFinal Answer: ") + if z >= 0: + output = output[z + len("\nFinal Answer: "):] + return output, None + + +async def handle_request( + request: Union[CompletionCreateParams, ChatCompletionCreateParams], + stop: Dict[str, Any] = None +) -> Union[Union[CompletionCreateParams, ChatCompletionCreateParams], JSONResponse]: + error_check_ret = check_requests(request) + if error_check_ret is not None: + raise error_check_ret + + # stop settings + _stop, _stop_token_ids = [], [] + if stop is not None: + _stop_token_ids = stop.get("token_ids", []) + _stop = stop.get("strings", []) + + request.stop = request.stop or [] + if isinstance(request.stop, str): + request.stop = [request.stop] + + if request.functions: + request.stop.append("Observation:") + + request.stop = list(set(_stop + request.stop)) + request.stop_token_ids = request.stop_token_ids or [] + request.stop_token_ids = list(set(_stop_token_ids + request.stop_token_ids)) + + return request + + +def check_requests(request: Union[CompletionCreateParams, ChatCompletionCreateParams]) -> Optional[JSONResponse]: + # Check all params + if request.max_tokens is not None and request.max_tokens <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", + ) + if request.n is not None and request.n <= 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.n} is less than the minimum of 1 - 'n'", + ) + if request.temperature is not None and request.temperature < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is less than the minimum of 0 - 'temperature'", + ) + if request.temperature is not None and request.temperature > 2: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.temperature} is greater than the maximum of 2 - 'temperature'", + ) + if request.top_p is not None and request.top_p < 0: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is less than the minimum of 0 - 'top_p'", + ) + if request.top_p is not None and request.top_p > 1: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.top_p} is greater than the maximum of 1 - 'temperature'", + ) + if request.stop is None or isinstance(request.stop, (str, list)): + return None + else: + return create_error_response( + ErrorCode.PARAM_OUT_OF_RANGE, + f"{request.stop} is not valid under any of the given schemas - 'stop'", + ) + + +def create_error_response(code: int, message: str) -> JSONResponse: + return JSONResponse(model_dump(ErrorResponse(message=message, code=code)), status_code=500) + + +async def get_event_publisher( + request: Request, + inner_send_chan: MemoryObjectSendStream, + iterator: Union[Iterator, AsyncIterator], +): + async with inner_send_chan: + try: + async for chunk in iterate_in_threadpool(iterator): + if isinstance(chunk, BaseModel): + chunk = model_json(chunk) + elif isinstance(chunk, dict): + chunk = json.dumps(chunk, ensure_ascii=False) + + await inner_send_chan.send(dict(data=chunk)) + + if await request.is_disconnected(): + raise anyio.get_cancelled_exc_class()() + + if llama_outer_lock.locked(): + await inner_send_chan.send(dict(data="[DONE]")) + raise anyio.get_cancelled_exc_class()() + except anyio.get_cancelled_exc_class() as e: + logger.info("disconnected") + with anyio.move_on_after(1, shield=True): + logger.info(f"Disconnected from client (via refresh/close) {request.client}") + raise e + + +def _get_args(): + parser = ArgumentParser() + parser.add_argument( + "-c", + "--checkpoint-path", + type=str, + default="Qwen/Qwen-7B-Chat", + help="Checkpoint name or path, default to %(default)r", + ) + parser.add_argument( + "--cpu-only", action="store_true", help="Run demo with CPU only" + ) + parser.add_argument( + "--server-port", type=int, default=8000, help="Demo server port." + ) + parser.add_argument( + "--server-name", + type=str, + default="127.0.0.1", + help="Demo server name. Default: 127.0.0.1, which is only visible from the local computer." + " If you want other computers to access your server, use 0.0.0.0 instead.", + ) + parser.add_argument( + "--context_len", type=int, default=None, help="Context length for generating completions." + ) + parser.add_argument("--disable-gc", action="store_true", + help="Disable GC after each response generated.") + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = _get_args() + + tokenizer = AutoTokenizer.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ) + + if args.cpu_only: + device = "cpu" + else: + device = "cuda" + + model = AutoModelForCausalLM.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ).to(device).eval() + + # Multi-GPU support, use the following two lines instead of the above line, num gpus to your actual number of graphics cards + # from utils import load_model_on_gpus + # model = load_model_on_gpus(args.checkpoint_path, num_gpus=2) + + model.generation_config = GenerationConfig.from_pretrained( + args.checkpoint_path, + trust_remote_code=True, + resume_download=True, + ) + + context_len = 8192 if args.context_len is None else args.context_len + template = QwenTemplate() + + uvicorn.run(app, host=args.server_name, port=args.server_port, workers=1) \ No newline at end of file diff --git a/examples/openai_api_demo/openai_api_request.py b/examples/openai_api_demo/openai_api_request.py new file mode 100644 index 00000000..081e7849 --- /dev/null +++ b/examples/openai_api_demo/openai_api_request.py @@ -0,0 +1,40 @@ +from openai import OpenAI + +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1/", +) + + +# List models API +models = client.models.list() +print(models.model_dump()) + + +# Chat completion API +chat_completion = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "你好,请问你是谁?", + } + ], + model="qwen", +) +print(chat_completion) + + +# Stream +stream = client.chat.completions.create( + messages=[ + { + "role": "user", + "content": "你好,请问你是谁?", + } + ], + model="qwen", + stream=True, +) +for part in stream: + print(part.choices[0].delta.content or "", end="", flush=True) + diff --git a/examples/openai_api_demo/openai_utils.py b/examples/openai_api_demo/openai_utils.py new file mode 100644 index 00000000..fe30d554 --- /dev/null +++ b/examples/openai_api_demo/openai_utils.py @@ -0,0 +1,879 @@ + +import pydantic +import json +import re + +from copy import deepcopy +from enum import Enum, IntEnum +from pydantic import BaseModel +from loguru import logger +from typing import ( + Dict, + List, + Any, + Literal, + Optional, + Union, + cast, + Type, + Tuple +) + +from transformers import PreTrainedTokenizer +from transformers.generation.logits_process import ( + LogitsProcessorList, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, +) + +from fastapi import HTTPException + +from openai.types.model import Model +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionToolChoiceOptionParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, +) +from openai.types.chat.chat_completion_message import FunctionCall +from openai.types.chat.completion_create_params import ResponseFormat +from openai.types.create_embedding_response import Usage + + +TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters}""" + +REACT_INSTRUCTION = """Answer the following questions as best you can. You have access to the following APIs: + +{tools_text} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tools_name_text}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can be repeated zero or more times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin!""" + +_TEXT_COMPLETION_CMD = object() + +# --------------- Pydantic v2 compatibility --------------- + +PYDANTIC_V2 = pydantic.VERSION.startswith("2.") + + +def model_json(model: pydantic.BaseModel, **kwargs) -> str: + if PYDANTIC_V2: + return model.model_dump_json(**kwargs) + return model.json(**kwargs) # type: ignore + + +def model_dump(model: pydantic.BaseModel, **kwargs) -> Dict[str, Any]: + if PYDANTIC_V2: + return model.model_dump(**kwargs) + return cast( + "dict[str, Any]", + model.dict(**kwargs), + ) + + +def model_parse(model: Type[pydantic.BaseModel], data: Any) -> pydantic.BaseModel: + if PYDANTIC_V2: + return model.model_validate(data) + return model.parse_obj(data) # pyright: ignore[reportDeprecated] + + +def disable_warnings(model: Type[pydantic.BaseModel]): + # Disable warning for model_name settings + if PYDANTIC_V2: + model.model_config["protected_namespaces"] = () + +def parse_messages( + messages: List[ChatCompletionMessageParam], split_role="user" +) -> Tuple[str, List[List[ChatCompletionMessageParam]]]: + """ + Parse a list of chat completion messages into system and rounds. + + Args: + messages (List[ChatCompletionMessageParam]): The list of chat completion messages. + split_role: The role at which to split the rounds. Defaults to Role.USER. + + Returns: + Tuple[str, List[List[ChatCompletionMessageParam]]]: A tuple containing the system message and a list of rounds. + """ + system, rounds = "", [] + r = [] + for i, message in enumerate(messages): + if message["role"] == "system": + system = message["content"] + continue + if message["role"] == split_role and r: + rounds.append(r) + r = [] + r.append(message) + if r: + rounds.append(r) + return system, rounds + + +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + TOOL = "tool" + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + code: int + + +class ErrorCode(IntEnum): + """ + https://platform.openai.com/docs/guides/error-codes/api-errors + """ + + VALIDATION_TYPE_ERROR = 40001 + + INVALID_AUTH_KEY = 40101 + INCORRECT_AUTH_KEY = 40102 + NO_PERMISSION = 40103 + + INVALID_MODEL = 40301 + PARAM_OUT_OF_RANGE = 40302 + CONTEXT_OVERFLOW = 40303 + + RATE_LIMIT = 42901 + QUOTA_EXCEEDED = 42902 + ENGINE_OVERLOADED = 42903 + + INTERNAL_ERROR = 50001 + CUDA_OUT_OF_MEMORY = 50002 + GRADIO_REQUEST_ERROR = 50003 + GRADIO_STREAM_UNKNOWN_ERROR = 50004 + CONTROLLER_NO_WORKER = 50005 + CONTROLLER_WORKER_TIMEOUT = 50006 + + +class ModelList(BaseModel): + object: str = "list" + data: List[Model] = [] + + +class ChatCompletionCreateParams(BaseModel): + messages: List[ChatCompletionMessageParam] + """A list of messages comprising the conversation so far. + + [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models). + """ + + model: str + """ID of the model to use. + + See the + [model endpoint compatibility](https://platform.openai.com/docs/models/model-endpoint-compatibility) + table for details on which models work with the Chat API. + """ + + frequency_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on their existing frequency in the + text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + function_call: Optional[FunctionCall] = None + """Deprecated in favor of `tool_choice`. + + Controls which (if any) function is called by the model. `none` means the model + will not call a function and instead generates a message. `auto` means the model + can pick between generating a message or calling a function. Specifying a + particular function via `{"name": "my_function"}` forces the model to call that + function. + + `none` is the default when no functions are present. `auto`` is the default if + functions are present. + """ + + functions: Optional[List] = None + """Deprecated in favor of `tools`. + + A list of functions the model may generate JSON inputs for. + """ + + logit_bias: Optional[Dict[str, int]] = None + """Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + """ + + max_tokens: Optional[int] = None + """The maximum number of [tokens](/tokenizer) to generate in the chat completion. + + The total length of input tokens and generated tokens is limited by the model's + context length. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + n: Optional[int] = 1 + """How many chat completion choices to generate for each input message.""" + + presence_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + response_format: Optional[ResponseFormat] = None + """An object specifying the format that the model must output. + + Used to enable JSON mode. + """ + + seed: Optional[int] = None + """This feature is in Beta. + + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same `seed` and parameters should return + the same result. Determinism is not guaranteed, and you should refer to the + `system_fingerprint` response parameter to monitor changes in the backend. + """ + + stop: Optional[Union[str, List[str]]] = None + """Up to 4 sequences where the API will stop generating further tokens.""" + + temperature: Optional[float] = 0.9 + """What sampling temperature to use, between 0 and 2. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + """ + + tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None + """ + Controls which (if any) function is called by the model. `none` means the model + will not call a function and instead generates a message. `auto` means the model + can pick between generating a message or calling a function. Specifying a + particular function via + `{"type: "function", "function": {"name": "my_function"}}` forces the model to + call that function. + + `none` is the default when no functions are present. `auto` is the default if + functions are present. + """ + + tools: Optional[List] = None + """A list of tools the model may call. + + Currently, only functions are supported as a tool. Use this to provide a list of + functions the model may generate JSON inputs for. + """ + + top_p: Optional[float] = 1.0 + """ + An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + stream: Optional[bool] = False + """If set, partial message deltas will be sent, like in ChatGPT. + + Tokens will be sent as data-only + [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` + message. + [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + """ + + # Addictional parameters + repetition_penalty: Optional[float] = 1.03 + """The parameter for repetition penalty. 1.0 means no penalty. + See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details. + """ + + typical_p: Optional[float] = None + """Typical Decoding mass. + See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information + """ + + watermark: Optional[bool] = False + """Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226) + """ + + best_of: Optional[int] = 1 + + ignore_eos: Optional[bool] = False + + use_beam_search: Optional[bool] = False + + stop_token_ids: Optional[List[int]] = None + + skip_special_tokens: Optional[bool] = True + + spaces_between_special_tokens: Optional[bool] = True + + min_p: Optional[float] = 0.0 + + +class CompletionCreateParams(BaseModel): + model: str + """ID of the model to use. + + You can use the + [List models](https://platform.openai.com/docs/api-reference/models/list) API to + see all of your available models, or see our + [Model overview](https://platform.openai.com/docs/models/overview) for + descriptions of them. + """ + + prompt: Union[str, List[str], List[int], List[List[int]], None] + """ + The prompt(s) to generate completions for, encoded as a string, array of + strings, array of tokens, or array of token arrays. + + Note that <|endoftext|> is the document separator that the model sees during + training, so if a prompt is not specified the model will generate as if from the + beginning of a new document. + """ + + best_of: Optional[int] = 1 + """ + Generates `best_of` completions server-side and returns the "best" (the one with + the highest log probability per token). Results cannot be streamed. + + When used with `n`, `best_of` controls the number of candidate completions and + `n` specifies how many to return – `best_of` must be greater than `n`. + + **Note:** Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and `stop`. + """ + + echo: Optional[bool] = False + """Echo back the prompt in addition to the completion""" + + frequency_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on their existing frequency in the + text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + logit_bias: Optional[Dict[str, int]] = None + """Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the GPT + tokenizer) to an associated bias value from -100 to 100. You can use this + [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to + convert text to token IDs. Mathematically, the bias is added to the logits + generated by the model prior to sampling. The exact effect will vary per model, + but values between -1 and 1 should decrease or increase likelihood of selection; + values like -100 or 100 should result in a ban or exclusive selection of the + relevant token. + + As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token + from being generated. + """ + + logprobs: Optional[int] = None + """ + Include the log probabilities on the `logprobs` most likely tokens, as well the + chosen tokens. For example, if `logprobs` is 5, the API will return a list of + the 5 most likely tokens. The API will always return the `logprob` of the + sampled token, so there may be up to `logprobs+1` elements in the response. + + The maximum value for `logprobs` is 5. + """ + + max_tokens: Optional[int] = 16 + """The maximum number of [tokens](/tokenizer) to generate in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + n: Optional[int] = 1 + """How many completions to generate for each prompt. + + **Note:** Because this parameter generates many completions, it can quickly + consume your token quota. Use carefully and ensure that you have reasonable + settings for `max_tokens` and `stop`. + """ + + presence_penalty: Optional[float] = 0. + """Number between -2.0 and 2.0. + + Positive values penalize new tokens based on whether they appear in the text so + far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](https://platform.openai.com/docs/guides/gpt/parameter-details) + """ + + seed: Optional[int] = None + """ + If specified, our system will make a best effort to sample deterministically, + such that repeated requests with the same `seed` and parameters should return + the same result. + + Determinism is not guaranteed, and you should refer to the `system_fingerprint` + response parameter to monitor changes in the backend. + """ + + stop: Optional[Union[str, List[str]]] = None + """Up to 4 sequences where the API will stop generating further tokens. + + The returned text will not contain the stop sequence. + """ + + suffix: Optional[str] = None + """The suffix that comes after a completion of inserted text.""" + + temperature: Optional[float] = 1. + """What sampling temperature to use, between 0 and 2. + + Higher values like 0.8 will make the output more random, while lower values like + 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + """ + + top_p: Optional[float] = 1. + """ + An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + stream: Optional[bool] = False + """If set, partial message deltas will be sent, like in ChatGPT. + + Tokens will be sent as data-only + [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` + message. + [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + """ + + # Addictional parameters + repetition_penalty: Optional[float] = 1.03 + """The parameter for repetition penalty. 1.0 means no penalty. + See[this paper](https://arxiv.org / pdf / 1909.05858.pdf) for more details. + """ + + typical_p: Optional[float] = None + """Typical Decoding mass. + See[Typical Decoding for Natural Language Generation](https://arxiv.org / abs / 2202.00666) for more information + """ + + watermark: Optional[bool] = False + """Watermarking with [A Watermark for Large Language Models](https://arxiv.org / abs / 2301.10226) + """ + + ignore_eos: Optional[bool] = False + + use_beam_search: Optional[bool] = False + + stop_token_ids: Optional[List[int]] = None + + skip_special_tokens: Optional[bool] = True + + spaces_between_special_tokens: Optional[bool] = True + + min_p: Optional[float] = 0.0 + + +class EmbeddingCreateParams(BaseModel): + input: Union[str, List[str], List[int], List[List[int]]] + """Input text to embed, encoded as a string or array of tokens. + + To embed multiple inputs in a single request, pass an array of strings or array + of token arrays. The input must not exceed the max input tokens for the model + (8192 tokens for `text-embedding-ada-002`) and cannot be an empty string. + [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) + for counting tokens. + """ + + model: str + """ID of the model to use. + + You can use the + [List models](https://platform.openai.com/docs/api-reference/models/list) API to + see all of your available models, or see our + [Model overview](https://platform.openai.com/docs/models/overview) for + descriptions of them. + """ + + encoding_format: Literal["float", "base64"] = "float" + """The format to return the embeddings in. + + Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). + """ + + user: Optional[str] = None + """ + A unique identifier representing your end-user, which can help OpenAI to monitor + and detect abuse. + [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids). + """ + + +class Embedding(BaseModel): + embedding: Any + """The embedding vector, which is a list of floats. + + The length of vector depends on the model as listed in the + [embedding guide](https://platform.openai.com/docs/guides/embeddings). + """ + + index: int + """The index of the embedding in the list of embeddings.""" + + object: Literal["embedding"] + """The object type, which is always "embedding".""" + + +class CreateEmbeddingResponse(BaseModel): + data: List[Embedding] + """The list of embeddings generated by the model.""" + + model: str + """The name of the model used to generate the embedding.""" + + object: Literal["list"] + """The object type, which is always "list".""" + + usage: Usage + """The usage information for the request.""" + + +def build_qwen_chat_input( + tokenizer: PreTrainedTokenizer, + messages: List[ChatCompletionMessageParam], + context_len: int = 8192, + max_new_tokens: int = 256, + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> List[int]: + """ + Builds the input tokens for Qwen chat generation. + + Refs: + https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py + + Args: + tokenizer: The tokenizer used to encode the input tokens. + messages: The list of chat messages. + context_len: The maximum length of the context. + max_new_tokens: The maximum number of new tokens to add. + functions: Optional dictionary or list of dictionaries representing the functions. + tools: Optional list of dictionaries representing the tools. + + Returns: + The list of input tokens. + """ + query, history = process_qwen_messages(messages, functions, tools) + if query is _TEXT_COMPLETION_CMD: + return build_last_message_input(tokenizer, history) + + messages = [] + for q, r in history: + messages.extend( + [ + ChatCompletionUserMessageParam(role="user", content=q), + ChatCompletionAssistantMessageParam(role="assistant", content=r) + ] + ) + messages.append(ChatCompletionUserMessageParam(role="user", content=query)) + + max_input_tokens = context_len - max_new_tokens + system, rounds = parse_messages(messages) + system = f"You are a helpful assistant.{system}" + + im_start_tokens, im_end_tokens = [tokenizer.im_start_id], [tokenizer.im_end_id] + nl_tokens = tokenizer.encode("\n") + + def _tokenize_str(role, content): + return tokenizer.encode( + role, allowed_special=set() + ) + nl_tokens + tokenizer.encode(content, allowed_special=set()) + + system_tokens_part = _tokenize_str("system", system) + system_tokens = im_start_tokens + system_tokens_part + im_end_tokens + max_history_tokens = max_input_tokens - len(system_tokens) + + history_tokens = [] + for r in rounds[::-1]: + round_tokens = [] + for message in r: + if round_tokens: + round_tokens += nl_tokens + + if message["role"] == Role.USER: + content_tokens = im_start_tokens + _tokenize_str("user", message["content"]) + im_end_tokens + else: + content_tokens = im_start_tokens + _tokenize_str("assistant", message["content"]) + im_end_tokens + + round_tokens.extend(content_tokens) + + if len(history_tokens) == 0 or len(history_tokens) + len(round_tokens) <= max_history_tokens: + if history_tokens: + history_tokens = nl_tokens + history_tokens + + history_tokens = round_tokens + history_tokens # concat left + if len(history_tokens) < max_history_tokens: + continue + break + + input_tokens = system_tokens + nl_tokens + history_tokens + if messages[-1]["role"] != Role.ASSISTANT: + input_tokens += nl_tokens + im_start_tokens + tokenizer.encode("assistant") + nl_tokens + return input_tokens[-max_input_tokens:] # truncate left + + + +def process_qwen_messages( + messages: List[ChatCompletionMessageParam], + functions: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, +) -> Tuple[str, List[List[str]]]: + """ + Process the Qwen messages and generate a query and history. + + Args: + messages (List[ChatCompletionMessageParam]): The list of chat completion messages. + functions (Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]): The functions to be used. + tools (Optional[List[Dict[str, Any]]]): The tools to be used. + + Returns: + Tuple[str, List[List[str]]]: The generated query and history. + """ + if all(m["role"] != Role.USER for m in messages): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting at least one user message.", + ) + + messages = deepcopy(messages) + default_system = "You are a helpful assistant." + system = "" + if messages[0]["role"] == Role.SYSTEM: + system = messages.pop(0)["content"].lstrip("\n").rstrip() + if system == default_system: + system = "" + + if tools: + functions = [t["function"] for t in tools] + + if functions: + tools_text = [] + tools_name_text = [] + for func_info in functions: + name = func_info.get("name", "") + name_m = func_info.get("name_for_model", name) + name_h = func_info.get("name_for_human", name) + desc = func_info.get("description", "") + desc_m = func_info.get("description_for_model", desc) + tool = TOOL_DESC.format( + name_for_model=name_m, + name_for_human=name_h, + # Hint: You can add the following format requirements in description: + # "Format the arguments as a JSON object." + # "Enclose the code within triple backticks (`) at the beginning and end of the code." + description_for_model=desc_m, + parameters=json.dumps(func_info["parameters"], ensure_ascii=False), + ) + + tools_text.append(tool) + tools_name_text.append(name_m) + + tools_text = "\n\n".join(tools_text) + tools_name_text = ", ".join(tools_name_text) + system += "\n\n" + REACT_INSTRUCTION.format( + tools_text=tools_text, + tools_name_text=tools_name_text, + ) + system = system.lstrip("\n").rstrip() + + dummy_thought = { + "en": "\nThought: I now know the final answer.\nFinal answer: ", + "zh": "\nThought: 我会作答了。\nFinal answer: ", + } + + _messages = messages + messages = [] + for m_idx, m in enumerate(_messages): + role, content = m["role"], m["content"] + func_call, tools_call = m.get("function_call", None), m.get("tools_call", None) + if content: + content = content.lstrip("\n").rstrip() + if role in [Role.FUNCTION, Role.TOOL]: + if (len(messages) == 0) or (messages[-1]["role"] != Role.ASSISTANT): + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role assistant before role function.", + ) + messages[-1]["content"] += f"\nObservation: {content}" + if m_idx == len(_messages) - 1: + messages[-1]["content"] += "\nThought:" + elif role == Role.ASSISTANT: + if len(messages) == 0: + raise HTTPException( + status_code=400, + detail=f"Invalid request: Expecting role user before role assistant.", + ) + last_msg = messages[-1]["content"] + last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 + + if func_call is None and tools_call is None: + if functions or tools_call: + content = dummy_thought["zh" if last_msg_has_zh else "en"] + content + else: + if func_call: + f_name, f_args = func_call.get("name"), func_call.get("arguments") + else: + f_name, f_args = tools_call[0]["function"]["name"], tools_call[0]["function"]["arguments"] + if not content: + if last_msg_has_zh: + content = f"Thought: 我可以使用 {f_name} API。" + else: + content = f"Thought: I can use {f_name}." + + if messages[-1]["role"] == Role.USER: + messages.append( + ChatCompletionAssistantMessageParam(role="assistant", content=content.lstrip("\n").rstrip()) + ) + else: + messages[-1]["content"] += content + elif role == Role.USER: + messages.append( + ChatCompletionUserMessageParam(role="user", content=content.lstrip("\n").rstrip()) + ) + else: + raise HTTPException( + status_code=400, detail=f"Invalid request: Incorrect role {role}." + ) + + query = _TEXT_COMPLETION_CMD + if messages[-1]["role"] == Role.USER: + query = messages[-1]["content"] + messages = messages[:-1] + + if len(messages) % 2 != 0: + raise HTTPException(status_code=400, detail="Invalid request") + + history = [] # [(Q1, A1), (Q2, A2), ..., (Q_last_turn, A_last_turn)] + for i in range(0, len(messages), 2): + if messages[i]["role"] == Role.USER and messages[i + 1]["role"] == Role.ASSISTANT: + usr_msg = messages[i]["content"].lstrip("\n").rstrip() + bot_msg = messages[i + 1]["content"].lstrip("\n").rstrip() + if system and (i == len(messages) - 2): + usr_msg = f"{system}\n\nQuestion: {usr_msg}" + system = "" + for t in dummy_thought.values(): + t = t.lstrip("\n") + if bot_msg.startswith(t) and ("\nAction: " in bot_msg): + bot_msg = bot_msg[len(t):] + history.append([usr_msg, bot_msg]) + else: + raise HTTPException( + status_code=400, + detail="Invalid request: Expecting exactly one user (or function) role before every assistant role.", + ) + if system: + assert query is not _TEXT_COMPLETION_CMD + query = f"{system}\n\nQuestion: {query}" + return query, history + + +def build_last_message_input(tokenizer: PreTrainedTokenizer, history: list): + im_start = "<|im_start|>" + im_end = "<|im_end|>" + prompt = f"{im_start}system\nYou are a helpful assistant.{im_end}" + for i, (query, response) in enumerate(history): + query = query.lstrip("\n").rstrip() + response = response.lstrip("\n").rstrip() + prompt += f"\n{im_start}user\n{query}{im_end}" + prompt += f"\n{im_start}assistant\n{response}{im_end}" + prompt = prompt[:-len(im_end)] + logger.debug(f"==== Prompt with tools ====\n{prompt}") + return tokenizer.encode(prompt) + + +def is_partial_stop(output: str, stop_str: str): + """ Check whether the output contains a partial stop str. """ + return any( + stop_str.startswith(output[-i:]) + for i in range(0, min(len(output), len(stop_str))) + ) + + +def prepare_logits_processor( + temperature: float, repetition_penalty: float, top_p: float, top_k: int +) -> LogitsProcessorList: + """ + Prepare a list of logits processors based on the provided parameters. + + Args: + temperature (float): The temperature value for temperature warping. + repetition_penalty (float): The repetition penalty value. + top_p (float): The top-p value for top-p warping. + top_k (int): The top-k value for top-k warping. + + Returns: + LogitsProcessorList: A list of logits processors. + """ + processor_list = LogitsProcessorList() + # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. + if temperature >= 1e-5 and temperature != 1.0: + processor_list.append(TemperatureLogitsWarper(temperature)) + if repetition_penalty > 1.0: + processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) + if 1e-8 <= top_p < 1.0: + processor_list.append(TopPLogitsWarper(top_p)) + if top_k > 0: + processor_list.append(TopKLogitsWarper(top_k)) + return processor_list From db223a797084e5e3ce9a6f09b2507b0fc47b8dfd Mon Sep 17 00:00:00 2001 From: Yimi81 <1548222878@qq.com> Date: Tue, 26 Dec 2023 22:25:39 +0800 Subject: [PATCH 3/3] fix unused f_args --- examples/openai_api_demo/function_call_examples_v2.py | 8 ++++---- examples/openai_api_demo/openai_api.py | 5 ----- examples/openai_api_demo/openai_utils.py | 10 ++++++---- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/examples/openai_api_demo/function_call_examples_v2.py b/examples/openai_api_demo/function_call_examples_v2.py index 1564f9f1..a0b1d306 100644 --- a/examples/openai_api_demo/function_call_examples_v2.py +++ b/examples/openai_api_demo/function_call_examples_v2.py @@ -296,10 +296,10 @@ def test_4(): if __name__ == "__main__": - print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") - test_1() - print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") - test_2() + # print("### Test Case 1 - No Function Calling (普通问答、无函数调用) ###") + # test_1() + # print("### Test Case 2 - Use Qwen-Style Functions (函数调用,千问格式) ###") + # test_2() print("### Test Case 3 - Use GPT-Style Functions (函数调用,GPT格式) ###") test_3() # # Qwen has not optimized parallel tool calls, often unable to parse a parallel call instruction into multiple tool_calls diff --git a/examples/openai_api_demo/openai_api.py b/examples/openai_api_demo/openai_api.py index f8363c34..8daad2a1 100644 --- a/examples/openai_api_demo/openai_api.py +++ b/examples/openai_api_demo/openai_api.py @@ -39,7 +39,6 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, - ChatCompletionToolChoiceOptionParam, ) from transformers import AutoTokenizer, AutoModelForCausalLM @@ -842,10 +841,6 @@ def _get_args(): resume_download=True, ).to(device).eval() - # Multi-GPU support, use the following two lines instead of the above line, num gpus to your actual number of graphics cards - # from utils import load_model_on_gpus - # model = load_model_on_gpus(args.checkpoint_path, num_gpus=2) - model.generation_config = GenerationConfig.from_pretrained( args.checkpoint_path, trust_remote_code=True, diff --git a/examples/openai_api_demo/openai_utils.py b/examples/openai_api_demo/openai_utils.py index fe30d554..05f8e91f 100644 --- a/examples/openai_api_demo/openai_utils.py +++ b/examples/openai_api_demo/openai_utils.py @@ -747,7 +747,7 @@ def process_qwen_messages( messages = [] for m_idx, m in enumerate(_messages): role, content = m["role"], m["content"] - func_call, tools_call = m.get("function_call", None), m.get("tools_call", None) + func_call, tool_calls = m.get("function_call", None), m.get("tool_calls", None) if content: content = content.lstrip("\n").rstrip() if role in [Role.FUNCTION, Role.TOOL]: @@ -768,19 +768,21 @@ def process_qwen_messages( last_msg = messages[-1]["content"] last_msg_has_zh = len(re.findall(r"[\u4e00-\u9fff]+", last_msg)) > 0 - if func_call is None and tools_call is None: - if functions or tools_call: + if func_call is None and tool_calls is None: + if functions or tool_calls: content = dummy_thought["zh" if last_msg_has_zh else "en"] + content else: if func_call: f_name, f_args = func_call.get("name"), func_call.get("arguments") else: - f_name, f_args = tools_call[0]["function"]["name"], tools_call[0]["function"]["arguments"] + f_name, f_args = tool_calls[0]["function"]["name"], tool_calls[0]["function"]["arguments"] if not content: if last_msg_has_zh: content = f"Thought: 我可以使用 {f_name} API。" else: content = f"Thought: I can use {f_name}." + if func_call: + content = f"\n{content}\nAction: {f_name}\nAction Input: {f_args}" if messages[-1]["role"] == Role.USER: messages.append(