Skip to content
This repository has been archived by the owner on Feb 22, 2023. It is now read-only.

Move page and page_size query param validation into serializer #868

Merged
merged 9 commits into from Sep 16, 2022
8 changes: 3 additions & 5 deletions api/catalog/api/controllers/search_controller.py
Expand Up @@ -18,7 +18,6 @@

import catalog.api.models as models
from catalog.api.utils.dead_link_mask import get_query_hash, get_query_mask
from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT
from catalog.api.utils.validate_images import validate_images


Expand Down Expand Up @@ -502,10 +501,9 @@ def _get_result_and_page_count(
return 0, 1

result_count = response_obj.hits.total.value
natural_page_count = int(result_count / page_size)
if natural_page_count % page_size != 0:
natural_page_count += 1
page_count = min(natural_page_count, MAX_TOTAL_PAGE_COUNT)
zackkrida marked this conversation as resolved.
Show resolved Hide resolved
page_count = int(result_count / page_size)
if page_count % page_size != 0:
page_count += 1
if len(results) < page_size and page_count == 0:
result_count = len(results)

Expand Down
44 changes: 39 additions & 5 deletions api/catalog/api/serializers/media_serializers.py
@@ -1,13 +1,16 @@
from collections import namedtuple

from django.conf import settings
from django.core.exceptions import ValidationError
from django.core.validators import MaxValueValidator
from rest_framework import serializers
from rest_framework.exceptions import NotAuthenticated

from catalog.api.constants.licenses import LICENSE_GROUPS
from catalog.api.controllers import search_controller
from catalog.api.models.media import AbstractMedia
from catalog.api.serializers.base import BaseModelSerializer
from catalog.api.serializers.fields import SchemableHyperlinkedIdentityField
from catalog.api.utils.exceptions import get_api_exception
from catalog.api.utils.help_text import make_comma_separated_help_text
from catalog.api.utils.licenses import get_license_url
from catalog.api.utils.url import add_protocol
Expand Down Expand Up @@ -42,6 +45,7 @@ class MediaSearchRequestSerializer(serializers.Serializer):
"mature",
"qa",
"page_size",
"page",
]
"""
Keep the fields names in sync with the actual fields below as this list is
Expand Down Expand Up @@ -111,6 +115,16 @@ class MediaSearchRequestSerializer(serializers.Serializer):
label="page_size",
help_text="Number of results to return per page.",
required=False,
default=settings.MAX_ANONYMOUS_PAGE_SIZE,
min_value=1,
)
page = serializers.IntegerField(
label="page",
help_text="The page of results to retrieve.",
required=False,
default=1,
max_value=settings.MAX_PAGINATION_DEPTH,
min_value=1,
)

@staticmethod
Expand Down Expand Up @@ -161,10 +175,30 @@ def validate_title(self, value):
def validate_page_size(self, value):
request = self.context.get("request")
is_anonymous = bool(request and request.user and request.user.is_anonymous)
if is_anonymous and value > 20:
raise get_api_exception(
"Page size must be between 1 & 20 for unauthenticated requests.", 401
)
max_value = (
settings.MAX_ANONYMOUS_PAGE_SIZE
if is_anonymous
else settings.MAX_AUTHED_PAGE_SIZE
)

validator = MaxValueValidator(
max_value,
message=serializers.IntegerField.default_error_messages["max_value"].format(
max_value=max_value
),
)
zackkrida marked this conversation as resolved.
Show resolved Hide resolved

if is_anonymous:
try:
validator(value)
except ValidationError as e:
raise NotAuthenticated(
detail=e.message,
code=e.code,
)
else:
validator(value)

return value

@staticmethod
Expand Down
42 changes: 3 additions & 39 deletions api/catalog/api/utils/pagination.py
@@ -1,11 +1,7 @@
from django.conf import settings
from rest_framework.pagination import PageNumberPagination
from rest_framework.response import Response

from catalog.api.utils.exceptions import get_api_exception


MAX_TOTAL_PAGE_COUNT = 20


class StandardPagination(PageNumberPagination):
page_size_query_param = "page_size"
Expand All @@ -15,45 +11,13 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.result_count = None # populated later
self.page_count = None # populated later

self._page_size = 20
self._page = None

@property
def page_size(self):
"""the number of results to show in one page"""
return self._page_size

@page_size.setter
def page_size(self, value):
if value is None or not str(value).isnumeric():
return
value = int(value) # convert str params to int
if value <= 0 or value > 500:
raise get_api_exception("Page size must be between 0 & 500.", 400)
self._page_size = value

@property
def page(self):
"""the current page number being served"""
return self._page

@page.setter
def page(self, value):
if value is None or not str(value).isnumeric():
value = 1
value = int(value) # convert str params to int
if value <= 0:
raise get_api_exception("Page must be greater than 0.", 400)
elif value > 20:
raise get_api_exception("Searches are limited to 20 pages.", 400)
self._page = value
self.page = 1 # default, get's updated when necessary

def get_paginated_response(self, data):
return Response(
{
"result_count": self.result_count,
"page_count": self.page_count,
"page_count": min(settings.MAX_PAGINATION_DEPTH, self.page_count),
"page_size": self.page_size,
"page": self.page,
"results": data,
Expand Down
8 changes: 3 additions & 5 deletions api/catalog/api/views/media_views.py
Expand Up @@ -76,13 +76,11 @@ def _get_request_serializer(self, request):
# Standard actions

def list(self, request, *_, **__):
self.paginator.page_size = request.query_params.get("page_size")
page_size = self.paginator.page_size
self.paginator.page = request.query_params.get("page")
page = self.paginator.page

params = self._get_request_serializer(request)

page_size = self.paginator.page_size = params.data["page_size"]
page = self.paginator.page = params.data["page"]

hashed_ip = hash(self._get_user_ip(request))
qa = params.validated_data["qa"]
filter_dead = params.validated_data["filter_dead"]
Expand Down
4 changes: 4 additions & 0 deletions api/catalog/settings.py
Expand Up @@ -372,3 +372,7 @@
# E.g. LINK_VALIDATION_CACHE_EXPIRY__200='{"days": 1}' will set the expiration time
# for links with HTTP status 200 to 1 day
LINK_VALIDATION_CACHE_EXPIRY_CONFIGURATION = LinkValidationCacheExpiryConfiguration()

MAX_ANONYMOUS_PAGE_SIZE = 20
MAX_AUTHED_PAGE_SIZE = 500
MAX_PAGINATION_DEPTH = 20
2 changes: 1 addition & 1 deletion api/test/auth_test.py
Expand Up @@ -97,7 +97,7 @@ def test_auth_rate_limit_reporting(


@pytest.mark.django_db
def test_pase_size_limit_unauthed(client):
def test_page_size_limit_unauthed(client):
query_params = {"filter_dead": False, "page_size": 20}
res = client.get("/v1/images/", query_params)
assert res.status_code == 200
Expand Down
9 changes: 6 additions & 3 deletions api/test/dead_link_filter_test.py
Expand Up @@ -2,12 +2,13 @@
from unittest.mock import MagicMock, patch
from uuid import uuid4

from django.conf import settings

import pytest
import requests
from fakeredis import FakeRedis

from catalog.api.controllers.search_controller import DEAD_LINK_RATIO
from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -202,7 +203,7 @@ def test_page_consistency_removing_dead_links(search_without_dead_links):
Test the results returned in consecutive pages are never repeated when
filtering out dead links.
"""
total_pages = MAX_TOTAL_PAGE_COUNT
total_pages = settings.MAX_PAGINATION_DEPTH
page_size = 5

