Skip to content

Commit

Permalink
Add public/private collection support
Browse files Browse the repository at this point in the history
  • Loading branch information
danlamanna committed Sep 21, 2021
1 parent 3957882 commit 271a9cb
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 18 deletions.
43 changes: 30 additions & 13 deletions isic/core/api.py
@@ -1,14 +1,13 @@
from typing import Optional

from django.contrib.auth.models import User
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action, api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet

from isic.core.models.collection import Collection
from isic.core.models.image import Image
from isic.core.permissions import IsicObjectPermissionsFilter
from isic.core.permissions import IsicObjectPermissionsFilter, get_visible_objects
from isic.core.search import facets, search_images
from isic.core.serializers import ImageSerializer, SearchQuerySerializer
from isic.core.stats import get_archive_stats
Expand All @@ -23,14 +22,27 @@ def stats(request):
return Response(get_archive_stats())


def build_filtered_query(user: User, query: Optional[str] = None) -> dict:
def build_filtered_query(user: User, query_params: dict) -> dict:
"""Translate a django search request into an elasticsearch query."""
serializer = SearchQuerySerializer(data=query_params, context={'user': user})
serializer.is_valid(raise_exception=True)
collection_pks = serializer.validated_data.get('collections')
dsl_query = serializer.validated_data.get('query')

query_dict = {'bool': {}}

if query:
query_dict['bool']['must'] = {'query_string': {'query': query}}
if collection_pks is not None:
query_dict['bool'].setdefault('filter', {})
query_dict['bool']['filter']['terms'] = {'collections': collection_pks}

if dsl_query:
query_dict['bool'].setdefault('must', {})
query_dict['bool']['must']['query_string'] = {'query': dsl_query}

if user.is_anonymous:
query_dict['bool']['should'] = [{'term': {'public': 'true'}}]
# https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-bool-query.html#bool-min-should-match
query_dict['bool']['minimum_should_match'] = 1
elif not user.is_staff:
query_dict['bool']['should'] = [
{'term': {'public': 'true'}},
Expand All @@ -57,20 +69,25 @@ class ImageViewSet(ReadOnlyModelViewSet):
)
@action(detail=False, methods=['get'], pagination_class=None)
def facets(self, request):
serializer = SearchQuerySerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
query = build_filtered_query(request.user, serializer.data.get('query'))
return Response(facets(query))
query = build_filtered_query(request.user, request.query_params)
# Manually pass the list of visible collection PKs through so buckets with
# counts of 0 aren't included in the facets output for non-visible collections.
collection_pks = list(
get_visible_objects(
request.user,
'core.view_collection',
Collection.objects.values_list('pk', flat=True),
)
)
return Response(facets(query, collection_pks))

@swagger_auto_schema(
operation_description='Search images with a key:value query string.',
query_serializer=SearchQuerySerializer,
)
@action(detail=False, methods=['get'])
def search(self, request):
serializer = SearchQuerySerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
query = build_filtered_query(request.user, serializer.data.get('query'))
query = build_filtered_query(request.user, request.query_params)
search_results = search_images(
query, self.paginator.get_limit(request), self.paginator.get_offset(request)
)
Expand Down
18 changes: 18 additions & 0 deletions isic/core/migrations/0023_collection_public.py
@@ -0,0 +1,18 @@
# Generated by Django 3.2.6 on 2021-09-15 18:04

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('core', '0022_alter_image_isic'),
]

operations = [
migrations.AddField(
model_name='collection',
name='public',
field=models.BooleanField(default=False),
),
]
15 changes: 11 additions & 4 deletions isic/core/models/collection.py
@@ -1,4 +1,8 @@
from typing import Optional

from django.contrib.auth.models import User
from django.db import models
from django.db.models.query import QuerySet
from django.urls import reverse
from django_extensions.db.models import TimeStampedModel

Expand All @@ -12,6 +16,8 @@ class Collection(TimeStampedModel):
name = models.CharField(max_length=200, unique=True)
description = models.TextField(blank=True)

public = models.BooleanField(default=False)

def __str__(self):
return self.name

Expand All @@ -25,14 +31,15 @@ class CollectionPermissions:
filters = {'view_collection': 'view_collection_list'}

