/
druid.py
238 lines (195 loc) · 8.68 KB
/
druid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import time
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Iterable
import requests
from pydruid.db import connect
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
from airflow.models import Connection
class IngestionType(Enum):
"""
Druid Ingestion Type. Could be Native batch ingestion or SQL-based ingestion.
https://druid.apache.org/docs/latest/ingestion/index.html
"""
BATCH = 1
MSQ = 2
class DruidHook(BaseHook):
"""
Connection to Druid overlord for ingestion.
To connect to a Druid cluster that is secured with the druid-basic-security
extension, add the username and password to the druid ingestion connection.
:param druid_ingest_conn_id: The connection id to the Druid overlord machine
which accepts index jobs
:param timeout: The interval between polling
the Druid job for the status of the ingestion job.
Must be greater than or equal to 1
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
:param verify_ssl: Whether to use SSL encryption to submit indexing job. If set to False then checks
connection information for path to a CA bundle to use. Defaults to True
"""
def __init__(
self,
druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
max_ingestion_time: int | None = None,
verify_ssl: bool = True,
) -> None:
super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
self.header = {"content-type": "application/json"}
self.verify_ssl = verify_ssl
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")
@cached_property
def conn(self) -> Connection:
return self.get_connection(self.druid_ingest_conn_id)
def get_conn_url(self, ingestion_type: IngestionType = IngestionType.BATCH) -> str:
"""Get Druid connection url."""
host = self.conn.host
port = self.conn.port
conn_type = self.conn.conn_type or "http"
if ingestion_type == IngestionType.BATCH:
endpoint = self.conn.extra_dejson.get("endpoint", "")
else:
endpoint = self.conn.extra_dejson.get("msq_endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"
def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
"""
Return username and password from connections tab as requests.auth.HTTPBasicAuth object.
If these details have not been set then returns None.
"""
user = self.conn.login
password = self.conn.password
if user is not None and password is not None:
return requests.auth.HTTPBasicAuth(user, password)
else:
return None
def get_verify(self) -> bool | str:
ca_bundle_path: str | None = self.conn.extra_dejson.get("ca_bundle_path", None)
if not self.verify_ssl and ca_bundle_path:
self.log.info("Using CA bundle to verify connection")
return ca_bundle_path
return self.verify_ssl
def submit_indexing_job(
self, json_index_spec: dict[str, Any] | str, ingestion_type: IngestionType = IngestionType.BATCH
) -> None:
"""Submit Druid ingestion job."""
url = self.get_conn_url(ingestion_type)
self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(
url, data=json_index_spec, headers=self.header, auth=self.get_auth(), verify=self.get_verify()
)
code = req_index.status_code
not_accepted = not (200 <= code < 300)
if not_accepted:
self.log.error("Error submitting the Druid job to %s (%s) %s", url, code, req_index.content)
raise AirflowException(f"Did not get 200 or 202 when submitting the Druid job to {url}")
req_json = req_index.json()
# Wait until the job is completed
if ingestion_type == IngestionType.BATCH:
druid_task_id = req_json["task"]
else:
druid_task_id = req_json["taskId"]
druid_task_status_url = f"{self.get_conn_url()}/{druid_task_id}/status"
self.log.info("Druid indexing task-id: %s", druid_task_id)
running = True
sec = 0
while running:
req_status = requests.get(druid_task_status_url, auth=self.get_auth(), verify=self.get_verify())
self.log.info("Job still running for %s seconds...", sec)
if self.max_ingestion_time and sec > self.max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
requests.post(
f"{url}/{druid_task_id}/shutdown", auth=self.get_auth(), verify=self.get_verify()
)
raise AirflowException(f"Druid ingestion took more than {self.max_ingestion_time} seconds")
time.sleep(self.timeout)
sec += self.timeout
status = req_status.json()["status"]["status"]
if status == "RUNNING":
running = True
elif status == "SUCCESS":
running = False # Great success!
elif status == "FAILED":
raise AirflowException("Druid indexing job failed, check console for more info")
else:
raise AirflowException(f"Could not get status of the job, got {status}")
self.log.info("Successful index")
class DruidDbApiHook(DbApiHook):
"""
Interact with Druid broker.
This hook is purely for users to query druid broker.
For ingestion, please use druidHook.
:param context: Optional query context parameters to pass to the SQL endpoint.
Example: ``{"sqlFinalizeOuterSketches": True}``
See: https://druid.apache.org/docs/latest/querying/sql-query-context/
"""
conn_name_attr = "druid_broker_conn_id"
default_conn_name = "druid_broker_default"
conn_type = "druid"
hook_name = "Druid"
supports_autocommit = False
def __init__(self, context: dict | None = None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.context = context or {}
def get_conn(self) -> connect:
"""Establish a connection to druid broker."""
conn = self.get_connection(getattr(self, self.conn_name_attr))
druid_broker_conn = connect(
host=conn.host,
port=conn.port,
path=conn.extra_dejson.get("endpoint", "/druid/v2/sql"),
scheme=conn.extra_dejson.get("schema", "http"),
user=conn.login,
password=conn.password,
context=self.context,
)
self.log.info("Get the connection to druid broker on %s using user %s", conn.host, conn.login)
return druid_broker_conn
def get_uri(self) -> str:
"""
Get the connection uri for druid broker.
e.g: druid://localhost:8082/druid/v2/sql/
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
host = conn.host
if conn.port is not None:
host += f":{conn.port}"
conn_type = conn.conn_type or "druid"
endpoint = conn.extra_dejson.get("endpoint", "druid/v2/sql")
return f"{conn_type}://{host}/{endpoint}"
def set_autocommit(self, conn: connect, autocommit: bool) -> None:
raise NotImplementedError()
def insert_rows(
self,
table: str,
rows: Iterable[tuple[str]],
target_fields: Iterable[str] | None = None,
commit_every: int = 1000,
replace: bool = False,
**kwargs: Any,
) -> None:
raise NotImplementedError()