Skip to content

Commit

Permalink
Remove usage of deprecated methods from BigQueryCursor (#35606)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Nov 17, 2023
1 parent 03a0b72 commit 0c6fd5b
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 21 deletions.
170 changes: 168 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Expand Up @@ -129,7 +129,8 @@ def __init__(

def get_conn(self) -> BigQueryConnection:
"""Get a BigQuery PEP 249 connection object."""
service = self.get_service()
http_authorized = self._authorize()
service = build("bigquery", "v2", http=http_authorized, cache_discovery=False)
return BigQueryConnection(
service=service,
project_id=self.project_id,
Expand Down Expand Up @@ -2775,7 +2776,7 @@ def execute(self, operation: str, parameters: dict | None = None) -> None:
"""
sql = _bind_parameters(operation, parameters) if parameters else operation
self.flush_results()
self.job_id = self.hook.run_query(sql)
self.job_id = self._run_query(sql)

query_results = self._get_query_result()
if "schema" in query_results:
Expand Down Expand Up @@ -2913,6 +2914,171 @@ def _get_query_result(self) -> dict:

return query_results

def _run_query(
self,
sql,
location: str | None = None,
) -> str:
"""Run job query."""
if not self.project_id:
raise ValueError("The project_id should be set")

configuration = self._prepare_query_configuration(sql)
job = self.hook.insert_job(configuration=configuration, project_id=self.project_id, location=location)

return job.job_id

def _prepare_query_configuration(
self,
sql,
destination_dataset_table: str | None = None,
write_disposition: str = "WRITE_EMPTY",
allow_large_results: bool = False,
flatten_results: bool | None = None,
udf_config: list | None = None,
use_legacy_sql: bool | None = None,
maximum_billing_tier: int | None = None,
maximum_bytes_billed: float | None = None,
create_disposition: str = "CREATE_IF_NEEDED",
query_params: list | None = None,
labels: dict | None = None,
schema_update_options: Iterable | None = None,
priority: str | None = None,
time_partitioning: dict | None = None,
api_resource_configs: dict | None = None,
cluster_fields: list[str] | None = None,
encryption_configuration: dict | None = None,
):
"""Helper method that prepare configuration for query."""
labels = labels or self.hook.labels
schema_update_options = list(schema_update_options or [])

priority = priority or self.hook.priority

if time_partitioning is None:
time_partitioning = {}

if not api_resource_configs:
api_resource_configs = self.hook.api_resource_configs
else:
_validate_value("api_resource_configs", api_resource_configs, dict)

configuration = deepcopy(api_resource_configs)

if "query" not in configuration:
configuration["query"] = {}
else:
_validate_value("api_resource_configs['query']", configuration["query"], dict)

if sql is None and not configuration["query"].get("query", None):
raise TypeError("`BigQueryBaseCursor.run_query` missing 1 required positional argument: `sql`")

# BigQuery also allows you to define how you want a table's schema to change
# as a side effect of a query job
# for more details:
# https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.query.schemaUpdateOptions

allowed_schema_update_options = ["ALLOW_FIELD_ADDITION", "ALLOW_FIELD_RELAXATION"]

if not set(allowed_schema_update_options).issuperset(set(schema_update_options)):
raise ValueError(
f"{schema_update_options} contains invalid schema update options."
f" Please only use one or more of the following options: {allowed_schema_update_options}"
)

if schema_update_options:
if write_disposition not in ["WRITE_APPEND", "WRITE_TRUNCATE"]:
raise ValueError(
"schema_update_options is only "
"allowed if write_disposition is "
"'WRITE_APPEND' or 'WRITE_TRUNCATE'."
)

if destination_dataset_table:
destination_project, destination_dataset, destination_table = self.hook.split_tablename(
table_input=destination_dataset_table, default_project_id=self.project_id
)

destination_dataset_table = { # type: ignore
"projectId": destination_project,
"datasetId": destination_dataset,
"tableId": destination_table,
}

if cluster_fields:
cluster_fields = {"fields": cluster_fields} # type: ignore

query_param_list: list[tuple[Any, str, str | bool | None | dict, type | tuple[type]]] = [
(sql, "query", None, (str,)),
(priority, "priority", priority, (str,)),
(use_legacy_sql, "useLegacySql", self.use_legacy_sql, bool),
(query_params, "queryParameters", None, list),
(udf_config, "userDefinedFunctionResources", None, list),
(maximum_billing_tier, "maximumBillingTier", None, int),
(maximum_bytes_billed, "maximumBytesBilled", None, float),
(time_partitioning, "timePartitioning", {}, dict),
(schema_update_options, "schemaUpdateOptions", None, list),
(destination_dataset_table, "destinationTable", None, dict),
(cluster_fields, "clustering", None, dict),
]

for param, param_name, param_default, param_type in query_param_list:
if param_name not in configuration["query"] and param in [None, {}, ()]:
if param_name == "timePartitioning":
param_default = _cleanse_time_partitioning(destination_dataset_table, time_partitioning)
param = param_default

if param in [None, {}, ()]:
continue

_api_resource_configs_duplication_check(param_name, param, configuration["query"])

configuration["query"][param_name] = param

# check valid type of provided param,
# it last step because we can get param from 2 sources,
# and first of all need to find it

_validate_value(param_name, configuration["query"][param_name], param_type)

if param_name == "schemaUpdateOptions" and param:
self.log.info("Adding experimental 'schemaUpdateOptions': %s", schema_update_options)

if param_name == "destinationTable":
for key in ["projectId", "datasetId", "tableId"]:
if key not in configuration["query"]["destinationTable"]:
raise ValueError(
"Not correct 'destinationTable' in "
"api_resource_configs. 'destinationTable' "
"must be a dict with {'projectId':'', "
"'datasetId':'', 'tableId':''}"
)
else:
configuration["query"].update(
{
"allowLargeResults": allow_large_results,
"flattenResults": flatten_results,
"writeDisposition": write_disposition,
"createDisposition": create_disposition,
}
)

if (
"useLegacySql" in configuration["query"]
and configuration["query"]["useLegacySql"]
and "queryParameters" in configuration["query"]
):
raise ValueError("Query parameters are not allowed when using legacy SQL")

if labels:
_api_resource_configs_duplication_check("labels", labels, configuration)
configuration["labels"] = labels

if encryption_configuration:
configuration["query"]["destinationEncryptionConfiguration"] = encryption_configuration

return configuration


def _bind_parameters(operation: str, parameters: dict) -> str:
"""Helper method that binds parameters to a SQL query."""
Expand Down
38 changes: 19 additions & 19 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Expand Up @@ -1208,7 +1208,7 @@ def test_create_materialized_view(self, mock_bq_client, mock_table):

@pytest.mark.db_test
class TestBigQueryCursor(_BigQueryBaseTestClass):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_execute_with_parameters(self, mock_insert, _):
bq_cursor = self.hook.get_cursor()
Expand All @@ -1223,7 +1223,7 @@ def test_execute_with_parameters(self, mock_insert, _):
}
mock_insert.assert_called_once_with(configuration=conf, project_id=PROJECT_ID, location=None)

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_execute_many(self, mock_insert, _):
bq_cursor = self.hook.get_cursor()
Expand Down Expand Up @@ -1275,10 +1275,10 @@ def test_format_schema_for_description(self):
("field_3", "STRING", None, None, None, None, False),
]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_description(self, mock_insert, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
def test_description(self, mock_insert, mock_build):
mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"schema": {
Expand All @@ -1292,10 +1292,10 @@ def test_description(self, mock_insert, mock_get_service):
bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, None, True)]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_description_no_schema(self, mock_insert, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
def test_description_no_schema(self, mock_insert, mock_build):
mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {}

Expand Down Expand Up @@ -1369,9 +1369,9 @@ def test_next_buffer(self, mock_get_service):
result = bq_cursor.next()
assert result is None

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_next(self, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
def test_next(self, mock_build):
mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"rows": [
Expand Down Expand Up @@ -1402,10 +1402,10 @@ def test_next(self, mock_get_service):
)
mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries)

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results")
def test_next_no_rows(self, mock_flush_results, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
def test_next_no_rows(self, mock_flush_results, mock_build):
mock_get_query_results = mock_build.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {}

Expand All @@ -1421,10 +1421,10 @@ def test_next_no_rows(self, mock_flush_results, mock_get_service):
mock_execute.assert_called_once_with(num_retries=bq_cursor.num_retries)
assert mock_flush_results.call_count == 1

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor.flush_results")
def test_flush_cursor_in_execute(self, _, mock_insert, mock_get_service):
def test_flush_cursor_in_execute(self, _, mock_insert, mock_build):
bq_cursor = self.hook.get_cursor()
bq_cursor.execute("SELECT %(foo)s", {"foo": "bar"})
assert mock_insert.call_count == 1
Expand Down Expand Up @@ -1786,7 +1786,7 @@ def test_run_query_with_arg(self, mock_insert):
class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):
"""Ensure `use_legacy_sql` param in `BigQueryHook` propagates properly."""

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
self.hook.get_first("query")
Expand All @@ -1797,10 +1797,10 @@ def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.get_credentials_and_project_id",
return_value=(CREDENTIALS, PROJECT_ID),
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_legacy_sql_override_propagates_properly(
self, mock_insert, mock_get_service, mock_get_creds_and_proj_id
self, mock_insert, mock_build, mock_get_creds_and_proj_id
):
bq_hook = BigQueryHook(use_legacy_sql=False)
bq_hook.get_first("query")
Expand Down

0 comments on commit 0c6fd5b

Please sign in to comment.