@staticmethod
def view_collection_list(user_obj, qs=Collection._default_manager):
if not user_obj.is_active or not user_obj.is_authenticated:
return qs.none()
def view_collection_list(
user_obj: User, qs: Optional[QuerySet[Collection]] = None
) -> QuerySet[Collection]:
qs: QuerySet = qs if qs is not None else Collection._default_manager.all()

if user_obj.is_staff:
return qs

return qs.none()
return qs.filter(public=True)

@staticmethod
def view_collection(user_obj, obj):
Expand Down
1 change: 1 addition & 0 deletions isic/core/models/image.py
Expand Up @@ -66,6 +66,7 @@ def as_elasticsearch_document(self) -> dict:
user.pk for user in self.accession.upload.cohort.contributor.owners.all()
],
'shared_to': [user.pk for user in self.shares.all()],
'collections': list(self.collections.values_list('pk', flat=True)),
}
)

Expand Down
8 changes: 7 additions & 1 deletion isic/core/search.py
Expand Up @@ -151,12 +151,18 @@ def bulk_add_to_search_index(qs: QuerySet[Image], chunk_size: int = 500) -> None
logger.error('Failed to insert document into elasticsearch', info)


def facets(query: Optional[dict] = None) -> dict:
def facets(query: Optional[dict] = None, collections: Optional[list[int]] = None) -> dict:
body = {
'size': 0,
'aggs': DEFAULT_SEARCH_AGGREGATES,
}

if collections is not None:
# Note this include statement means we can only filter by ~65k collections. See:
# "By default, Elasticsearch limits the terms query to a maximum of 65,536 terms.
# You can change this limit using the index.max_terms_count setting."
body['aggs']['collections'] = {'terms': {'field': 'collections', 'include': collections}}

if query:
body['query'] = query

Expand Down
54 changes: 54 additions & 0 deletions isic/core/serializers.py
@@ -1,11 +1,65 @@
import re
from typing import List, Optional, Union

from rest_framework import serializers
from rest_framework.fields import Field

from isic.core.models import Image
from isic.core.models.collection import Collection
from isic.core.models.image import RESTRICTED_SEARCH_FIELDS
from isic.core.permissions import get_visible_objects


class CollectionsField(Field):
"""
A field for comma separated collection ids.
This field filters the collection ids to those that are visible
to the user in context.
"""

default_error_messages = {
'invalid': 'Not a valid string.',
'not_comma_delimited': 'Not a comma delimited string.',
}

def to_representation(self, obj: List[int]) -> str:
obj = super().to_representation(obj)
return ','.join([str(element) for element in obj])

def to_internal_value(self, data: Optional[Union[list, str]]) -> Optional[List[int]]:
if data:
# if the data is coming from swagger, it's built into a 1 element list
if isinstance(data, list):
data = data[0]
elif not isinstance(data, str):
self.fail('invalid')

if not re.match(r'^(\d+)(,\d+)*$', data):
self.fail('not_comma_delimited')

data = [int(x) for x in data.split(',')]
return self._filter_collection_pks(data)

def _filter_collection_pks(self, collection_pks: List[int]) -> List[int]:
visible_collection_pks = get_visible_objects(
self.context['user'],
'core.view_collection',
Collection.objects.filter(pk__in=collection_pks),
)

return list(visible_collection_pks.values_list('pk', flat=True))


class SearchQuerySerializer(serializers.Serializer):
"""A serializer for a search query against images.
Note that this serializer requires being called with a user object in
the context.
"""

query = serializers.CharField(required=False)
collections = CollectionsField(required=False)


class ImageUrlSerializer(serializers.Serializer):
Expand Down
128 changes: 128 additions & 0 deletions isic/core/tests/test_search.py
Expand Up @@ -48,6 +48,32 @@ def searchable_images_with_private_fields(image_factory, search_index):
return images


@pytest.fixture
def private_and_public_images_collections(search_index, image_factory, collection_factory):
public_coll, private_coll = collection_factory(public=True), collection_factory(public=False)
public_image, private_image = image_factory(public=True), image_factory(public=False)

public_coll.images.add(public_image)
private_coll.images.add(private_image)

for image in [public_image, private_image]:
add_to_search_index(image)

