Skip to content

Commit

Permalink
using search queries to get facets
Browse files Browse the repository at this point in the history
  • Loading branch information
snyaggarwal committed Aug 25, 2023
1 parent aa91392 commit ffe371b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 26 deletions.
6 changes: 5 additions & 1 deletion core/common/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@


class CustomESFacetedSearch(FacetedSearch):
def __init__(self, query=None, filters={}, sort=()): # pylint: disable=dangerous-default-value
def __init__(self, query=None, filters={}, sort=(), _search=None): # pylint: disable=dangerous-default-value
self._search = _search
super().__init__(query=query, filters=filters, sort=sort)

@staticmethod
def format_search_str(search_str):
return f"{search_str}*".replace('**', '*')

def query(self, search, query):
if self._search:
from_search = self._search.to_dict()
return search.update_from_dict(from_search)
if query:
search_str = self.format_search_str(query)
if self.fields:
Expand Down
39 changes: 14 additions & 25 deletions core/common/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from core.common.mixins import PathWalkerMixin
from core.common.search import CustomESSearch
from core.common.serializers import RootSerializer
from core.common.utils import compact_dict_by_values, to_snake_case, to_camel_case, parse_updated_since_param, \
from core.common.utils import compact_dict_by_values, to_snake_case, parse_updated_since_param, \
to_int, get_user_specific_task_id, get_falsy_values, get_truthy_values
from core.concepts.permissions import CanViewParentDictionary, CanEditParentDictionary
from core.orgs.constants import ORG_OBJECT_TYPE
Expand Down Expand Up @@ -406,26 +406,16 @@ def get_facets(self):
if self.facet_class:
if self.is_user_document():
return facets
is_source_child_document_model = self.is_source_child_document_model()
default_filters = self.default_filters.copy()

if is_source_child_document_model and 'collection' not in self.kwargs and 'version' not in self.kwargs:
default_filters['is_latest_version'] = True

faceted_filters = {to_camel_case(k): v for k, v in self.get_faceted_filters(True).items()}
filters = {**default_filters, **self.get_facet_filters_from_kwargs(), **faceted_filters, 'retired': False}
if not self._should_exclude_retired_from_search_results() or not is_source_child_document_model:
filters.pop('retired')

faceted_search = self.facet_class( # pylint: disable=not-callable
self.get_search_string(lower=False),
filters=filters
_search=self.__get_search_results(ignore_retired_filter=True, sort=False, highlight=False),
)
faceted_search.params(request_timeout=ES_REQUEST_TIMEOUT)
try:
facets = faceted_search.execute().facets.to_dict()
s = faceted_search.execute()
facets = s.facets.to_dict()
except TransportError as ex: # pragma: no cover
raise Http400(detail=get(ex, 'error') or str(ex)) from ex
raise Http400(detail=get(ex, 'info') or get(ex, 'error') or str(ex)) from ex

return facets

Expand Down Expand Up @@ -519,7 +509,7 @@ def __should_query_latest_version(self):

return (not collection or collection.startswith('!')) and (not version or version.startswith('!'))

def __apply_common_search_filters(self):
def __apply_common_search_filters(self, ignore_retired_filter=False):
results = None
if not self.should_perform_es_search():
return results
Expand All @@ -538,7 +528,7 @@ def __apply_common_search_filters(self):
if updated_since:
results = results.query('range', last_update={"gte": updated_since})

if self._should_exclude_retired_from_search_results():
if not ignore_retired_filter and self._should_exclude_retired_from_search_results():
results = results.query('match', retired=False)

include_private = self._should_include_private()
Expand Down Expand Up @@ -584,7 +574,7 @@ def __get_search_aggregations(
):
results = self.__get_fuzzy_search_results(
source_versions=source_versions, other_filters=other_filters, sort=False
) if self.is_fuzzy_search else self.__search_results
) if self.is_fuzzy_search else self.__get_search_results()

results = results.extra(size=0)
search = CustomESSearch(results)
Expand All @@ -610,9 +600,8 @@ def __get_source_version_es_criteria(source_version):
criteria &= Q('match', owner_type=source_version.parent.resource_type)
return criteria

@property
def __search_results(self): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
results = self.__apply_common_search_filters()
def __get_search_results(self, ignore_retired_filter=False, sort=True, highlight=True): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
results = self.__apply_common_search_filters(ignore_retired_filter)
if results is None:
return results

Expand Down Expand Up @@ -681,9 +670,9 @@ def __search_results(self): # pylint: disable=too-many-branches,too-many-locals
else:
results = results.query('match', **{attr: value})

if self.request.query_params.get(INCLUDE_SEARCH_META_PARAM) in get_truthy_values():
if highlight and self.request.query_params.get(INCLUDE_SEARCH_META_PARAM) in get_truthy_values():
results = results.highlight(*self.clean_fields_for_highlight(fields))
return results.sort(*self._get_sort_attribute())
return results.sort(*self._get_sort_attribute()) if sort else results

@staticmethod
def clean_fields_for_highlight(fields):
Expand Down Expand Up @@ -716,10 +705,10 @@ def __get_queryset_from_search_results(self, search_results):
' or fine tune your query to get more accurate results.') from ex
raise ex
except TransportError as ex: # pragma: no cover
raise Http400(detail=get(ex, 'error') or str(ex)) from ex
raise Http400(detail=get(ex, 'info') or get(ex, 'error') or str(ex)) from ex

def get_search_results_qs(self):
return self.__get_queryset_from_search_results(self.__search_results)
return self.__get_queryset_from_search_results(self.__get_search_results())

def get_fuzzy_search_results_qs(
self, source_versions=None, other_filters=None
Expand Down

0 comments on commit ffe371b

Please sign in to comment.