Skip to content

Commit

Permalink
Add as_dict param to BigQueryGetDataOperator (#30887)
Browse files Browse the repository at this point in the history
* Add "as_dict" param to BigQueryGetDataOperator
  • Loading branch information
shahar1 committed May 9, 2023
1 parent dff7e0d commit b8f7376
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 22 deletions.
12 changes: 9 additions & 3 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -3125,20 +3125,26 @@ async def create_job_for_partition_get(
job_query_resp = await job_client.query(query_request, cast(Session, session))
return job_query_resp["jobReference"]["jobId"]

def get_records(self, query_results: dict[str, Any]) -> list[Any]:
def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]:
"""
Given the output query response from gcloud-aio bigquery, convert the response to records.
:param query_results: the results from a SQL query
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
"""
buffer = []
buffer: list[Any] = []
if "rows" in query_results and query_results["rows"]:
rows = query_results["rows"]
fields = query_results["schema"]["fields"]
col_types = [field["type"] for field in fields]
for dict_row in rows:
typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])]
buffer.append(typed_row)
if not as_dict:
buffer.append(typed_row)
else:
fields_names = [field["name"] for field in fields]
typed_row_dict = {k: v for k, v in zip(fields_names, typed_row)}
buffer.append(typed_row_dict)
return buffer

