Skip to content

Commit

Permalink
DruidHook add SQL-based task support (#32795)
Browse files Browse the repository at this point in the history
* DruidHook supports SQL-based task
---------

Co-authored-by: vio.ao <vio.ao@skyscanner.net>
Co-authored-by: Ashish Patel <ashishpatel0720@gmail.com>
  • Loading branch information
3 people committed Aug 6, 2023
1 parent 0d8a24b commit d24933c
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 10 deletions.
33 changes: 27 additions & 6 deletions airflow/providers/apache/druid/hooks/druid.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import time
from enum import Enum
from typing import Any, Iterable

import requests
Expand All @@ -28,6 +29,17 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook


class IngestionType(Enum):
"""
Druid Ingestion Type. Could be Native batch ingestion or SQL-based ingestion.
https://druid.apache.org/docs/latest/ingestion/index.html
"""

BATCH = 1
MSQ = 2


class DruidHook(BaseHook):
"""
Connection to Druid overlord for ingestion.
Expand Down Expand Up @@ -59,13 +71,16 @@ def __init__(
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")

def get_conn_url(self) -> str:
def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str:
"""Get Druid connection url."""
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
conn_type = conn.conn_type or "http"
endpoint = conn.extra_dejson.get("endpoint", "")
if ingestion_type == IngestionType.BATCH:
endpoint = conn.extra_dejson.get("endpoint", "")
else:
endpoint = conn.extra_dejson.get("msq_endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"

def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
Expand All @@ -82,9 +97,11 @@ def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
else:
return None

def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None:
def submit_indexing_job(
self, json_index_spec: dict[str, Any] | str, ingestion_type: IngestionType = IngestionType.BATCH
) -> None:
"""Submit Druid ingestion job."""
url = self.get_conn_url()
url = self.get_conn_url(ingestion_type)

self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth())
Expand All @@ -96,14 +113,18 @@ def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None:

req_json = req_index.json()
# Wait until the job is completed
druid_task_id = req_json["task"]
if ingestion_type == IngestionType.BATCH:
druid_task_id = req_json["task"]
else:
druid_task_id = req_json["taskId"]
druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status"
self.log.info("Druid indexing task-id: %s", druid_task_id)

running = True

sec = 0
while running:
req_status = requests.get(f"{url}/{druid_task_id}/status", auth=self.get_auth())
req_status = requests.get(druid_task_status_url, auth=self.get_auth())

self.log.info("Job still running for %s seconds...", sec)

Expand Down
7 changes: 5 additions & 2 deletions airflow/providers/apache/druid/operators/druid.py
Expand Up @@ -20,7 +20,7 @@
from typing import TYPE_CHECKING, Any, Sequence

from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
from airflow.providers.apache.druid.hooks.druid import DruidHook, IngestionType

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand All @@ -36,6 +36,7 @@ class DruidOperator(BaseOperator):
:param timeout: The interval (in seconds) between polling the Druid job for the status
of the ingestion job. Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param ingestion_type: The ingestion type of the job. Could be IngestionType.Batch or IngestionType.MSQ
"""

template_fields: Sequence[str] = ("json_index_file",)
Expand All @@ -49,13 +50,15 @@ def __init__(
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
ingestion_type: IngestionType = IngestionType.BATCH,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.json_index_file = json_index_file
self.conn_id = druid_ingest_conn_id
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
self.ingestion_type = ingestion_type

def execute(self, context: Context) -> None:
hook = DruidHook(
Expand All @@ -64,4 +67,4 @@ def execute(self, context: Context) -> None:
max_ingestion_time=self.max_ingestion_time,
)
self.log.info("Submitting %s", self.json_index_file)
hook.submit_indexing_job(self.json_index_file)
hook.submit_indexing_job(self.json_index_file, self.ingestion_type)
1 change: 1 addition & 0 deletions docs/apache-airflow-providers-apache-druid/operators.rst
Expand Up @@ -29,6 +29,7 @@ DruidOperator
-------------------

Submit a task directly to Druid, you need to provide the filepath to the Druid index specification ``json_index_file``, and the connection id of the Druid overlord ``druid_ingest_conn_id`` which accepts index jobs in Airflow Connections.
In addition, you can provide the ingestion type ``ingestion_type`` to determine whether the job is Batch Ingestion or SQL-based ingestion.

There is also a example content of the Druid Ingestion specification below.

Expand Down
35 changes: 33 additions & 2 deletions tests/providers/apache/druid/hooks/test_druid.py
Expand Up @@ -23,7 +23,7 @@
import requests

from airflow.exceptions import AirflowException
from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook
from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook, IngestionType


class TestDruidHook:
Expand All @@ -35,7 +35,11 @@ def setup_method(self):
session.mount("mock", adapter)

class TestDRuidhook(DruidHook):
def get_conn_url(self):
self.is_sql_based_ingestion = False

def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH):
if ingestion_type == IngestionType.MSQ:
return "http://druid-overlord:8081/druid/v2/sql/task"
return "http://druid-overlord:8081/druid/indexer/v1/task"

self.db_hook = TestDRuidhook()
Expand Down Expand Up @@ -73,6 +77,22 @@ def test_submit_ok(self, requests_mock):
assert task_post.called_once
assert status_check.called_once

def test_submit_sql_based_ingestion_ok(self, requests_mock):
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/v2/sql/task",
text='{"taskId":"9f8a7359-77d4-4612-b0cd-cc2f6a3c28de"}',
)
status_check = requests_mock.get(
"http://druid-overlord:8081/druid/indexer/v1/task/9f8a7359-77d4-4612-b0cd-cc2f6a3c28de/status",
text='{"status":{"status": "SUCCESS"}}',
)

