From 7975d563ba32e6a7cccbeffc5f5c34d2f7980610 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 5 Dec 2024 11:02:27 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E8=AF=86=E5=88=AB=E5=8E=86=E5=8F=B2=E8=AE=B0=E5=BD=95=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E6=98=BE=E7=A4=BA=E5=A4=9A=E4=B8=AA=E5=9B=BE=E7=89=87?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/base_image_understand_node.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 9e09e8a8ea6..1c2536e0c86 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -57,13 +57,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor _write_context(node_variable, workflow_variable, node, workflow, answer) +def file_id_to_base64(file_id: str): + file = QuerySet(File).filter(id=file_id).first() + base64_image = base64.b64encode(file.get_byte()).decode("utf-8") + return base64_image + + class BaseImageUnderstandNode(IImageUnderstandNode): def save_context(self, details, workflow_manage): self.context['answer'] = details.get('answer') self.context['question'] = details.get('question') self.answer_text = details.get('answer') - def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id, + def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, + chat_record_id, image, **kwargs) -> NodeResult: # 处理不正确的参数 @@ -72,12 +79,13 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) # 执行详情中的历史消息不需要图片内容 - history_message =self.get_history_message_for_details(history_chat_record, dialogue_number) + history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) self.context['history_message'] = history_message question = self.generate_prompt_question(prompt) self.context['question'] = question.content # 生成消息列表, 真实的history_message - message_list = self.generate_message_list(image_model, system, prompt, self.get_history_message(history_chat_record, dialogue_number), image) + message_list = self.generate_message_list(image_model, system, prompt, + self.get_history_message(history_chat_record, dialogue_number), image) self.context['message_list'] = message_list self.context['image_list'] = image self.context['dialogue_type'] = dialogue_type @@ -92,11 +100,11 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist 'history_message': history_message, 'question': question.content}, {}, _write_context=write_context) - def get_history_message_for_details(self, history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [self.generate_history_human_message_for_details(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])] + [self.generate_history_human_message_for_details(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message @@ -115,17 +123,19 @@ def generate_history_human_message_for_details(self, chat_record): image_list = data['image_list'] if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) - file_id = image_list[0]['file_id'] + file_id_list = [image.get('file_id') for image in image_list] return HumanMessage(content=[ - {'type': 'text', 'text': data['question']}, - {'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}}, - ]) + {'type': 'text', 'text': data['question']}, + *[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list] + + ]) return HumanMessage(content=chat_record.problem_text) def get_history_message(self, history_chat_record, dialogue_number): start_index = len(history_chat_record) - dialogue_number history_message = reduce(lambda x, y: [*x, *y], [ - [self.generate_history_human_message(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])] + [self.generate_history_human_message(history_chat_record[index]), + self.generate_history_ai_message(history_chat_record[index])] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))], []) return history_message @@ -137,13 +147,12 @@ def generate_history_human_message(self, chat_record): image_list = data['image_list'] if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) - file_id = image_list[0]['file_id'] - file = QuerySet(File).filter(id=file_id).first() - base64_image = base64.b64encode(file.get_byte()).decode("utf-8") + image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list] return HumanMessage( content=[ {'type': 'text', 'text': data['question']}, - {'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}}, + *[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for + base64_image in image_base64_list] ]) return HumanMessage(content=chat_record.problem_text)