From a65b1b68ac08e1f52c967865a3537764f4e616bc Mon Sep 17 00:00:00 2001 From: Frank Niessink Date: Fri, 10 May 2024 17:36:54 +0200 Subject: [PATCH] Refactor base collector unit tests. --- .../tests/base_collectors/test_collector.py | 91 +++++++------------ 1 file changed, 35 insertions(+), 56 deletions(-) diff --git a/components/collector/tests/base_collectors/test_collector.py b/components/collector/tests/base_collectors/test_collector.py index 1c2943bb1a..34b7761bbd 100644 --- a/components/collector/tests/base_collectors/test_collector.py +++ b/components/collector/tests/base_collectors/test_collector.py @@ -65,8 +65,8 @@ async def _fetch_measurements(self, mock_async_get_request, number=1, side_effec await self.collector.collect_metrics(session) await asyncio.gather(*self.collector.running_tasks) # Wait for the running tasks to finish - def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]]: - """Create a source.""" + def expected_source(self, **kwargs: str | None) -> dict[str, str | None | list[dict[str, str]]]: + """Create an expected source.""" connection_error = kwargs.get("connection_error") entities = kwargs.get( "entities", @@ -91,16 +91,27 @@ def _source(self, **kwargs: str) -> dict[str, str | None | list[dict[str, str]]] "source_uuid": SOURCE_ID, } + def expected_measurement( + self, + *, + metric_uuid: str = "metric_uuid", + **expected_source_kwargs, + ) -> dict[str, bool | list | str]: + """Create an expected inserted measurement.""" + return { + "has_error": "connection_error" in expected_source_kwargs, + "sources": [self.expected_source(**expected_source_kwargs)], + "metric_uuid": metric_uuid, + "report_uuid": "report1", + } + async def test_fetch_successful(self): """Test fetching a test metric.""" mock_async_get_request = AsyncMock() mock_async_get_request.json.side_effect = [self.pip_json] with patch(self.create_measurement) as post: await self._fetch_measurements(mock_async_get_request) - post.assert_called_once_with( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"}, - ) + post.assert_called_once_with(self.database, self.expected_measurement()) async def test_fetch_without_sources(self): """Test fetching measurement for a metric without sources.""" @@ -142,15 +153,7 @@ async def test_fetch_with_client_error(self): """Test fetching measurement when getting measurements fails.""" with patch(self.create_measurement) as post: await self._fetch_measurements(None, side_effect=[aiohttp.ClientConnectionError("error")]) - post.assert_called_once_with( - self.database, - { - "has_error": True, - "sources": [self._source(connection_error="error")], - "metric_uuid": "metric_uuid", - "report_uuid": "report1", - }, - ) + post.assert_called_once_with(self.database, self.expected_measurement(connection_error="error")) async def test_fetch_with_empty_client_error(self): """Test fetching measurement when getting measurements fails with an 'empty' exception. @@ -159,15 +162,7 @@ async def test_fetch_with_empty_client_error(self): """ with patch(self.create_measurement) as post: await self._fetch_measurements(None, side_effect=[aiohttp.ClientPayloadError()]) - post.assert_called_once_with( - self.database, - { - "has_error": True, - "sources": [self._source(connection_error="ClientPayloadError")], - "metric_uuid": "metric_uuid", - "report_uuid": "report1", - }, - ) + post.assert_called_once_with(self.database, self.expected_measurement(connection_error="ClientPayloadError")) async def test_fetch_with_post_error(self): """Test fetching measurement when posting fails.""" @@ -175,10 +170,7 @@ async def test_fetch_with_post_error(self): mock_async_get_request.json.side_effect = [self.pip_json] with patch(self.create_measurement) as post: await self._fetch_measurements(mock_async_get_request) - post.assert_called_once_with( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"}, - ) + post.assert_called_once_with(self.database, self.expected_measurement()) @patch("asyncio.sleep", AsyncMock(side_effect=[RuntimeError])) async def test_collect(self): @@ -202,10 +194,7 @@ async def test_fetch_twice(self): mock_async_get_request.json.side_effect = [self.pip_json, self.pip_json] with patch(self.create_measurement) as post: await self._fetch_measurements(mock_async_get_request, number=2) - post.assert_called_once_with( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"}, - ) + post.assert_called_once_with(self.database, self.expected_measurement()) @patch.object(config, "MEASUREMENT_LIMIT", 1) async def test_fetch_in_batches(self): @@ -216,14 +205,8 @@ async def test_fetch_in_batches(self): with patch(self.create_measurement) as post: self.client["quality_time_db"]["reports"].insert_one(create_report(metric_id=METRIC_ID2)) await self._fetch_measurements(mock_async_get_request, number=2) - expected_call1 = call( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"}, - ) - expected_call2 = call( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid2", "report_uuid": "report1"}, - ) + expected_call1 = call(self.database, self.expected_measurement()) + expected_call2 = call(self.database, self.expected_measurement(metric_uuid="metric_uuid2")) post.assert_has_calls(calls=[expected_call1, expected_call2]) @patch.object(config, "MEASUREMENT_LIMIT", 1) @@ -241,18 +224,14 @@ async def test_prioritize_edited_metrics(self): self.client["quality_time_db"]["reports"].insert_one(report2) await self._fetch_measurements(mock_async_get_request, number=2) - expected_call1 = call( - self.database, - {"has_error": False, "sources": [self._source()], "metric_uuid": "metric_uuid", "report_uuid": "report1"}, - ) + expected_call1 = call(self.database, self.expected_measurement()) expected_call2 = call( self.database, - { - "has_error": False, - "sources": [self._source(api_url=edited_url, landing_url=edited_url)], - "metric_uuid": "metric_uuid2", - "report_uuid": "report1", - }, + self.expected_measurement( + metric_uuid="metric_uuid2", + api_url=edited_url, + landing_url=edited_url, + ), ) post.assert_has_calls(calls=[expected_call1, expected_call2]) @@ -281,12 +260,12 @@ async def test_missing_mandatory_parameter_with_default_value(self): await self._fetch_measurements(mock_async_get_request) post.assert_called_once_with( self.database, - { - "has_error": False, - "sources": [self._source(value="0", entities=[], api_url="", landing_url="")], - "metric_uuid": "metric_uuid", - "report_uuid": "report1", - }, + self.expected_measurement( + value="0", + entities=[], + api_url="", + landing_url="", + ), ) @patch("pathlib.Path.open", new_callable=mock_open)