From 98b1221c2dc146e8845bb0abcca10b6dc521d419 Mon Sep 17 00:00:00 2001 From: zhuming Date: Wed, 8 Jan 2025 15:23:50 +0800 Subject: [PATCH] fix: fix developer message missed --- ghostos/core/messages/openai.py | 5 +++++ ghostos/framework/llms/openai_driver.py | 26 +++++++++++++++---------- pyproject.toml | 2 +- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/ghostos/core/messages/openai.py b/ghostos/core/messages/openai.py index 9f7f92cf..dc3e0a15 100644 --- a/ghostos/core/messages/openai.py +++ b/ghostos/core/messages/openai.py @@ -8,6 +8,7 @@ from openai.types.chat.chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam, FunctionCall from openai.types.chat.chat_completion_message_tool_call_param import ChatCompletionMessageToolCallParam from openai.types.chat.chat_completion_system_message_param import ChatCompletionSystemMessageParam +from openai.types.chat.chat_completion_developer_message_param import ChatCompletionDeveloperMessageParam from openai.types.chat.chat_completion_user_message_param import ChatCompletionUserMessageParam from openai.types.chat.chat_completion_function_message_param import ChatCompletionFunctionMessageParam from ghostos.core.messages import ( @@ -189,6 +190,10 @@ def _parse_message(self, message: Message) -> Iterable[ChatCompletionMessagePara return [ ChatCompletionSystemMessageParam(content=message.get_content(), role="system") ] + elif message.role == Role.DEVELOPER: + return [ + ChatCompletionDeveloperMessageParam(content=message.get_content(), role="developer") + ] elif message.role == Role.USER: item = ChatCompletionUserMessageParam(content=message.get_content(), role="user") if message.name: diff --git a/ghostos/framework/llms/openai_driver.py b/ghostos/framework/llms/openai_driver.py index 9a0f0ed8..95b7fd8b 100644 --- a/ghostos/framework/llms/openai_driver.py +++ b/ghostos/framework/llms/openai_driver.py @@ -118,33 +118,38 @@ def parse_message_params(self, messages: List[Message]) -> List[ChatCompletionMe messages = self.parse_by_compatible_settings(messages) return list(self._parser.parse_message_list(messages, self.model.message_types)) + @staticmethod + def _parse_system_to_develop(messages: List[Message]) -> List[Message]: + changed = [] + for message in messages: + if message.role == Role.SYSTEM: + message = message.model_copy(update={"role": Role.DEVELOPER.value}, deep=True) + changed.append(message) + return changed + def parse_by_compatible_settings(self, messages: List[Message]) -> List[Message]: # developer role test if self.service.compatible.use_developer_role: - changed = [] - for message in messages: - if message.role == Role.SYSTEM: - message = message.model_copy(update={"role": Role.DEVELOPER}, deep=True) - changed.append(message) - messages = changed + messages = self._parse_system_to_develop(messages) else: changed = [] for message in messages: if message.role == Role.DEVELOPER: - message = message.model_copy(update={"role": Role.SYSTEM}, deep=True) + message = message.model_copy(update={"role": Role.SYSTEM.value}, deep=True) changed.append(message) messages = changed # allow system messages if not self.service.compatible.allow_system_in_messages: changed = [] + count = 0 for message in messages: - if message.role == Role.SYSTEM or message.role == Role.DEVELOPER: + if count > 0 and message.role == Role.SYSTEM or message.role == Role.DEVELOPER: name = f"__{message.role}__" - message = message.model_copy(update={"role": Role.USER, "name": name}, deep=True) + message = message.model_copy(update={"role": Role.USER.value, "name": name}, deep=True) changed.append(message) + count += 1 messages = changed - return messages def _chat_completion(self, prompt: Prompt, stream: bool) -> Union[ChatCompletion, Iterable[ChatCompletionChunk]]: @@ -238,6 +243,7 @@ def _reasoning_completion(self, prompt: Prompt) -> ChatCompletion: ) # include_usage = ChatCompletionStreamOptionsParam(include_usage=True) if stream else NOT_GIVEN messages = prompt.get_messages() + messages = self._parse_system_to_develop(messages) messages = self.parse_message_params(messages) if not messages: raise AttributeError("empty chat!!") diff --git a/pyproject.toml b/pyproject.toml index 1f8dcba6..bbd71945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "ghostos" -version = "0.1.0" +version = "0.1.1" description = "A framework offers an operating system simulator with a Python Code Interface for AI Agents" authors = ["zhuming ", "Nile Zhou "] license = "MIT"