Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor base collector unit tests. #8663

Merged
merged 1 commit into from
May 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 35 additions & 56 deletions components/collector/tests/base_collectors/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ async def _fetch_measurements(self, mock_async_get_request, number=1, side_effec
await self.collector.collect_metrics(session)
await asyncio.gather(*self.collector.running_tasks) # Wait for the running tasks to finish

def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]]:
"""Create a source."""
def expected_source(self, **kwargs: str | None) -> dict[str, str | None | list[dict[str, str]]]:
"""Create an expected source."""
connection_error = kwargs.get("connection_error")
entities = kwargs.get(
"entities",
Expand All @@ -91,16 +91,27 @@ def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]]
"source_uuid": SOURCE_ID,
}

def expected_measurement(
self,
*,
metric_uuid: str = "metric_uuid",
**expected_source_kwargs,
) -> dict[str, bool | list | str]:
"""Create an expected inserted measurement."""
return {
"has_error": "connection_error" in expected_source_kwargs,
"sources": [self.expected_source(**expected_source_kwargs)],
"metric_uuid": metric_uuid,
"report_uuid": "report1",
}

async def test_fetch_successful(self):
"""Test fetching a test metric."""
mock_async_get_request = AsyncMock()
mock_async_get_request.json.side_effect = [self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

async def test_fetch_without_sources(self):
"""Test fetching measurement for a metric without sources."""
Expand Down Expand Up @@ -142,15 +153,7 @@ async def test_fetch_with_client_error(self):
"""Test fetching measurement when getting measurements fails."""
with patch(self.create_measurement) as post:
await self._fetch_measurements(None, side_effect=[aiohttp.ClientConnectionError("error")])
post.assert_called_once_with(
self.database,
{
"has_error": True,
"sources": [self._source(connection_error="error")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
)
post.assert_called_once_with(self.database, self.expected_measurement(connection_error="error"))

async def test_fetch_with_empty_client_error(self):
"""Test fetching measurement when getting measurements fails with an 'empty' exception.
Expand All @@ -159,26 +162,15 @@ async def test_fetch_with_empty_client_error(self):
"""
with patch(self.create_measurement) as post:
await self._fetch_measurements(None, side_effect=[aiohttp.ClientPayloadError()])
post.assert_called_once_with(
self.database,
{
"has_error": True,
"sources": [self._source(connection_error="ClientPayloadError")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
)
post.assert_called_once_with(self.database, self.expected_measurement(connection_error="ClientPayloadError"))

async def test_fetch_with_post_error(self):
"""Test fetching measurement when posting fails."""
mock_async_get_request = AsyncMock()
mock_async_get_request.json.side_effect = [self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

@patch("asyncio.sleep", AsyncMock(side_effect=[RuntimeError]))
async def test_collect(self):
Expand All @@ -202,10 +194,7 @@ async def test_fetch_twice(self):
mock_async_get_request.json.side_effect = [self.pip_json, self.pip_json]
with patch(self.create_measurement) as post:
await self._fetch_measurements(mock_async_get_request, number=2)
post.assert_called_once_with(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
post.assert_called_once_with(self.database, self.expected_measurement())

@patch.object(config, "MEASUREMENT_LIMIT", 1)
async def test_fetch_in_batches(self):
Expand All @@ -216,14 +205,8 @@ async def test_fetch_in_batches(self):
with patch(self.create_measurement) as post:
self.client["quality_time_db"]["reports"].insert_one(create_report(metric_id=METRIC_ID2))
await self._fetch_measurements(mock_async_get_request, number=2)
expected_call1 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
expected_call2 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid2", "report_uuid": "report1"},
)
expected_call1 = call(self.database, self.expected_measurement())
expected_call2 = call(self.database, self.expected_measurement(metric_uuid="metric_uuid2"))
post.assert_has_calls(calls=[expected_call1, expected_call2])

@patch.object(config, "MEASUREMENT_LIMIT", 1)
Expand All @@ -241,18 +224,14 @@ async def test_prioritize_edited_metrics(self):
self.client["quality_time_db"]["reports"].insert_one(report2)
await self._fetch_measurements(mock_async_get_request, number=2)

expected_call1 = call(
self.database,
{"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"},
)
expected_call1 = call(self.database, self.expected_measurement())
expected_call2 = call(
self.database,
{
"has_error": False,
"sources": [self._source(api_url=edited_url, landing_url=edited_url)],
"metric_uuid": "metric_uuid2",
"report_uuid": "report1",
},
self.expected_measurement(
metric_uuid="metric_uuid2",
api_url=edited_url,
landing_url=edited_url,
),
)
post.assert_has_calls(calls=[expected_call1, expected_call2])

Expand Down Expand Up @@ -281,12 +260,12 @@ async def test_missing_mandatory_parameter_with_default_value(self):
await self._fetch_measurements(mock_async_get_request)
post.assert_called_once_with(
self.database,
{
"has_error": False,
"sources": [self._source(value="0", entities=[], api_url="", landing_url="")],
"metric_uuid": "metric_uuid",
"report_uuid": "report1",
},
self.expected_measurement(
value="0",
entities=[],
api_url="",
landing_url="",
),
)

@patch("pathlib.Path.open", new_callable=mock_open)
Expand Down