Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rest_framework import serializers

from application.flow.i_step_node import INode, NodeResult
from common.util.common import flat_map
from common.util.field_message import ErrMessage


Expand Down Expand Up @@ -43,6 +44,13 @@ def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)


def get_paragraph_list(chat_record, node_id):
return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if
(chat_record.details[
key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get(
'paragraph_list', []) is not None and key == node_id])


class ISearchDatasetStepNode(INode):
type = 'search-dataset-node'

Expand All @@ -53,7 +61,13 @@ def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[])
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
paragraph_id_list = [p.get('id') for p in flat_map(
[get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if
chat_record.problem_text == question])]
exclude_paragraph_id_list = list(set(paragraph_id_list))
return self.execute(**self.node_params_serializer.data, question=str(question),
exclude_paragraph_id_list=exclude_paragraph_id_list)

def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
Expand Down
2 changes: 2 additions & 0 deletions apps/application/serializers/application_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ def one(self, with_valid=True):
'dataset_id_list': dataset_id_list}

def get_search_node(self, work_flow):
if work_flow is None:
return []
return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node']

def update_search_node(self, work_flow, user_dataset_id_list: List):
Expand Down
4 changes: 2 additions & 2 deletions apps/embedding/vector/pg_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ def query(self, query_text: str, query_embedding: List[float], dataset_id_list:
return []
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
for search_handle in search_handle_list:
if search_handle.support(search_mode):
Expand Down