Skip to content

Commit

Permalink
Refactor base collector unit tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
fniessink committed May 10, 2024
1 parent 758d584 commit a65b1b6
Showing 1 changed file with 35 additions and 56 deletions.
91 changes: 35 additions & 56 deletions components/collector/tests/base_collectors/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ async def _fetch_measurements(self, mock_async_get_request, number=1, side_effec
await self.collector.collect_metrics(session)
await asyncio.gather(*self.collector.running_tasks) # Wait for the running tasks to finish

def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]]:
"""Create a source."""
def expected_source(self, **kwargs: str | None) -> dict[str, str | None | list[dict[str, str]]]:
"""Create an expected source."""
connection_error = kwargs.get("connection_error")
entities = kwargs.get(
"entities",
Expand All @@ -91,16 +91,27 @@ def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]]
"source_uuid": SOURCE_ID,
}

def expected_measurement(
self,
*,
metric_uuid: str = "metric_uuid",
**expected_source_kwargs,
) -> dict[str, bool | list | str]:
"""Create an expected inserted measurement."""
return {
"has_error": "connection_error" in expected_source_kwargs,
"sources": [self.expected_source(**expected_source_kwargs)],
"metric_uuid": metric_uuid,
"report_uuid": "report1",
}

async def test_fetch_successful(self):
"""Test fetching a test metric."""
mock_async_get_request = AsyncMock()
mock_async_get_request.json.side_effect = [self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

async def test_fetch_without_sources(self):
"""Test fetching measurement for a metric without sources."""
Expand Down Expand Up @@ -142,15 +153,7 @@ async def test_fetch_with_client_error(self):
"""Test fetching measurement when getting measurements fails."""
with patch(self.create_measurement) as post:
await self._fetch_measurements(None, side_effect=[aiohttp.ClientConnectionError("error")])
post.assert_called_once_with(
self.database,
{
"has_error": True,
"sources": [self._source(connection_error="error")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
)
post.assert_called_once_with(self.database, self.expected_measurement(connection_error="error"))

async def test_fetch_with_empty_client_error(self):
"""Test fetching measurement when getting measurements fails with an 'empty' exception.
Expand All @@ -159,26 +162,15 @@ async def test_fetch_with_empty_client_error(self):
"""
with patch(self.create_measurement) as post:
await self._fetch_measurements(None, side_effect=[aiohttp.ClientPayloadError()])
post.assert_called_once_with(
self.database,
{
"has_error": True,
"sources": [self._source(connection_error="ClientPayloadError")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
)
post.assert_called_once_with(self.database, self.expected_measurement(connection_error="ClientPayloadError"))

async def test_fetch_with_post_error(self):
"""Test fetching measurement when posting fails."""
mock_async_get_request = AsyncMock()
mock_async_get_request.json.side_effect = [self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

@patch("asyncio.sleep", AsyncMock(side_effect=[RuntimeError]))
async def test_collect(self):
Expand All @@ -202,10 +194,7 @@ async def test_fetch_twice(self):
mock_async_get_request.json.side_effect = [self.pip_json, self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request, number=2)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

@patch.object(config, "MEASUREMENT_LIMIT", 1)
async def test_fetch_in_batches(self):
Expand All @@ -216,14 +205,8 @@ async def test_fetch_in_batches(self):
with patch(self.create_measurement) as post:
self.client["quality_time_db"]["reports"].insert_one(create_report(metric_id=METRIC_ID2))
await self._fetch_measurements(mock_async_get_request, number=2)
expected_call1 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
expected_call2 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid2", "report_uuid": "report1"},
)
expected_call1 = call(self.database, self.expected_measurement())
expected_call2 = call(self.database, self.expected_measurement(metric_uuid="metric_uuid2"))
post.assert_has_calls(calls=[expected_call1, expected_call2])

@patch.object(config, "MEASUREMENT_LIMIT", 1)
Expand All @@ -241,18 +224,14 @@ async def test_prioritize_edited_metrics(self):
self.client["quality_time_db"]["reports"].insert_one(report2)
await self._fetch_measurements(mock_async_get_request, number=2)

expected_call1 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
expected_call1 = call(self.database, self.expected_measurement())
expected_call2 = call(
self.database,
{
"has_error": False,
"sources": [self._source(api_url=edited_url, landing_url=edited_url)],
"metric_uuid": "metric_uuid2",
"report_uuid": "report1",
},
self.expected_measurement(
metric_uuid="metric_uuid2",
api_url=edited_url,
landing_url=edited_url,
),
)
post.assert_has_calls(calls=[expected_call1, expected_call2])

Expand Down Expand Up @@ -281,12 +260,12 @@ async def test_missing_mandatory_parameter_with_default_value(self):
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{
"has_error": False,
"sources": [self._source(value="0", entities=[], api_url="", landing_url="")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
self.expected_measurement(
value="0",
entities=[],
api_url="",
landing_url="",
),
)

@patch("pathlib.Path.open", new_callable=mock_open)
Expand Down

0 comments on commit a65b1b6

Please sign in to comment.