diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py index 5b5c814fb5885..7708684e60cf5 100644 --- a/airflow/providers/apache/druid/hooks/druid.py +++ b/airflow/providers/apache/druid/hooks/druid.py @@ -18,6 +18,7 @@ from __future__ import annotations import time +from enum import Enum from typing import Any, Iterable import requests @@ -28,6 +29,17 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook +class IngestionType(Enum): + """ + Druid Ingestion Type. Could be Native batch ingestion or SQL-based ingestion. + + https://druid.apache.org/docs/latest/ingestion/index.html + """ + + BATCH = 1 + MSQ = 2 + + class DruidHook(BaseHook): """ Connection to Druid overlord for ingestion. @@ -59,13 +71,16 @@ def __init__( if self.timeout < 1: raise ValueError("Druid timeout should be equal or greater than 1") - def get_conn_url(self) -> str: + def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str: """Get Druid connection url.""" conn = self.get_connection(self.druid_ingest_conn_id) host = conn.host port = conn.port conn_type = conn.conn_type or "http" - endpoint = conn.extra_dejson.get("endpoint", "") + if ingestion_type == IngestionType.BATCH: + endpoint = conn.extra_dejson.get("endpoint", "") + else: + endpoint = conn.extra_dejson.get("msq_endpoint", "") return f"{conn_type}://{host}:{port}/{endpoint}" def get_auth(self) -> requests.auth.HTTPBasicAuth | None: @@ -82,9 +97,11 @@ def get_auth(self) -> requests.auth.HTTPBasicAuth | None: else: return None - def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None: + def submit_indexing_job( + self, json_index_spec: dict[str, Any] | str, ingestion_type: IngestionType = IngestionType.BATCH + ) -> None: """Submit Druid ingestion job.""" - url = self.get_conn_url() + url = self.get_conn_url(ingestion_type) self.log.info("Druid ingestion spec: %s", json_index_spec) req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth()) @@ -96,14 +113,18 @@ def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None: req_json = req_index.json() # Wait until the job is completed - druid_task_id = req_json["task"] + if ingestion_type == IngestionType.BATCH: + druid_task_id = req_json["task"] + else: + druid_task_id = req_json["taskId"] + druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status" self.log.info("Druid indexing task-id: %s", druid_task_id) running = True sec = 0 while running: - req_status = requests.get(f"{url}/{druid_task_id}/status", auth=self.get_auth()) + req_status = requests.get(druid_task_status_url, auth=self.get_auth()) self.log.info("Job still running for %s seconds...", sec) diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py index 7e1cd607999a6..080287e5ec613 100644 --- a/airflow/providers/apache/druid/operators/druid.py +++ b/airflow/providers/apache/druid/operators/druid.py @@ -20,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator -from airflow.providers.apache.druid.hooks.druid import DruidHook +from airflow.providers.apache.druid.hooks.druid import DruidHook, IngestionType if TYPE_CHECKING: from airflow.utils.context import Context @@ -36,6 +36,7 @@ class DruidOperator(BaseOperator): :param timeout: The interval (in seconds) between polling the Druid job for the status of the ingestion job. Must be greater than or equal to 1 :param max_ingestion_time: The maximum ingestion time before assuming the job failed + :param ingestion_type: The ingestion type of the job. Could be IngestionType.Batch or IngestionType.MSQ """ template_fields: Sequence[str] = ("json_index_file",) @@ -49,6 +50,7 @@ def __init__( druid_ingest_conn_id: str = "druid_ingest_default", timeout: int = 1, max_ingestion_time: int | None = None, + ingestion_type: IngestionType = IngestionType.BATCH, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -56,6 +58,7 @@ def __init__( self.conn_id = druid_ingest_conn_id self.timeout = timeout self.max_ingestion_time = max_ingestion_time + self.ingestion_type = ingestion_type def execute(self, context: Context) -> None: hook = DruidHook( @@ -64,4 +67,4 @@ def execute(self, context: Context) -> None: max_ingestion_time=self.max_ingestion_time, ) self.log.info("Submitting %s", self.json_index_file) - hook.submit_indexing_job(self.json_index_file) + hook.submit_indexing_job(self.json_index_file, self.ingestion_type) diff --git a/docs/apache-airflow-providers-apache-druid/operators.rst b/docs/apache-airflow-providers-apache-druid/operators.rst index 1d2f1c022ab67..6930e7b4d3ae3 100644 --- a/docs/apache-airflow-providers-apache-druid/operators.rst +++ b/docs/apache-airflow-providers-apache-druid/operators.rst @@ -29,6 +29,7 @@ DruidOperator ------------------- Submit a task directly to Druid, you need to provide the filepath to the Druid index specification ``json_index_file``, and the connection id of the Druid overlord ``druid_ingest_conn_id`` which accepts index jobs in Airflow Connections. +In addition, you can provide the ingestion type ``ingestion_type`` to determine whether the job is Batch Ingestion or SQL-based ingestion. There is also a example content of the Druid Ingestion specification below. diff --git a/tests/providers/apache/druid/hooks/test_druid.py b/tests/providers/apache/druid/hooks/test_druid.py index 7d97f857b44d0..38f26f06cc380 100644 --- a/tests/providers/apache/druid/hooks/test_druid.py +++ b/tests/providers/apache/druid/hooks/test_druid.py @@ -23,7 +23,7 @@ import requests from airflow.exceptions import AirflowException -from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook +from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook, IngestionType class TestDruidHook: @@ -35,7 +35,11 @@ def setup_method(self): session.mount("mock", adapter) class TestDRuidhook(DruidHook): - def get_conn_url(self): + self.is_sql_based_ingestion = False + + def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH): + if ingestion_type == IngestionType.MSQ: + return "http://druid-overlord:8081/druid/v2/sql/task" return "http://druid-overlord:8081/druid/indexer/v1/task" self.db_hook = TestDRuidhook() @@ -73,6 +77,22 @@ def test_submit_ok(self, requests_mock): assert task_post.called_once assert status_check.called_once + def test_submit_sql_based_ingestion_ok(self, requests_mock): + task_post = requests_mock.post( + "http://druid-overlord:8081/druid/v2/sql/task", + text='{"taskId":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}', + ) + status_check = requests_mock.get( + "http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status", + text='{"status":{"status": "SUCCESS"}}', + ) + + # Exists just as it should + self.db_hook.submit_indexing_job("Long json file", IngestionType.MSQ) + + assert task_post.called_once + assert status_check.called_once + def test_submit_correct_json_body(self, requests_mock): task_post = requests_mock.post( "http://druid-overlord:8081/druid/indexer/v1/task", @@ -149,6 +169,17 @@ def test_get_conn_url(self, mock_get_connection): hook = DruidHook(timeout=1, max_ingestion_time=5) assert hook.get_conn_url() == "https://test_host:1/ingest" + @patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection") + def test_get_conn_url_with_ingestion_type(self, mock_get_connection): + get_conn_value = MagicMock() + get_conn_value.host = "test_host" + get_conn_value.conn_type = "https" + get_conn_value.port = "1" + get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": "sql_ingest"} + mock_get_connection.return_value = get_conn_value + hook = DruidHook(timeout=1, max_ingestion_time=5) + assert hook.get_conn_url(IngestionType.MSQ) == "https://test_host:1/sql_ingest" + @patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection") def test_get_auth(self, mock_get_connection): get_conn_value = MagicMock() diff --git a/tests/providers/apache/druid/operators/test_druid.py b/tests/providers/apache/druid/operators/test_druid.py index 30d14f4f1a8c1..86f4650f70e44 100644 --- a/tests/providers/apache/druid/operators/test_druid.py +++ b/tests/providers/apache/druid/operators/test_druid.py @@ -18,7 +18,9 @@ from __future__ import annotations import json +from unittest.mock import MagicMock, patch +from airflow.providers.apache.druid.hooks.druid import IngestionType from airflow.providers.apache.druid.operators.druid import DruidOperator from airflow.utils import timezone from airflow.utils.types import DagRunType @@ -104,3 +106,28 @@ def test_init_default_timeout(): ) expected_default_timeout = 1 assert expected_default_timeout == operator.timeout + + +@patch("airflow.providers.apache.druid.operators.druid.DruidHook") +def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook): + mock_druid_hook_instance = MagicMock() + mock_druid_hook.return_value = mock_druid_hook_instance + json_index_file = "sql.json" + druid_ingest_conn_id = "druid_ingest_default" + max_ingestion_time = 5 + timeout = 5 + operator = DruidOperator( + task_id="spark_submit_job", + json_index_file=json_index_file, + druid_ingest_conn_id=druid_ingest_conn_id, + timeout=timeout, + ingestion_type=IngestionType.MSQ, + max_ingestion_time=max_ingestion_time, + ) + operator.execute(context={}) + mock_druid_hook.assert_called_once_with( + druid_ingest_conn_id=druid_ingest_conn_id, + timeout=timeout, + max_ingestion_time=max_ingestion_time, + ) + mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file, IngestionType.MSQ)