Skip to content

Commit

Permalink
[erniebot] Add new qianfan models, Add response_format argument (#349)
Browse files Browse the repository at this point in the history
* add new models; add response_format

* remove aksk

* fix lint
  • Loading branch information
sijunhe committed May 31, 2024
1 parent 3f2ecc0 commit 39f0a38
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 10 deletions.
43 changes: 34 additions & 9 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,29 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5-8k": {
"model_id": "completions",
},
"ernie-3.5-8k-0205": {
"model_id": "ernie-3.5-8k-0205",
},
"ernie-3.5-8k-0329": {
"model_id": "ernie-3.5-8k-0329",
},
"ernie-3.5-128k": {
"model_id": "ernie-3.5-128k",
},
"ernie-lite": {
"model_id": "eb-instant",
},
"ernie-lite-8k-0308": {
"model_id": "ernie-lite-8k",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
"ernie-4.0-8k-0329": {
"model_id": "ernie-4.0-8k-0329",
},
"ernie-4.0-8k-0104": {
"model_id": "ernie-4.0-8k-0104",
},
"ernie-speed": {
"model_id": "ernie_speed",
Expand Down Expand Up @@ -97,10 +111,6 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
Expand Down Expand Up @@ -156,6 +166,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> "ChatCompletionResponse":
Expand Down Expand Up @@ -183,6 +194,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Iterator["ChatCompletionResponse"]:
Expand Down Expand Up @@ -210,6 +222,7 @@ def create(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
Expand All @@ -236,6 +249,7 @@ def create(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
max_output_tokens: Optional[int] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]:
Expand Down Expand Up @@ -292,6 +306,8 @@ def create(
kwargs["headers"] = headers
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout
if response_format is not None:
kwargs["response_format"] = response_format

resp = resource.create_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)
Expand All @@ -318,6 +334,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> EBResponse:
Expand Down Expand Up @@ -345,6 +362,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> AsyncIterator["ChatCompletionResponse"]:
Expand Down Expand Up @@ -372,6 +390,7 @@ async def acreate(
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ...,
response_format: Optional[Literal["json_object", "text"]] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
Expand All @@ -398,6 +417,7 @@ async def acreate(
extra_params: Optional[dict] = None,
headers: Optional[HeadersType] = None,
request_timeout: Optional[float] = None,
response_format: Optional[Literal["json_object", "text"]] = None,
max_output_tokens: Optional[int] = None,
_config_: Optional[ConfigDictType] = None,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
Expand All @@ -423,6 +443,7 @@ async def acreate(
validate_functions: Whether to validate the function descriptions.
headers: Custom headers to send with the request.
request_timeout: Timeout for a single request.
response_format: Format of the response.
_config_: Overrides the global settings.
Returns:
Expand Down Expand Up @@ -460,9 +481,11 @@ async def acreate(

def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"):
for arg in ("functions", "disable_search", "enable_citation", "tool_choice"):
for arg in ("functions", "disable_search", "enable_citation", "tool_choice", "response_format"):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")
raise errors.InvalidArgumentError(
f"`{arg}` is not supported by the `{model_name}` model."
)

def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream:
def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str:
Expand Down Expand Up @@ -497,6 +520,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"extra_params",
"headers",
"request_timeout",
"response_format",
"max_output_tokens",
}

Expand Down Expand Up @@ -561,6 +585,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
_set_val_if_key_exists(kwargs, params, "tool_choice")
_set_val_if_key_exists(kwargs, params, "stream")
_set_val_if_key_exists(kwargs, params, "max_output_tokens")
_set_val_if_key_exists(kwargs, params, "response_format")

if "extra_params" in kwargs:
params.update(kwargs["extra_params"])
Expand Down
14 changes: 13 additions & 1 deletion erniebot/tests/test_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ def create_chat_completion(model):
print(response.get_result())


def create_chat_completion_json_mode(model):
response = erniebot.ChatCompletion.create(
model=model,
messages=[
{"role": "user", "content": "文心一言是哪个公司开发的?"},
],
stream=False,
response_format="json_object",
)
print(response.get_result())


def create_chat_completion_stream(model):
response = erniebot.ChatCompletion.create(
model=model,
Expand Down Expand Up @@ -68,5 +80,5 @@ def create_chat_completion_stream(model):
erniebot.api_type = "qianfan"

create_chat_completion(model="ernie-turbo")

create_chat_completion_stream(model="ernie-turbo")
create_chat_completion_json_mode(model="ernie-lite")

0 comments on commit 39f0a38

Please sign in to comment.