-
Notifications
You must be signed in to change notification settings - Fork 267
/
oai.py
103 lines (89 loc) · 3.81 KB
/
oai.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import copy
import os
from pprint import pformat
from typing import Dict, Iterator, List, Optional
import openai
if openai.__version__.startswith('0.'):
from openai.error import OpenAIError # noqa
else:
from openai import OpenAIError
from qwen_agent.llm.base import ModelServiceError, register_llm
from qwen_agent.llm.function_calling import BaseFnCallModel
from qwen_agent.llm.schema import ASSISTANT, Message
from qwen_agent.log import logger
@register_llm('oai')
class TextChatAtOAI(BaseFnCallModel):
def __init__(self, cfg: Optional[Dict] = None):
super().__init__(cfg)
self.model = self.model or 'gpt-3.5-turbo'
cfg = cfg or {}
api_base = cfg.get(
'api_base',
cfg.get(
'base_url',
cfg.get('model_server', ''),
),
).strip()
api_key = cfg.get('api_key', '')
if not api_key:
api_key = os.getenv('OPENAI_API_KEY', 'EMPTY')
api_key = api_key.strip()
if openai.__version__.startswith('0.'):
if api_base:
openai.api_base = api_base
if api_key:
openai.api_key = api_key
self._chat_complete_create = openai.ChatCompletion.create
else:
api_kwargs = {}
if api_base:
api_kwargs['base_url'] = api_base
if api_key:
api_kwargs['api_key'] = api_key
def _chat_complete_create(*args, **kwargs):
# OpenAI API v1 does not allow the following args, must pass by extra_body
extra_params = ['top_k', 'repetition_penalty']
if any((k in kwargs) for k in extra_params):
kwargs['extra_body'] = copy.deepcopy(kwargs.get('extra_body', {}))
for k in extra_params:
if k in kwargs:
kwargs['extra_body'][k] = kwargs.pop(k)
if 'request_timeout' in kwargs:
kwargs['timeout'] = kwargs.pop('request_timeout')
client = openai.OpenAI(**api_kwargs)
return client.chat.completions.create(*args, **kwargs)
self._chat_complete_create = _chat_complete_create
def _chat_stream(
self,
messages: List[Message],
delta_stream: bool,
generate_cfg: dict,
) -> Iterator[List[Message]]:
messages = [msg.model_dump() for msg in messages]
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=True, **generate_cfg)
if delta_stream:
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
yield [Message(ASSISTANT, chunk.choices[0].delta.content)]
else:
full_response = ''
for chunk in response:
if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content:
full_response += chunk.choices[0].delta.content
yield [Message(ASSISTANT, full_response)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)
def _chat_no_stream(
self,
messages: List[Message],
generate_cfg: dict,
) -> List[Message]:
messages = [msg.model_dump() for msg in messages]
logger.debug(f'LLM Input:\n{pformat(messages, indent=2)}')
try:
response = self._chat_complete_create(model=self.model, messages=messages, stream=False, **generate_cfg)
return [Message(ASSISTANT, response.choices[0].message.content)]
except OpenAIError as ex:
raise ModelServiceError(exception=ex)