get_elasticsearch_client().indices.refresh(index='_all')

yield public_coll, private_coll


@pytest.fixture
def collection_with_image(search_index, image_factory, collection_factory):
public_coll = collection_factory(public=True)
public_image = image_factory(public=True, accession__metadata={'age': 52})
public_coll.images.add(public_image)
add_to_search_index(public_image)
get_elasticsearch_client().indices.refresh(index='_all')
yield public_coll


@pytest.mark.django_db
def test_core_api_image_search(searchable_images, staff_api_client):
r = staff_api_client.get('/api/v2/images/search/')
Expand All @@ -66,6 +92,20 @@ def test_core_api_image_search_private_image(private_searchable_image, authentic
assert r.data['count'] == 0


@pytest.mark.django_db
def test_core_api_image_search_private_image_as_guest(private_searchable_image, api_client):
r = api_client.get('/api/v2/images/search/')
assert r.status_code == 200, r.data
assert r.data['count'] == 0


@pytest.mark.django_db
def test_core_api_image_search_images_as_guest(searchable_images, api_client):
r = api_client.get('/api/v2/images/search/')
assert r.status_code == 200, r.data
assert r.data['count'] == 1


@pytest.mark.django_db
def test_core_api_image_search_contributed(
private_searchable_image, authenticated_api_client, user
Expand Down Expand Up @@ -104,3 +144,91 @@ def test_core_api_image_hides_fields(
assert r.data['count'] == 3
for image in r.data['results']:
assert restricted_field not in image['metadata']


@pytest.mark.django_db
def test_core_api_image_search_collection_and_query(
collection_with_image, authenticated_api_client
):
r = authenticated_api_client.get(
'/api/v2/images/search/',
{'collections': f'{collection_with_image.pk}', 'query': 'age_approx:50'},
)
assert r.status_code == 200, r.data
assert r.data['count'] == 1


@pytest.mark.django_db
@pytest.mark.parametrize(
'collection_is_public,image_is_public,can_see',
[
(True, True, True),
# Don't leak which images are in a private collection
(False, True, False),
(False, False, False),
],
ids=['all-public', 'private-coll-public-image', 'all-private'],
)
def test_core_api_image_search_collection(
authenticated_api_client,
image_factory,
collection_factory,
search_index,
collection_is_public,
image_is_public,
can_see,
):
collection = collection_factory(public=collection_is_public)
image = image_factory(public=image_is_public)
collection.images.add(image)
add_to_search_index(image)
get_elasticsearch_client().indices.refresh(index='_all')

r = authenticated_api_client.get('/api/v2/images/search/', {'collections': str(collection.pk)})
assert r.status_code == 200, r.data

if can_see:
assert r.data['count'] == 1
else:
assert r.data['count'] == 0


@pytest.mark.django_db
def test_core_api_image_search_collection_parsing(
private_and_public_images_collections, authenticated_api_client
):
public_coll, private_coll = private_and_public_images_collections

r = authenticated_api_client.get(
'/api/v2/images/search/', {'collections': f'{public_coll.pk},{private_coll.pk}'}
)
assert r.status_code == 200, r.data
assert r.data['count'] == 1


@pytest.mark.django_db
def test_core_api_image_faceting_collections(
private_and_public_images_collections, authenticated_api_client
):
public_coll, private_coll = private_and_public_images_collections

r = authenticated_api_client.get(
'/api/v2/images/facets/', {'collections': f'{public_coll.pk},{private_coll.pk}'}
)
assert r.status_code == 200, r.data
buckets = r.data['collections']['buckets']
assert len(buckets) == 1
assert buckets[0] == {'key': public_coll.pk, 'doc_count': 1}


@pytest.mark.django_db
def test_core_api_image_faceting(private_and_public_images_collections, authenticated_api_client):
public_coll, private_coll = private_and_public_images_collections

r = authenticated_api_client.get(
'/api/v2/images/facets/',
)
assert r.status_code == 200, r.data
buckets = r.data['collections']['buckets']
assert len(buckets) == 1, buckets
assert buckets[0] == {'key': public_coll.pk, 'doc_count': 1}, buckets

0 comments on commit 271a9cb

Please sign in to comment.