page_results = []
Expand All @@ -226,6 +227,8 @@ def no_duplicates(xs):
@pytest.mark.django_db
def test_max_page_count():
response = requests.get(
f"{API_URL}/v1/images", params={"page": MAX_TOTAL_PAGE_COUNT + 1}, verify=False
f"{API_URL}/v1/images",
params={"page": settings.MAX_PAGINATION_DEPTH + 1},
verify=False,
)
assert response.status_code == 400
3 changes: 0 additions & 3 deletions api/test/unit/controllers/test_search_controller.py
Expand Up @@ -10,7 +10,6 @@

from catalog.api.controllers import search_controller
from catalog.api.utils.dead_link_mask import get_query_hash, save_query_mask
from catalog.api.utils.pagination import MAX_TOTAL_PAGE_COUNT


@pytest.mark.parametrize(
Expand Down Expand Up @@ -41,8 +40,6 @@
(20, 5, 5, (20, 5)),
# Fewer hits than page size, but result list somehow differs, use that for count
(48, 20, 50, (20, 0)),
# Page count gets truncated always
(5000, 10, 10, (5000, MAX_TOTAL_PAGE_COUNT)),
],
)
def test_get_result_and_page_count(total_hits, real_result_count, page_size, expected):
Expand Down
85 changes: 76 additions & 9 deletions api/test/unit/serializers/media_serializers_test.py
@@ -1,21 +1,34 @@
import uuid
from test.factory.models.oauth2 import AccessTokenFactory
from unittest.mock import MagicMock

