Skip to content

Commit

Permalink
Merge pull request #1012 from DemocracyLab/ThrottleAPI
Browse files Browse the repository at this point in the history
Added api_view and throttle_class decorators to all views to throttle…
  • Loading branch information
marlonkeating committed Sep 3, 2023
2 parents 12ed1b4 + ee632a7 commit a61094a
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 70 deletions.
173 changes: 170 additions & 3 deletions civictechprojects/tests/test_views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from django.test import TestCase, Client, tag
from datetime import timedelta
from time import sleep
from typing import Literal
from django.core.cache import cache

from django.test import Client, TestCase, tag
from django.urls import reverse
from civictechprojects.models import Project
from django.utils.timezone import now

from civictechprojects.models import Group, Project
from democracylab.models import Contributor


Expand Down Expand Up @@ -45,4 +52,164 @@ def test_allauth_provider_url_resolves(self):

for provider in provider_list:
url = reverse(f'{provider}_login')
self.assertIsNotNone(url)
self.assertIsNotNone(url)

def make_many_requests(
self,
client: Client,
method: Literal['get', 'post', 'delete', 'put', 'patch', 'head'],
params: list[dict], # [{'path': str, 'data': dict}, ...]
):
""" A helper method to make many requests to a given URL; used by tests """
cache.clear() # clear the throttling cache before measuring num of throttled requests
num_succeeded, num_throttled = 0, 0

# cache has "1 second" resolution, so throttling counts API calls not from first call to
# last call but from beginning of second (say, 21:30:45.000) to end of second (21:30:45.999);
# to make a fair test, we start measuring right at the beginning of a second, thus the sleep:
sleep(1 - now().microsecond/1_000_000)

start = now()
for request_params in params:
response = getattr(client, method)(**request_params)
match response.status_code:
case num if 200 <= num <= 401:
num_succeeded += 1
case 429:
num_throttled += 1
case _:
self.assertTrue(False, f'Unexpected response code: {response.status_code}')

elapsed = now() - start
self.assertTrue(elapsed < timedelta(seconds=1), f"Expected requests to finish within 1s but it took {elapsed}")
return num_succeeded, num_throttled

def test_api_views_throttling__get_group(self):
group = Group.objects.create(
group_creator=self.test_user,
group_name='test-name',
is_searchable=True,
)
url = reverse('get_group', kwargs={'group_id': group.id})

for is_authenticated in {True, False}:
client = Client()
if is_authenticated:
client.force_login(self.test_user)

num_succeeded, num_throttled = self.make_many_requests(
client=client,
method='get',
params=[{
'path': url,
}]*12,
)

expect_succeeded = 10 if is_authenticated else 5
self.assertEqual(num_succeeded, expect_succeeded)
self.assertEqual(num_throttled, 12-expect_succeeded)

def test_api_views_throttling__group_create(self):
url = reverse('group_create')

for is_authenticated in {True, False}:
client = Client()
if is_authenticated:
client.force_login(self.test_user)

num_succeeded, num_throttled = self.make_many_requests(
client=client,
method='post',
params=[{
'path': url,
'data': {
'group_name': 'test name',
'group_description': 'test description',
'group_short_description': 'test short description',
},
}]*12,
)

expect_succeeded = 10 if is_authenticated else 5
self.assertEqual(num_succeeded, expect_succeeded)
self.assertEqual(num_throttled, 12-expect_succeeded)

def test_api_views_throttling__group_edit(self):
group = Group.objects.create(
group_creator=self.test_user,
group_name='test-name',
is_searchable=True,
)
url = reverse('group_edit', kwargs={'group_id': group.id})

for is_authenticated in {True, False}:
client = Client()
if is_authenticated:
client.force_login(self.test_user)

num_succeeded, num_throttled = self.make_many_requests(
client=client,
method='post',
params=[{
'path': url,
'data': {
'group_name': f'test name #{i}',
'group_description': 'test description',
'group_short_description': 'test short description',
},
} for i in range(12)],
)

expect_succeeded = 10 if is_authenticated else 5
self.assertEqual(num_succeeded, expect_succeeded)
self.assertEqual(num_throttled, 12-expect_succeeded)

def test_api_views_throttling__group_delete(self):
for is_authenticated in {True, False}:
client = Client()
if is_authenticated:
client.force_login(self.test_user)

groups = Group.objects.bulk_create([
Group(
group_creator=self.test_user,
group_name=f'test-name-{i}',
is_searchable=True,
) for i in range(12)
])
num_succeeded, num_throttled = self.make_many_requests(
client=client,
method='post',
params=[{
'path': reverse('group_delete', kwargs={'group_id': group.id}),
} for group in groups],
)

expect_succeeded = 10 if is_authenticated else 5
self.assertEqual(num_succeeded, expect_succeeded)
self.assertEqual(num_throttled, 12-expect_succeeded)

def test_api_views_throttling__project_delete(self):
for is_authenticated in {True, False}:
client = Client()
if is_authenticated:
client.force_login(self.test_user)

projects = Project.objects.bulk_create([
Project(
project_creator=self.test_user,
project_name=f'test-name-{i}',
is_searchable=True,
) for i in range(12)
])
num_succeeded, num_throttled = self.make_many_requests(
client=client,
method='post',
params=[{
'path': reverse('project_delete', kwargs={'project_id': project.id}),
} for project in projects],
)

expect_succeeded = 10 if is_authenticated else 5
self.assertEqual(num_succeeded, expect_succeeded)
self.assertEqual(num_throttled, 12-expect_succeeded)
Loading

0 comments on commit a61094a

Please sign in to comment.