Skip to content

Commit

Permalink
Adding optional SSL verification for druid operator (#37629)
Browse files Browse the repository at this point in the history
* Initial changes

* Add tests

---------

Co-authored-by: Daniel Bell <daniel.bell@skyscanner.net>
  • Loading branch information
danielbe11 and Daniel Bell committed Feb 23, 2024
1 parent 8c05e59 commit d2ba126
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 2 deletions.
9 changes: 8 additions & 1 deletion airflow/providers/apache/druid/hooks/druid.py
Expand Up @@ -53,19 +53,24 @@ class DruidHook(BaseHook):
the Druid job for the status of the ingestion job.
Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS
certificate, or a string, in which case it must be a path to a CA bundle to use.
Defaults to True
"""

def __init__(
self,
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
verify_ssl: bool | str = True,
) -> None:
super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
self.header = {"content-type": "application/json"}
self.verify_ssl = verify_ssl

if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")
Expand Down Expand Up @@ -103,7 +108,9 @@ def submit_indexing_job(
url = self.get_conn_url(ingestion_type)

self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth())
req_index = requests.post(
url, data=json_index_spec, headers=self.header, auth=self.get_auth(), verify=self.verify_ssl
)

code = req_index.status_code
not_accepted = not (200 <= code < 300)
Expand Down
6 changes: 6 additions & 0 deletions airflow/providers/apache/druid/operators/druid.py
Expand Up @@ -37,6 +37,9 @@ class DruidOperator(BaseOperator):
of the ingestion job. Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param ingestion_type: The ingestion type of the job. Could be IngestionType.Batch or IngestionType.MSQ
:param verify_ssl: Either a boolean, in which case it controls whether we verify the server's TLS
certificate, or a string, in which case it must be a path to a CA bundle to use.
Defaults to True.
"""

template_fields: Sequence[str] = ("json_index_file",)
Expand All @@ -51,6 +54,7 @@ def __init__(
timeout: int = 1,
max_ingestion_time: int | None = None,
ingestion_type: IngestionType = IngestionType.BATCH,
verify_ssl: bool | str = True,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand All @@ -59,12 +63,14 @@ def __init__(
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
self.ingestion_type = ingestion_type
self.verify_ssl = verify_ssl

def execute(self, context: Context) -> None:
hook = DruidHook(
druid_ingest_conn_id=self.conn_id,
timeout=self.timeout,
max_ingestion_time=self.max_ingestion_time,
verify_ssl=self.verify_ssl,
)
self.log.info("Submitting %s", self.json_index_file)
hook.submit_indexing_job(self.json_index_file, self.ingestion_type)
19 changes: 19 additions & 0 deletions tests/providers/apache/druid/hooks/test_druid.py
Expand Up @@ -96,6 +96,25 @@ def test_submit_sql_based_ingestion_ok(self, requests_mock):
assert task_post.called_once
assert status_check.called_once

def test_submit_with_correct_ssl_arg(self, requests_mock):
self.db_hook.verify_ssl = "/path/to/ca.crt"
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
text='{"task":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
)
status_check = requests_mock.get(
"http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status",
text='{"status":{"status": "SUCCESS"}}',
)

self.db_hook.submit_indexing_job("Long json file")

assert task_post.called_once
assert status_check.called_once
if task_post.called_once:
verify_ssl = task_post.request_history[0].verify
assert "/path/to/ca.crt" == verify_ssl

def test_submit_correct_json_body(self, requests_mock):
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
Expand Down
13 changes: 12 additions & 1 deletion tests/providers/apache/druid/operators/test_druid.py
Expand Up @@ -102,14 +102,22 @@ def test_init_with_timeout_and_max_ingestion_time():
assert expected_values["max_ingestion_time"] == operator.max_ingestion_time


def test_init_default_timeout():
def test_init_defaults():
operator = DruidOperator(
task_id="spark_submit_job",
json_index_file=JSON_INDEX_STR,
params={"index_type": "index_hadoop", "datasource": "datasource_prd"},
)
expected_default_druid_ingest_conn_id = "druid_ingest_default"
expected_default_timeout = 1
expected_default_max_ingestion_time = None
expected_default_ingestion_type = IngestionType.BATCH
expected_default_verify_ssl = True
assert expected_default_druid_ingest_conn_id == operator.conn_id
assert expected_default_timeout == operator.timeout
assert expected_default_max_ingestion_time == operator.max_ingestion_time
assert expected_default_ingestion_type == operator.ingestion_type
assert expected_default_verify_ssl == operator.verify_ssl


@patch("airflow.providers.apache.druid.operators.druid.DruidHook")
Expand All @@ -120,18 +128,21 @@ def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
druid_ingest_conn_id = "druid_ingest_default"
max_ingestion_time = 5
timeout = 5
verify_ssl = "/path/to/ca.crt"
operator = DruidOperator(
task_id="spark_submit_job",
json_index_file=json_index_file,
druid_ingest_conn_id=druid_ingest_conn_id,
timeout=timeout,
ingestion_type=IngestionType.MSQ,
max_ingestion_time=max_ingestion_time,
verify_ssl=verify_ssl,
)
operator.execute(context={})
mock_druid_hook.assert_called_once_with(
druid_ingest_conn_id=druid_ingest_conn_id,
timeout=timeout,
max_ingestion_time=max_ingestion_time,
verify_ssl=verify_ssl,
)
mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file, IngestionType.MSQ)

0 comments on commit d2ba126

Please sign in to comment.