Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
_write_context(node_variable, workflow_variable, node, workflow, answer)


def file_id_to_base64(file_id: str):
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
return base64_image


class BaseImageUnderstandNode(IImageUnderstandNode):
def save_context(self, details, workflow_manage):
self.context['answer'] = details.get('answer')
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')

def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
chat_record_id,
image,
**kwargs) -> NodeResult:
# 处理不正确的参数
Expand All @@ -72,12 +79,13 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist

image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
# 执行详情中的历史消息不需要图片内容
history_message =self.get_history_message_for_details(history_chat_record, dialogue_number)
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# 生成消息列表, 真实的history_message
message_list = self.generate_message_list(image_model, system, prompt, self.get_history_message(history_chat_record, dialogue_number), image)
message_list = self.generate_message_list(image_model, system, prompt,
self.get_history_message(history_chat_record, dialogue_number), image)
self.context['message_list'] = message_list
self.context['image_list'] = image
self.context['dialogue_type'] = dialogue_type
Expand All @@ -92,11 +100,11 @@ def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, hist
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)


def get_history_message_for_details(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message_for_details(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])]
[self.generate_history_human_message_for_details(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
Expand All @@ -115,17 +123,19 @@ def generate_history_human_message_for_details(self, chat_record):
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id = image_list[0]['file_id']
file_id_list = [image.get('file_id') for image in image_list]
return HumanMessage(content=[
{'type': 'text', 'text': data['question']},
{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}},
])
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url', 'image_url': {'url': f'/api/file/{file_id}'}} for file_id in file_id_list]

])
return HumanMessage(content=chat_record.problem_text)

def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[self.generate_history_human_message(history_chat_record[index]), self.generate_history_ai_message(history_chat_record[index])]
[self.generate_history_human_message(history_chat_record[index]),
self.generate_history_ai_message(history_chat_record[index])]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
Expand All @@ -137,13 +147,12 @@ def generate_history_human_message(self, chat_record):
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id = image_list[0]['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
*[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for
base64_image in image_base64_list]
])
return HumanMessage(content=chat_record.problem_text)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中存在一些问题和优化建议:

  1. 缺少文件ID到Base64的转换函数。

  2. execute方法没有正确处理历史聊天记录中的错误输入。

  3. 不应该在其他方法中调用write_context方法,这样会引发无限循环。

  4. 在某些情况下,可能需要考虑数据完整性(例如图片路径可能不唯一)。在这种情况下,可以使用序列化或哈希等技术来防止重复创建相同的对象。

  5. 逻辑简单性上有一些改进的空间:可以通过更简单的替换条件减少复杂的判断树。

我提供了这些分析,请仔细审查,并提出你的修正意见和方案。这些建议可能会改变原始结构或增加新的功能。

此外,如果你正在寻求关于特定编程语言或软件开发领域的问题的见解,则应提供详细信息以供帮助。

Expand Down
Loading