From 39f0a389b40aa33afae3e4fe22f13691fceff41e Mon Sep 17 00:00:00 2001 From: Sijun He Date: Fri, 31 May 2024 16:26:01 +0800 Subject: [PATCH] [erniebot] Add new qianfan models, Add `response_format` argument (#349) * add new models; add response_format * remove aksk * fix lint --- .../src/erniebot/resources/chat_completion.py | 43 +++++++++++++++---- erniebot/tests/test_chat_completion.py | 14 +++++- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index cd3fff0ce..3f91c4bd3 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -58,15 +58,29 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5-8k": { "model_id": "completions", }, + "ernie-3.5-8k-0205": { + "model_id": "ernie-3.5-8k-0205", + }, + "ernie-3.5-8k-0329": { + "model_id": "ernie-3.5-8k-0329", + }, + "ernie-3.5-128k": { + "model_id": "ernie-3.5-128k", + }, "ernie-lite": { "model_id": "eb-instant", }, + "ernie-lite-8k-0308": { + "model_id": "ernie-lite-8k", + }, "ernie-4.0": { "model_id": "completions_pro", }, - "ernie-longtext": { - # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 - "model_id": "completions", + "ernie-4.0-8k-0329": { + "model_id": "ernie-4.0-8k-0329", + }, + "ernie-4.0-8k-0104": { + "model_id": "ernie-4.0-8k-0104", }, "ernie-speed": { "model_id": "ernie_speed", @@ -97,10 +111,6 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-4.0": { "model_id": "completions_pro", }, - "ernie-longtext": { - # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 - "model_id": "completions", - }, "ernie-speed": { "model_id": "ernie_speed", }, @@ -156,6 +166,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> "ChatCompletionResponse": @@ -183,6 +194,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Iterator["ChatCompletionResponse"]: @@ -210,6 +222,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: @@ -236,6 +249,7 @@ def create( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, max_output_tokens: Optional[int] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: @@ -292,6 +306,8 @@ def create( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + if response_format is not None: + kwargs["response_format"] = response_format resp = resource.create_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp) @@ -318,6 +334,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> EBResponse: @@ -345,6 +362,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> AsyncIterator["ChatCompletionResponse"]: @@ -372,6 +390,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., max_output_tokens: Optional[int] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: @@ -398,6 +417,7 @@ async def acreate( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, max_output_tokens: Optional[int] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: @@ -423,6 +443,7 @@ async def acreate( validate_functions: Whether to validate the function descriptions. headers: Custom headers to send with the request. request_timeout: Timeout for a single request. + response_format: Format of the response. _config_: Overrides the global settings. Returns: @@ -460,9 +481,11 @@ async def acreate( def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None: if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"): - for arg in ("functions", "disable_search", "enable_citation", "tool_choice"): + for arg in ("functions", "disable_search", "enable_citation", "tool_choice", "response_format"): if arg in kwargs: - raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.") + raise errors.InvalidArgumentError( + f"`{arg}` is not supported by the `{model_name}` model." + ) def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream: def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str: @@ -497,6 +520,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: "extra_params", "headers", "request_timeout", + "response_format", "max_output_tokens", } @@ -561,6 +585,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: _set_val_if_key_exists(kwargs, params, "tool_choice") _set_val_if_key_exists(kwargs, params, "stream") _set_val_if_key_exists(kwargs, params, "max_output_tokens") + _set_val_if_key_exists(kwargs, params, "response_format") if "extra_params" in kwargs: params.update(kwargs["extra_params"]) diff --git a/erniebot/tests/test_chat_completion.py b/erniebot/tests/test_chat_completion.py index d20d8eaba..073a44b44 100644 --- a/erniebot/tests/test_chat_completion.py +++ b/erniebot/tests/test_chat_completion.py @@ -39,6 +39,18 @@ def create_chat_completion(model): print(response.get_result()) +def create_chat_completion_json_mode(model): + response = erniebot.ChatCompletion.create( + model=model, + messages=[ + {"role": "user", "content": "文心一言是哪个公司开发的?"}, + ], + stream=False, + response_format="json_object", + ) + print(response.get_result()) + + def create_chat_completion_stream(model): response = erniebot.ChatCompletion.create( model=model, @@ -68,5 +80,5 @@ def create_chat_completion_stream(model): erniebot.api_type = "qianfan" create_chat_completion(model="ernie-turbo") - create_chat_completion_stream(model="ernie-turbo") + create_chat_completion_json_mode(model="ernie-lite")