Skip to content

Commit

Permalink
Use different rate limiter
Browse files Browse the repository at this point in the history
  • Loading branch information
jdddog committed May 17, 2023
1 parent 586ef06 commit 73c9bc7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datetime
import logging
import os
import time
from concurrent.futures import ProcessPoolExecutor, as_completed, ThreadPoolExecutor
from datetime import timedelta
from typing import List, Dict, Tuple
Expand All @@ -32,7 +33,9 @@
from google.cloud import bigquery
from google.cloud.bigquery import SourceFormat
from importlib_metadata import metadata
from ratelimit import limits, sleep_and_retry
from limits import RateLimitItemPerSecond
from limits.storage import storage_from_string
from limits.strategies import FixedWindowElasticExpiryRateLimiter

from academic_observatory_workflows.config import schema_folder as default_schema_folder, Tag
from observatory.api.client.model.dataset_release import DatasetRelease
Expand All @@ -59,6 +62,9 @@
CROSSREF_EVENTS_HOST = "https://api.eventdata.crossref.org/v1/events"
DATE_FORMAT = "YYYY-MM-DD"

backend = storage_from_string("memory://")
moving_window = FixedWindowElasticExpiryRateLimiter(backend)


class CrossrefEventsRelease(ChangefileRelease):
def __init__(
Expand Down Expand Up @@ -593,7 +599,11 @@ def __init__(self, day: pendulum.DateTime, mailto: str):
def make_crossref_events_url(
*, action: str, start_date: pendulum.DateTime, end_date: pendulum.DateTime, mailto: str, rows: int, cursor: str
):
assert action in {"created", "edited", "deleted"}, f"make_crossref_events_url: unknown action={action}, must be one of created, edited or deleted"
assert action in {
"created",
"edited",
"deleted",
}, f"make_crossref_events_url: unknown action={action}, must be one of created, edited or deleted"

# Set query params and set path in URL
url = CROSSREF_EVENTS_HOST
Expand Down Expand Up @@ -763,11 +773,19 @@ def fetch_events(request: EventRequest, cursor: str = None, n_rows: int = 1000)
return events, next_cursor


@sleep_and_retry
@limits(calls=15, period=1)
def crossref_events_limiter():
"""Task to throttle the calls to the Crossref Events API"""
return
def crossref_events_limiter(calls_per_second: int = 12):
"""Function to throttle the calls to the Crossref Events API"""

identifier = "crossref_events_limiter"
item = RateLimitItemPerSecond(calls_per_second) # 12 per second

while True:
if not moving_window.test(item, identifier):
time.sleep(0.01)
else:
break

moving_window.hit(item, identifier)


def transform_events(download_path: str, transform_folder: str):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import json
import os
from concurrent.futures import as_completed, ThreadPoolExecutor

import pendulum
import responses
Expand All @@ -30,6 +31,7 @@
make_day_requests,
parse_release_msg,
EventRequest,
crossref_events_limiter,
)
from observatory.platform.api import get_dataset_releases
from observatory.platform.observatory_config import Workflow
Expand Down Expand Up @@ -111,8 +113,8 @@ def test_telescope(self):
dag_id=self.dag_id,
cloud_workspace=env.cloud_workspace,
bq_dataset_id=bq_dataset_id,
max_threads=1,
max_processes=1,
max_threads=3,
max_processes=3,
events_start_date=start_date,
n_rows=3, # needs to be 3 because that is what the mocked data uses
)
Expand Down Expand Up @@ -356,7 +358,7 @@ def test_telescope(self):
ti = env.run_task(workflow.bq_delete_records.__name__)
self.assertEqual(State.SUCCESS, ti.state)

# TODO: assert that we have correct dataset state
# Assert that we have correct dataset state
expected_content = load_and_parse_json(
test_fixtures_folder(self.dag_id, "run2-expected.json"),
date_fields={"occurred_at", "timestamp", "updated_date"},
Expand All @@ -380,11 +382,31 @@ def test_telescope(self):
self.assertEqual(State.SUCCESS, ti.state)


class TestOpenAlexUtils(ObservatoryTestCase):
# TODO: implement rate limit / retry / 429 backoff and test it
class TestCrossrefEventsUtils(ObservatoryTestCase):
def test_crossref_events_limiter(self):
n_per_second = 12

def my_func():
crossref_events_limiter(n_per_second)
print("Called my_func")

def test_get_event_date(self):
self.fail()
num_calls = 1000
max_workers = 100
expected_wait = num_calls / n_per_second
print(f"test_crossref_events_limiter: expected wait time {expected_wait}s")
start = datetime.datetime.now()
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for event in range(num_calls):
futures.append(executor.submit(my_func))
for future in as_completed(futures):
future.result()

end = datetime.datetime.now()
duration = (end - start).total_seconds()
actual_n_per_second = 1 / (duration / num_calls)
print(f"test_crossref_events_limiter: actual_n_per_second {actual_n_per_second}")
self.assertAlmostEqual(float(n_per_second), actual_n_per_second, delta=2.5)

def test_event_request(self):
day = pendulum.datetime(2023, 1, 1)
Expand Down Expand Up @@ -439,105 +461,3 @@ def test_make_day_requests(self):
],
[req.date for req in requests],
)