from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
from django.conf import settings
from rest_framework.exceptions import NotAuthenticated, ValidationError
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.views import APIView

import pytest

from catalog.api.serializers.audio_serializers import AudioSerializer
from catalog.api.serializers.image_serializers import ImageSerializer
from catalog.api.serializers.media_serializers import MediaSearchRequestSerializer


# TODO: @sarayourfriend consolidate these with the other
# request factory fixtures into conftest.py
@pytest.fixture
def req():
factory = APIRequestFactory()
request = factory.get("/")
request = Request(request)
return request
def request_factory() -> APIRequestFactory():
request_factory = APIRequestFactory(defaults={"REMOTE_ADDR": "192.0.2.1"})

return request_factory


@pytest.fixture
def access_token():
token = AccessTokenFactory.create()
token.application.verified = True
token.application.save()
return token


@pytest.fixture
Expand All @@ -28,17 +41,71 @@ def hit():
return hit


@pytest.fixture
def authed_request(access_token, request_factory):
request = request_factory.get("/")

force_authenticate(request, token=access_token.token)

return APIView().initialize_request(request)


@pytest.fixture
def anon_request(request_factory):
return APIView().initialize_request(request_factory.get("/"))


@pytest.mark.django_db
@pytest.mark.parametrize(
("page_size", "authenticated"),
(
pytest.param(-1, False, marks=pytest.mark.raises(exception=ValidationError)),
pytest.param(0, False, marks=pytest.mark.raises(exception=ValidationError)),
(1, False),
(settings.MAX_ANONYMOUS_PAGE_SIZE, False),
pytest.param(
settings.MAX_ANONYMOUS_PAGE_SIZE + 1,
False,
marks=pytest.mark.raises(exception=NotAuthenticated),
),
pytest.param(
settings.MAX_AUTHED_PAGE_SIZE,
False,
marks=pytest.mark.raises(exception=NotAuthenticated),
),
pytest.param(-1, True, marks=pytest.mark.raises(exception=ValidationError)),
pytest.param(0, True, marks=pytest.mark.raises(exception=ValidationError)),
(1, True),
(settings.MAX_ANONYMOUS_PAGE_SIZE + 1, True),
(settings.MAX_AUTHED_PAGE_SIZE, True),
pytest.param(
settings.MAX_AUTHED_PAGE_SIZE + 1,
True,
marks=pytest.mark.raises(exception=ValidationError),
),
),
)
def test_page_size_validation(page_size, authenticated, anon_request, authed_request):
request = authed_request if authenticated else anon_request
serializer = MediaSearchRequestSerializer(
context={"request": request}, data={"page_size": page_size}
)
assert serializer.is_valid(raise_exception=True)


@pytest.mark.parametrize(
"serializer_class",
[
AudioSerializer,
ImageSerializer,
],
)
def test_media_serializer_adds_license_url_if_missing(req, hit, serializer_class):
def test_media_serializer_adds_license_url_if_missing(
anon_request, hit, serializer_class
):
# Note that this behaviour is inherited from the parent `MediaSerializer` class, but
# it cannot be tested without a concrete model to test with.

del hit.license_url # without the ``del``, the property is dynamically generated
repr = serializer_class(hit, context={"request": req}).data
repr = serializer_class(hit, context={"request": anon_request}).data
assert repr["license_url"] == "https://creativecommons.org/publicdomain/zero/1.0/"