# Exists just as it should
self.db_hook.submit_indexing_job("Long json file", IngestionType.MSQ)

assert task_post.called_once
assert status_check.called_once

def test_submit_correct_json_body(self, requests_mock):
task_post = requests_mock.post(
"http://druid-overlord:8081/druid/indexer/v1/task",
Expand Down Expand Up @@ -149,6 +169,17 @@ def test_get_conn_url(self, mock_get_connection):
hook = DruidHook(timeout=1, max_ingestion_time=5)
assert hook.get_conn_url() == "https://test_host:1/ingest"

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_conn_url_with_ingestion_type(self, mock_get_connection):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.port = "1"
get_conn_value.extra_dejson = {"endpoint": "ingest", "msq_endpoint": "sql_ingest"}
mock_get_connection.return_value = get_conn_value
hook = DruidHook(timeout=1, max_ingestion_time=5)
assert hook.get_conn_url(IngestionType.MSQ) == "https://test_host:1/sql_ingest"

@patch("airflow.providers.apache.druid.hooks.druid.DruidHook.get_connection")
def test_get_auth(self, mock_get_connection):
get_conn_value = MagicMock()
Expand Down
27 changes: 27 additions & 0 deletions tests/providers/apache/druid/operators/test_druid.py
Expand Up @@ -18,7 +18,9 @@
from __future__ import annotations

import json
from unittest.mock import MagicMock, patch

from airflow.providers.apache.druid.hooks.druid import IngestionType
from airflow.providers.apache.druid.operators.druid import DruidOperator
from airflow.utils import timezone
from airflow.utils.types import DagRunType
Expand Down Expand Up @@ -104,3 +106,28 @@ def test_init_default_timeout():
)
expected_default_timeout = 1
assert expected_default_timeout == operator.timeout


@patch("airflow.providers.apache.druid.operators.druid.DruidHook")
def test_execute_calls_druid_hook_with_the_right_parameters(mock_druid_hook):
mock_druid_hook_instance = MagicMock()
mock_druid_hook.return_value = mock_druid_hook_instance
json_index_file = "sql.json"
druid_ingest_conn_id = "druid_ingest_default"
max_ingestion_time = 5
timeout = 5
operator = DruidOperator(
task_id="spark_submit_job",
json_index_file=json_index_file,
druid_ingest_conn_id=druid_ingest_conn_id,
timeout=timeout,
ingestion_type=IngestionType.MSQ,
max_ingestion_time=max_ingestion_time,
)
operator.execute(context={})
mock_druid_hook.assert_called_once_with(
druid_ingest_conn_id=druid_ingest_conn_id,
timeout=timeout,
max_ingestion_time=max_ingestion_time,
)
mock_druid_hook_instance.submit_indexing_job.assert_called_once_with(json_index_file, IngestionType.MSQ)

0 comments on commit d24933c

Please sign in to comment.