def value_check(
Expand Down
35 changes: 25 additions & 10 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -758,12 +758,19 @@ def execute(self, context=None):

class BigQueryGetDataOperator(GoogleCloudBaseOperator):
"""
Fetches the data from a BigQuery table (alternatively fetch data for selected columns)
and returns data in a python list. The number of elements in the returned list will
be equal to the number of rows fetched. Each element in the list will again be a list
where element would represent the columns values for that row.
Fetches the data from a BigQuery table (alternatively fetch data for selected columns) and returns data
in either of the following two formats, based on "as_dict" value:
1. False (Default) - A Python list of lists, with the number of nested lists equal to the number of rows
fetched. Each nested list represents a row, where the elements within it correspond to the column values
for that particular row.
**Example Result**: ``[['Tony', '10'], ['Mike', '20'], ['Steve', '15']]``
**Example Result**: ``[['Tony', 10], ['Mike', 20]``
2. True - A Python list of dictionaries, where each dictionary represents a row. In each dictionary,
the keys are the column names and the values are the corresponding values for those columns.
**Example Result**: ``[{'name': 'Tony', 'age': 10}, {'name': 'Mike', 'age': 20}]``
.. seealso::
For more information on how to use this operator, take a look at the guide:
Expand Down Expand Up @@ -810,6 +817,8 @@ class BigQueryGetDataOperator(GoogleCloudBaseOperator):
:param deferrable: Run operator in the deferrable mode
:param poll_interval: (Deferrable mode only) polling period in seconds to check for the status of job.
Defaults to 4 seconds.
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
"""

template_fields: Sequence[str] = (
Expand All @@ -835,6 +844,7 @@ def __init__(
impersonation_chain: str | Sequence[str] | None = None,
deferrable: bool = False,
poll_interval: float = 4.0,
as_dict: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -849,6 +859,7 @@ def __init__(
self.project_id = project_id
self.deferrable = deferrable
self.poll_interval = poll_interval
self.as_dict = as_dict

def _submit_job(
self,
Expand Down Expand Up @@ -884,7 +895,6 @@ def execute(self, context: Context):
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
self.hook = hook

if not self.deferrable:
self.log.info(
Expand All @@ -910,21 +920,26 @@ def execute(self, context: Context):

self.log.info("Total extracted rows: %s", len(rows))

table_data = [row.values() for row in rows]
if self.as_dict:
table_data = [{k: v for k, v in row.items()} for row in rows]
else:
table_data = [row.values() for row in rows]

return table_data

job = self._submit_job(hook, job_id="")
self.job_id = job.job_id
context["ti"].xcom_push(key="job_id", value=self.job_id)

context["ti"].xcom_push(key="job_id", value=job.job_id)
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryGetDataTrigger(
conn_id=self.gcp_conn_id,
job_id=self.job_id,
job_id=job.job_id,
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=hook.project_id,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
),
method_name="execute_complete",
)
Expand Down
13 changes: 11 additions & 2 deletions airflow/providers/google/cloud/triggers/bigquery.py
Expand Up @@ -165,7 +165,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]


class BigQueryGetDataTrigger(BigQueryInsertJobTrigger):
"""BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class"""
"""
BigQueryGetDataTrigger run on the trigger worker, inherits from BigQueryInsertJobTrigger class
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists
(default: False).
"""

def __init__(self, as_dict: bool = False, **kwargs):
super().__init__(**kwargs)
self.as_dict = as_dict

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serializes BigQueryInsertJobTrigger arguments and classpath."""
Expand All @@ -190,7 +199,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
response_from_hook = await hook.get_job_status(job_id=self.job_id, project_id=self.project_id)
if response_from_hook == "success":
query_results = await hook.get_job_output(job_id=self.job_id, project_id=self.project_id)
records = hook.get_records(query_results)
records = hook.get_records(query_results=query_results, as_dict=self.as_dict)
self.log.debug("Response from hook: %s", response_from_hook)
yield TriggerEvent(
{
Expand Down
Expand Up @@ -208,10 +208,11 @@ To fetch data from a BigQuery table you can use
Alternatively you can fetch data for selected columns if you pass fields to
``selected_fields``.

This operator returns data in a Python list where the number of elements in the
returned list will be equal to the number of rows fetched. Each element in the
list will again be a list where elements would represent the column values for
The result of this operator can be retrieved in two different formats based on the value of the ``as_dict`` parameter:
``False`` (default) - A Python list of lists, where the number of elements in the nesting list will be equal to the number of rows fetched. Each element in the
nesting will a nested list where elements would represent the column values for
that row.
``True`` - A Python list of dictionaries, where each dictionary represents a row. In each dictionary, the keys are the column names and the values are the corresponding values for those columns.

.. exampleinclude:: /../../tests/system/providers/google/cloud/bigquery/example_bigquery_queries.py
:language: python
Expand Down
26 changes: 26 additions & 0 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Expand Up @@ -2348,3 +2348,29 @@ def test_get_records_return_type(self):
assert isinstance(result[0][0], int)
assert isinstance(result[0][1], float)
assert isinstance(result[0][2], str)

def test_get_records_as_dict(self):
query_result = {
"kind": "bigquery#getQueryResultsResponse",
"etag": "test_etag",
"schema": {
"fields": [
{"name": "f0_", "type": "INTEGER", "mode": "NULLABLE"},
{"name": "f1_", "type": "FLOAT", "mode": "NULLABLE"},
{"name": "f2_", "type": "STRING", "mode": "NULLABLE"},
]
},
"jobReference": {
"projectId": "test_airflow-providers",
"jobId": "test_jobid",
"location": "US",
},
"totalRows": "1",
"rows": [{"f": [{"v": "22"}, {"v": "3.14"}, {"v": "PI"}]}],
"totalBytesProcessed": "0",
"jobComplete": True,
"cacheHit": False,
}
hook = BigQueryAsyncHook()
result = hook.get_records(query_result, as_dict=True)
assert result == [{"f0_": 22, "f1_": 3.14, "f2_": "PI"}]
16 changes: 12 additions & 4 deletions tests/providers/google/cloud/operators/test_bigquery.py
Expand Up @@ -785,8 +785,9 @@ def test_bigquery_operator_extra_link_when_multiple_query(


class TestBigQueryGetDataOperator:
@pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_execute(self, mock_hook):
def test_execute(self, mock_hook, as_dict):
max_results = 100
selected_fields = "DATE"
operator = BigQueryGetDataOperator(
Expand All @@ -797,6 +798,7 @@ def test_execute(self, mock_hook):
max_results=max_results,
selected_fields=selected_fields,
location=TEST_DATASET_LOCATION,
as_dict=as_dict,
)
operator.execute(None)
mock_hook.return_value.list_rows.assert_called_once_with(
Expand Down Expand Up @@ -840,9 +842,10 @@ def test_bigquery_get_data_operator_async_with_selected_fields(
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"

@pytest.mark.parametrize("as_dict", [True, False])
@mock.patch("airflow.providers.google.cloud.operators.bigquery.BigQueryHook")
def test_bigquery_get_data_operator_async_without_selected_fields(
self, mock_hook, create_task_instance_of_operator
self, mock_hook, create_task_instance_of_operator, as_dict
):
"""
Asserts that a task is deferred and a BigQueryGetDataTrigger will be fired
Expand All @@ -862,6 +865,7 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
table_id=TEST_TABLE_ID,
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with pytest.raises(TaskDeferred) as exc:
Expand All @@ -871,7 +875,8 @@ def test_bigquery_get_data_operator_async_without_selected_fields(
exc.value.trigger, BigQueryGetDataTrigger
), "Trigger is not a BigQueryGetDataTrigger"

def test_bigquery_get_data_operator_execute_failure(self):
@pytest.mark.parametrize("as_dict", [True, False])
def test_bigquery_get_data_operator_execute_failure(self, as_dict):
"""Tests that an AirflowException is raised in case of error event"""

operator = BigQueryGetDataOperator(
Expand All @@ -880,14 +885,16 @@ def test_bigquery_get_data_operator_execute_failure(self):
table_id="any",
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with pytest.raises(AirflowException):
operator.execute_complete(
context=None, event={"status": "error", "message": "test failure message"}
)

def test_bigquery_get_data_op_execute_complete_with_records(self):
@pytest.mark.parametrize("as_dict", [True, False])
def test_bigquery_get_data_op_execute_complete_with_records(self, as_dict):
"""Asserts that exception is raised with correct expected exception message"""

operator = BigQueryGetDataOperator(
Expand All @@ -896,6 +903,7 @@ def test_bigquery_get_data_op_execute_complete_with_records(self):
table_id="any",
max_results=100,
deferrable=True,
as_dict=as_dict,
)

with mock.patch.object(operator.log, "info") as mock_log_info:
Expand Down

0 comments on commit b8f7376

Please sign in to comment.