From d2ba126404ceb7fbbfca64317086c8058c897f68 Mon Sep 17 00:00:00 2001 From: Daniel Bell <70265071+danielbe11@users.noreply.github.com> Date: Fri, 23 Feb 2024 19:17:49 +0100 Subject: [PATCH] Adding optional SSL verification for druid operator (#37629) * Initial changes * Add tests --------- Co-authored-by: Daniel Bell --- airflow/providers/apache/druid/hooks/druid.py | 9 ++++++++- .../providers/apache/druid/operators/druid.py | 6 ++++++ .../apache/druid/hooks/test_druid.py | 19 +++++++++++++++++++ .../apache/druid/operators/test_druid.py | 13 ++++++++++++- 4 files changed, 45 insertions(+), 2 deletions(-) diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 1c0f809247b49..9ab60a8c082bb 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -53,6 +53,9 @@ class DruidHook(BaseHook): the Druid job for the status of the ingestion job. Must be greater than or equal to 1 :param max_ingestion_time: The maximum ingestion time before assuming the job failed + :param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS + certificate, or a string, in which case it must be a path to a CA bundle to use. + Defaults to True """ def __init__( @@ -60,12 +63,14 @@ def __init__( druid_ingest_conn_id: str = "druid_ingest_default", timeout: int = 1, max_ingestion_time: int | None = None, + verify_ssl: bool | str = True, ) -> None: super().__init__() self.druid_ingest_conn_id = druid_ingest_conn_id self.timeout = timeout self.max_ingestion_time = max_ingestion_time self.header = {"content-type": "application/json"} + self.verify_ssl = verify_ssl if self.timeout < 1: raise ValueError("Druid timeout should be equal or greater than 1") @@ -103,7 +108,9 @@ def submit_indexing_job( url = self.get_conn_url(ingestion_type) self.log.info("Druid ingestion spec: %s", json_index_spec) - req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth()) + req_index = requests.post( + url, data=json_index_spec, headers=self.header, auth=self.get_auth(), verify=self.verify_ssl + ) code = req_index.status_code not_accepted = not (200 <= code < 300) diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py index 080287e5ec613..9a5a411121b80 100644 --- a/airflow/providers/apache/druid/operators/druid.py +++ b/airflow/providers/apache/druid/operators/druid.py @@ -37,6 +37,9 @@ class DruidOperator(BaseOperator): of the ingestion job. Must be greater than or equal to 1 :param max_ingestion_time: The maximum ingestion time before assuming the job failed :param ingestion_type: The ingestion type of the job. Could be IngestionType.Batch or IngestionType.MSQ + :param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS + certificate, or a string, in which case it must be a path to a CA bundle to use. + Defaults to True. """ template_fields: Sequence[str] = ("json_index_file",) @@ -51,6 +54,7 @@ def __init__( timeout: int = 1, max_ingestion_time: int | None = None, ingestion_type: IngestionType = IngestionType.BATCH, + verify_ssl: bool | str = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -59,12 +63,14 @@ def __init__( self.timeout = timeout self.max_ingestion_time = max_ingestion_time self.ingestion_type = ingestion_type + self.verify_ssl = verify_ssl def execute(self, context: Context) -> None: hook = DruidHook( druid_ingest_conn_id=self.conn_id, timeout=self.timeout, max_ingestion_time=self.max_ingestion_time, + verify_ssl=self.verify_ssl, ) self.log.info("Submitting %s", self.json_index_file) hook.submit_indexing_job(self.json_index_file, self.ingestion_type) diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py index 5a389cb710c65..76f332bb97621 100644 --- a/tests/providers/apache/druid/hooks/test_druid.py +++ b/tests/providers/apache/druid/hooks/test_druid.py @@ -96,6 +96,25 @@ def test_submit_sql_based_ingestion_ok(self, requests_mock): assert task_post.called_once assert status_check.called_once + def test_submit_with_correct_ssl_arg(self, requests_mock): + self.db_hook.verify_ssl = "/path/to/ca.crt" + task_post = requests_mock.post( + "http://druid-overlord:8081/druid/indexer/v1/task", + text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', + ) + status_check = requests_mock.get( + "http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status", + text='{"status":{"status": "SUCCESS"}}', + ) + + self.db_hook.submit_indexing_job("Long json file") + + assert task_post.called_once + assert status_check.called_once + if task_post.called_once: + verify_ssl = task_post.request_history[0].verify + assert "/path/to/ca.crt" == verify_ssl + def test_submit_correct_json_body(self, requests_mock): task_post = requests_mock.post( "http://druid-overlord:8081/druid/indexer/v1/task", diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py index f6fba6bffb592..28f9632cd17d9 100644 --- a/tests/providers/apache/druid/operators/test_druid.py +++ b/tests/providers/apache/druid/operators/test_druid.py @@ -102,14 +102,22 @@ def test_init_with_timeout_and_max_ingestion_time(): assert expected_values["max_ingestion_time"] == operator.max_ingestion_time -def test_init_default_timeout(): +def test_init_defaults(): operator = DruidOperator( task_id="spark_submit_job", json_index_file=JSON_INDEX_STR, params={"index_type": "index_hadoop", "datasource": "datasource_prd"}, ) + expected_default_druid_ingest_conn_id = "druid_ingest_default" expected_default_timeout = 1 + expected_default_max_ingestion_time = None + expected_default_ingestion_type = IngestionType.BATCH + expected_default_verify_ssl = True + assert expected_default_druid_ingest_conn_id == operator.conn_id assert expected_default_timeout == operator.timeout + assert expected_default_max_ingestion_time == operator.max_ingestion_time + assert expected_default_ingestion_type == operator.ingestion_type + assert expected_default_verify_ssl == operator.verify_ssl @patch("airflow.providers.apache.druid.operators.druid.DruidHook") @@ -120,6 +128,7 @@ def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook): druid_ingest_conn_id = "druid_ingest_default" max_ingestion_time = 5 timeout = 5 + verify_ssl = "/path/to/ca.crt" operator = DruidOperator( task_id="spark_submit_job", json_index_file=json_index_file, @@ -127,11 +136,13 @@ def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook): timeout=timeout, ingestion_type=IngestionType.MSQ, max_ingestion_time=max_ingestion_time, + verify_ssl=verify_ssl, ) operator.execute(context={}) mock_druid_hook.assert_called_once_with( druid_ingest_conn_id=druid_ingest_conn_id, timeout=timeout, max_ingestion_time=max_ingestion_time, + verify_ssl=verify_ssl, ) mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file, IngestionType.MSQ)