From 1ac1c99117af5819a0d77f025d2eabde8ca7dd70 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Mon, 15 Jul 2024 20:07:53 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E3=80=90=E9=97=AE?= =?UTF-8?q?=E7=AD=94=E9=A1=B5=E9=9D=A2=E3=80=91-=20=E5=BD=93=E5=BA=94?= =?UTF-8?q?=E7=94=A8=E5=85=B3=E8=81=94=E7=9A=84=E7=9F=A5=E8=AF=86=E5=BA=93?= =?UTF-8?q?=E4=B8=AD=E5=90=AB=E6=9C=89=E7=A6=81=E7=94=A8=E7=8A=B6=E6=80=81?= =?UTF-8?q?=E7=9A=84=E6=96=87=E6=A1=A3=E6=97=B6=EF=BC=8C=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E6=97=B6=E7=82=B9=E5=87=BB=E6=8D=A2=E4=B8=AA=E7=AD=94=E6=A1=88?= =?UTF-8?q?=E4=B8=8D=E4=BC=9A=E6=8D=A2=E4=B8=80=E6=89=B9=E5=91=BD=E4=B8=AD?= =?UTF-8?q?=E5=88=86=E6=AE=B5=20#759?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../search_dataset_node/i_search_dataset_node.py | 16 +++++++++++++++- .../serializers/application_serializers.py | 2 ++ apps/embedding/vector/pg_vector.py | 4 ++-- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py index 0a134527c98..436de5a96d6 100644 --- a/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -13,6 +13,7 @@ from rest_framework import serializers from application.flow.i_step_node import INode, NodeResult +from common.util.common import flat_map from common.util.field_message import ErrMessage @@ -43,6 +44,13 @@ def is_valid(self, *, raise_exception=False): super().is_valid(raise_exception=True) +def get_paragraph_list(chat_record, node_id): + return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if + (chat_record.details[ + key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get( + 'paragraph_list', []) is not None and key == node_id]) + + class ISearchDatasetStepNode(INode): type = 'search-dataset-node' @@ -53,7 +61,13 @@ def _run(self): question = self.workflow_manage.get_reference_field( self.node_params_serializer.data.get('question_reference_address')[0], self.node_params_serializer.data.get('question_reference_address')[1:]) - return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[]) + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + paragraph_id_list = [p.get('id') for p in flat_map( + [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if + chat_record.problem_text == question])] + exclude_paragraph_id_list = list(set(paragraph_id_list)) + return self.execute(**self.node_params_serializer.data, question=str(question), + exclude_paragraph_id_list=exclude_paragraph_id_list) def execute(self, dataset_id_list, dataset_setting, question, exclude_paragraph_id_list=None, diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 8ac962c95bc..d689448e12f 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -579,6 +579,8 @@ def one(self, with_valid=True): 'dataset_id_list': dataset_id_list} def get_search_node(self, work_flow): + if work_flow is None: + return [] return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node'] def update_search_node(self, work_flow, user_dataset_id_list: List): diff --git a/apps/embedding/vector/pg_vector.py b/apps/embedding/vector/pg_vector.py index 5c0d045363b..e9a62ae577d 100644 --- a/apps/embedding/vector/pg_vector.py +++ b/apps/embedding/vector/pg_vector.py @@ -105,9 +105,9 @@ def query(self, query_text: str, query_embedding: List[float], dataset_id_list: return [] query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active) if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: - exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + query_set = query_set.exclude(document_id__in=exclude_document_id_list) if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: - exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list) + query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list) query_set = query_set.exclude(**exclude_dict) for search_handle in search_handle_list: if search_handle.support(search_mode):