-
Notifications
You must be signed in to change notification settings - Fork 40
/
raw_sql.py
156 lines (128 loc) · 6.35 KB
/
raw_sql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
from __future__ import annotations
import logging
from collections.abc import Iterable, Mapping
from typing import Any, Callable
try:
from airflow.decorators.base import TaskDecorator
except ImportError:
from airflow.decorators import _TaskDecorator as TaskDecorator
import airflow
from airflow.decorators.base import task_decorator_factory
if airflow.__version__ >= "2.3":
from sqlalchemy.engine.row import LegacyRow as SQLAlcRow
else:
from sqlalchemy.engine.result import RowProxy as SQLAlcRow
from astro import settings
from astro.exceptions import IllegalLoadToDatabaseException
from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator
from astro.utils.compat.typing import Context
class RawSQLOperator(BaseSQLDecoratedOperator):
"""
Given a SQL statement, (optional) tables and a (optional) function, execute the SQL statement
and apply the function to the results, returning the result of the function.
Disclaimer: this could potentially trash the XCom Database, depending on the XCom backend used
and on the SQL statement/function declared by the user.
"""
def execute(self, context: Context) -> Any:
super().execute(context)
result = self.database_impl.run_sql(sql=self.sql, parameters=self.parameters, handler=self.handler)
if self.response_size == -1 and not settings.IS_CUSTOM_XCOM_BACKEND:
logging.warning(
"Using `run_raw_sql` without `response_size` can result in excessive amount of data being recorded "
"to the Airflow metadata database, leading to issues to the orchestration of tasks. It is possible to "
"avoid this problem by either setting `response_size` to a small integer or by using a custom XCom "
"backend."
)
if self.handler and self.database_impl.sql_type == "delta":
return result
elif self.handler:
response = self.handler(result)
response = self.make_row_serializable(response)
if 0 <= self.response_limit < len(response):
raise IllegalLoadToDatabaseException() # pragma: no cover
if self.response_size >= 0:
return response[: self.response_size]
else:
return response
else:
return None
@staticmethod
def make_row_serializable(rows: Any) -> Any:
"""
Convert rows to a serializable format
"""
if not settings.NEED_CUSTOM_SERIALIZATION:
return rows
if isinstance(rows, Iterable):
return [SdkLegacyRow.from_legacy_row(r) if isinstance(r, SQLAlcRow) else r for r in rows]
return rows
class SdkLegacyRow(SQLAlcRow):
version: int = 1
def serialize(self):
return {"key_map": self._keymap, "key_style": self._key_style, "data": self._data}
@staticmethod
def deserialize(data: dict, version: int): # skipcq: PYL-W0613
return SdkLegacyRow(None, None, data["key_map"], data["key_style"], data["data"])
@staticmethod
def from_legacy_row(obj):
return SdkLegacyRow(None, None, obj._keymap, obj._key_style, obj._data) # skipcq: PYL-W0212
def run_raw_sql(
python_callable: Callable | None = None,
conn_id: str = "",
parameters: Mapping | Iterable | None = None,
database: str | None = None,
schema: str | None = None,
handler: Callable | None = None,
response_size: int = settings.RAW_SQL_MAX_RESPONSE_SIZE,
**kwargs: Any,
) -> TaskDecorator:
"""
Given a python function that returns a SQL statement and (optional) tables, execute the SQL statement and output
the result into a SQL table.
Use this function as a decorator like so:
.. code-block:: python
@run_raw_sql
def my_sql_statement(table1: Table) -> Table:
return "DROP TABLE {{table1}}"
In this example, by identifying parameters as ``Table`` objects, astro knows to automatically convert those
objects into tables (if they are, for example, a dataframe). Any type besides table will lead astro to assume
you do not want the parameter converted.
Please note that the ``run_raw_sql`` function will not create a temporary table. It will either return the
result of a provided ``handler`` function or it will not return anything at all.
:param python_callable: This parameter is filled in automatically when you use the transform function as a
decorator. This is where the python function gets passed to the wrapping function
:param conn_id: Connection ID for the database you want to connect to.
If you do not pass in a value for this object we can infer the connection ID from the first table
passed into the python_callable function. (required if there are no table arguments)
:param parameters: parameters to pass into the SQL query
:param database: Database within the SQL instance you want to access. If left blank we will default to the
table.metatadata.database in the first Table passed to the function
(required if there are no table arguments)
:param schema: Schema within the SQL instance you want to access. If left blank we will default to the
table.metatadata.schema in the first Table passed to the function
(required if there are no table arguments)
:param handler: Handler function to process the result of the SQL query. For more information please consult
https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.Result
:param response_size: Used to trim the responses returned to avoid trashing the Airflow DB.
The default value is -1, which means the response is not changed. Otherwise, if the response is a list,
returns up to the desired amount of items. If the response is a string, trims it to the desired size.
:param kwargs:
:return: By default returns None unless there is a handler function,
in which case returns the result of the handler
"""
kwargs.update(
{
"conn_id": conn_id,
"parameters": parameters,
"database": database,
"schema": schema,
"handler": handler,
"response_size": response_size,
}
)
return task_decorator_factory(
python_callable=python_callable,
multiple_outputs=False,
decorated_operator_class=RawSQLOperator,
**kwargs,
)