def test_download_events(self):
self.fail()

def test_fetch_events(self):
self.fail()

def test_transform_events(self):
self.fail()

def test_transform_event(self):
self.fail()

# @patch.object(CrossrefEventsRelease, "download_batch")
# @patch("observatory.platform.utils.workflow_utils.Variable.get")
# def test_download(self, mock_variable_get, mock_download_batch):
# """Test the download method of the release in parallel mode
# :return: None.
# """
# mock_variable_get.return_value = "data"
# with CliRunner().isolated_filesystem():
# # Test download without any events returned
# with self.assertRaises(AirflowSkipException):
# self.release.download()
#
# # Test download with events returned
# mock_download_batch.reset_mock()
# events_path = os.path.join(self.release.download_folder, "events.jsonl")
# with open(events_path, "w") as f:
# f.write("[{'test': 'test'}]\n")
#
# self.release.download()
# self.assertEqual(len(self.release.urls), mock_download_batch.call_count)
#
# @patch("academic_observatory_workflows.workflows.crossref_events_telescope.download_events")
# @patch("observatory.platform.utils.workflow_utils.Variable.get")
# def test_download_batch(self, mock_variable_get, mock_download_events):
# """Test download_batch function
# :return: None.
# """
# mock_variable_get.return_value = os.path.join(os.getcwd(), "data")
# self.release.first_release = True
# batch_number = 0
# url = self.release.urls[batch_number]
# headers = {"User-Agent": get_user_agent(package_name="academic_observatory_workflows")}
# with CliRunner().isolated_filesystem():
# events_path = self.release.batch_path(url)
# cursor_path = self.release.batch_path(url, cursor=True)
#
# # Test with existing cursor path
# with open(cursor_path, "w") as f:
# f.write("cursor")
# mock_download_events.return_value = (None, 10, 10)
# self.release.download_batch(batch_number, url)
# self.assertFalse(os.path.exists(cursor_path))
# mock_download_events.assert_called_once_with(url, headers, events_path, cursor_path)
#
# # Test with no existing previous files
# mock_download_events.reset_mock()
# mock_download_events.return_value = (None, 10, 10)
# self.release.download_batch(batch_number, url)
# mock_download_events.assert_called_once_with(url, headers, events_path, cursor_path)
#
# # Test with events path and no cursor path, so previous successful attempt
# mock_download_events.reset_mock()
# with open(events_path, "w") as f:
# f.write("events")
# self.release.download_batch(batch_number, url)
# mock_download_events.assert_not_called()
# os.remove(events_path)
#
# @patch("observatory.platform.utils.workflow_utils.Variable.get")
# def test_transform_batch(self, mock_variable_get):
# """Test the transform_batch method of the release
# :return: None.
# """
#
# with CliRunner().isolated_filesystem() as t:
# mock_variable_get.return_value = os.path.join(t, "data")
#
# # Use release info so that we can download the right data
# release = CrossrefEventsRelease(
# "crossref_events",
# pendulum.datetime(2018, 5, 14),
# pendulum.datetime(2018, 5, 19),
# True,
# metadata("academic_observatory_workflows").get("Author-email"),
# max_threads=1,
# max_processes=1,
# )
#
# # Download files
# with vcr.use_cassette(self.first_cassette):
# with self.retry_get_url_patch:
# release.download()
#
# # Transform batch
# for file_path in release.download_files:
# transform_batch(file_path, release.transform_folder)
#
# # Assert all transformed
# self.assertEqual(len(release.download_files), len(release.transform_files))
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ beautifulsoup4>=4.9.3,<5
boto3>=1.15.0,<2
nltk==3.*
Deprecated>1,<2
responses>=0.23.1,<1
limits>3,<4

0 comments on commit 73c9bc7

Please sign in to comment.