From 511d0ee256b819690ccf0f6b30d12340b1dd7f0a Mon Sep 17 00:00:00 2001 From: Wojciech Januszek Date: Sat, 30 Apr 2022 20:34:39 +0200 Subject: [PATCH] Bigquery assets (#23165) --- .../providers/google/cloud/hooks/bigquery.py | 19 ++-- .../providers/google/cloud/links/bigquery.py | 77 +++++++++++++ .../google/cloud/operators/bigquery.py | 102 ++++++++++++++++-- .../cloud/transfers/bigquery_to_bigquery.py | 14 ++- .../google/cloud/transfers/bigquery_to_gcs.py | 15 ++- .../cloud/transfers/bigquery_to_mssql.py | 10 ++ airflow/providers/google/provider.yaml | 2 + .../google/cloud/hooks/test_bigquery.py | 10 -- .../google/cloud/operators/test_bigquery.py | 30 +++--- .../transfers/test_bigquery_to_bigquery.py | 2 +- .../cloud/transfers/test_bigquery_to_gcs.py | 2 +- .../cloud/transfers/test_bigquery_to_mssql.py | 2 +- 12 files changed, 240 insertions(+), 45 deletions(-) create mode 100644 airflow/providers/google/cloud/links/bigquery.py diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 339cd74dee27e..d4f54f56cef09 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -408,7 +408,7 @@ def create_empty_dataset( location: Optional[str] = None, dataset_reference: Optional[Dict[str, Any]] = None, exists_ok: bool = True, - ) -> None: + ) -> Dict[str, Any]: """ Create a new empty dataset: https://cloud.google.com/bigquery/docs/reference/rest/v2/datasets/insert @@ -452,8 +452,11 @@ def create_empty_dataset( dataset: Dataset = Dataset.from_api_repr(dataset_reference) self.log.info('Creating dataset: %s in project: %s ', dataset.dataset_id, dataset.project) - self.get_client(location=location).create_dataset(dataset=dataset, exists_ok=exists_ok) + dataset_object = self.get_client(location=location).create_dataset( + dataset=dataset, exists_ok=exists_ok + ) self.log.info('Dataset created successfully.') + return dataset_object.to_api_repr() @GoogleBaseHook.fallback_to_default_project_id def get_dataset_tables( @@ -533,7 +536,7 @@ def create_external_table( encryption_configuration: Optional[Dict] = None, location: Optional[str] = None, project_id: Optional[str] = None, - ) -> None: + ) -> Table: """ Creates a new external table in the dataset with the data from Google Cloud Storage. See here: @@ -659,10 +662,11 @@ def create_external_table( table.encryption_configuration = EncryptionConfiguration.from_api_repr(encryption_configuration) self.log.info('Creating external table: %s', external_project_dataset_table) - self.create_empty_table( + table_object = self.create_empty_table( table_resource=table.to_api_repr(), project_id=project_id, location=location, exists_ok=True ) self.log.info('External table created successfully: %s', external_project_dataset_table) + return table_object @GoogleBaseHook.fallback_to_default_project_id def update_table( @@ -1287,7 +1291,7 @@ def update_table_schema( dataset_id: str, table_id: str, project_id: Optional[str] = None, - ) -> None: + ) -> Dict[str, Any]: """ Update fields within a schema for a given dataset and table. Note that some fields in schemas are immutable and trying to change them will cause @@ -1361,13 +1365,14 @@ def _remove_policy_tags(schema: List[Dict[str, Any]]): if not include_policy_tags: _remove_policy_tags(new_schema) - self.update_table( + table = self.update_table( table_resource={"schema": {"fields": new_schema}}, fields=["schema"], project_id=project_id, dataset_id=dataset_id, table_id=table_id, ) + return table @GoogleBaseHook.fallback_to_default_project_id def poll_job_complete( @@ -2244,7 +2249,7 @@ def create_empty_table(self, *args, **kwargs) -> None: ) return self.hook.create_empty_table(*args, **kwargs) - def create_empty_dataset(self, *args, **kwargs) -> None: + def create_empty_dataset(self, *args, **kwargs) -> Dict[str, Any]: """ This method is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.create_empty_dataset` diff --git a/airflow/providers/google/cloud/links/bigquery.py b/airflow/providers/google/cloud/links/bigquery.py new file mode 100644 index 0000000000000..a80818e2034ed --- /dev/null +++ b/airflow/providers/google/cloud/links/bigquery.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This module contains Google BigQuery links.""" +from typing import TYPE_CHECKING + +from airflow.models import BaseOperator +from airflow.providers.google.cloud.links.base import BaseGoogleLink + +if TYPE_CHECKING: + from airflow.utils.context import Context + +BIGQUERY_BASE_LINK = "https://console.cloud.google.com/bigquery" +BIGQUERY_DATASET_LINK = ( + BIGQUERY_BASE_LINK + "?referrer=search&project={project_id}&d={dataset_id}&p={project_id}&page=dataset" +) +BIGQUERY_TABLE_LINK = ( + BIGQUERY_BASE_LINK + + "?referrer=search&project={project_id}&d={dataset_id}&p={project_id}&page=table&t={table_id}" +) + + +class BigQueryDatasetLink(BaseGoogleLink): + """Helper class for constructing BigQuery Dataset Link""" + + name = "BigQuery Dataset" + key = "bigquery_dataset" + format_str = BIGQUERY_DATASET_LINK + + @staticmethod + def persist( + context: "Context", + task_instance: BaseOperator, + dataset_id: str, + project_id: str, + ): + task_instance.xcom_push( + context, + key=BigQueryDatasetLink.key, + value={"dataset_id": dataset_id, "project_id": project_id}, + ) + + +class BigQueryTableLink(BaseGoogleLink): + """Helper class for constructing BigQuery Table Link""" + + name = "BigQuery Table" + key = "bigquery_table" + format_str = BIGQUERY_TABLE_LINK + + @staticmethod + def persist( + context: "Context", + task_instance: BaseOperator, + dataset_id: str, + project_id: str, + table_id: str, + ): + task_instance.xcom_push( + context, + key=BigQueryTableLink.key, + value={"dataset_id": dataset_id, "project_id": project_id, "table_id": table_id}, + ) diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 8eb5c67b86985..ecd42a576b653 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -38,6 +38,7 @@ from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url +from airflow.providers.google.cloud.links.bigquery import BigQueryDatasetLink, BigQueryTableLink if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstanceKey @@ -798,6 +799,7 @@ class BigQueryCreateEmptyTableOperator(BaseOperator): ) template_fields_renderers = {"table_resource": "json", "materialized_view": "json"} ui_color = BigQueryUIColors.TABLE.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -879,6 +881,13 @@ def execute(self, context: 'Context') -> None: table_resource=self.table_resource, exists_ok=self.exists_ok, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table.to_api_repr()["tableReference"]["datasetId"], + project_id=table.to_api_repr()["tableReference"]["projectId"], + table_id=table.to_api_repr()["tableReference"]["tableId"], + ) self.log.info( 'Table %s.%s.%s created successfully', table.project, table.dataset_id, table.table_id ) @@ -977,6 +986,7 @@ class BigQueryCreateExternalTableOperator(BaseOperator): ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -1094,9 +1104,16 @@ def execute(self, context: 'Context') -> None: impersonation_chain=self.impersonation_chain, ) if self.table_resource: - bq_hook.create_empty_table( + table = bq_hook.create_empty_table( table_resource=self.table_resource, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table.to_api_repr()["tableReference"]["datasetId"], + project_id=table.to_api_repr()["tableReference"]["projectId"], + table_id=table.to_api_repr()["tableReference"]["tableId"], + ) return if not self.schema_fields and self.schema_object and self.source_format != 'DATASTORE_BACKUP': @@ -1111,7 +1128,7 @@ def execute(self, context: 'Context') -> None: source_uris = [f"gs://{self.bucket}/{source_object}" for source_object in self.source_objects] - bq_hook.create_external_table( + table = bq_hook.create_external_table( external_project_dataset_table=self.destination_project_dataset_table, schema_fields=schema_fields, source_uris=source_uris, @@ -1128,6 +1145,13 @@ def execute(self, context: 'Context') -> None: labels=self.labels, encryption_configuration=self.encryption_configuration, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table.to_api_repr()["tableReference"]["datasetId"], + project_id=table.to_api_repr()["tableReference"]["projectId"], + table_id=table.to_api_repr()["tableReference"]["tableId"], + ) class BigQueryDeleteDatasetOperator(BaseOperator): @@ -1257,6 +1281,7 @@ class BigQueryCreateEmptyDatasetOperator(BaseOperator): ) template_fields_renderers = {"dataset_reference": "json"} ui_color = BigQueryUIColors.DATASET.value + operator_extra_links = (BigQueryDatasetLink(),) def __init__( self, @@ -1292,13 +1317,19 @@ def execute(self, context: 'Context') -> None: ) try: - bq_hook.create_empty_dataset( + dataset = bq_hook.create_empty_dataset( project_id=self.project_id, dataset_id=self.dataset_id, dataset_reference=self.dataset_reference, location=self.location, exists_ok=self.exists_ok, ) + BigQueryDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset["datasetReference"]["datasetId"], + project_id=dataset["datasetReference"]["projectId"], + ) except Conflict: dataset_id = self.dataset_reference.get("datasetReference", {}).get("datasetId", self.dataset_id) self.log.info('Dataset %s already exists.', dataset_id) @@ -1339,6 +1370,7 @@ class BigQueryGetDatasetOperator(BaseOperator): 'impersonation_chain', ) ui_color = BigQueryUIColors.DATASET.value + operator_extra_links = (BigQueryDatasetLink(),) def __init__( self, @@ -1367,7 +1399,14 @@ def execute(self, context: 'Context'): self.log.info('Start getting dataset: %s:%s', self.project_id, self.dataset_id) dataset = bq_hook.get_dataset(dataset_id=self.dataset_id, project_id=self.project_id) - return dataset.to_api_repr() + dataset = dataset.to_api_repr() + BigQueryDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset["datasetReference"]["datasetId"], + project_id=dataset["datasetReference"]["projectId"], + ) + return dataset class BigQueryGetDatasetTablesOperator(BaseOperator): @@ -1558,6 +1597,7 @@ class BigQueryUpdateTableOperator(BaseOperator): ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -1589,7 +1629,7 @@ def execute(self, context: 'Context'): impersonation_chain=self.impersonation_chain, ) - return bq_hook.update_table( + table = bq_hook.update_table( table_resource=self.table_resource, fields=self.fields, dataset_id=self.dataset_id, @@ -1597,6 +1637,16 @@ def execute(self, context: 'Context'): project_id=self.project_id, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table["tableReference"]["datasetId"], + project_id=table["tableReference"]["projectId"], + table_id=table["tableReference"]["tableId"], + ) + + return table + class BigQueryUpdateDatasetOperator(BaseOperator): """ @@ -1641,6 +1691,7 @@ class BigQueryUpdateDatasetOperator(BaseOperator): ) template_fields_renderers = {"dataset_resource": "json"} ui_color = BigQueryUIColors.DATASET.value + operator_extra_links = (BigQueryDatasetLink(),) def __init__( self, @@ -1677,7 +1728,15 @@ def execute(self, context: 'Context'): dataset_id=self.dataset_id, fields=fields, ) - return dataset.to_api_repr() + + dataset = dataset.to_api_repr() + BigQueryDatasetLink.persist( + context=context, + task_instance=self, + dataset_id=dataset["datasetReference"]["datasetId"], + project_id=dataset["datasetReference"]["projectId"], + ) + return dataset class BigQueryDeleteTableOperator(BaseOperator): @@ -1782,6 +1841,7 @@ class BigQueryUpsertTableOperator(BaseOperator): ) template_fields_renderers = {"table_resource": "json"} ui_color = BigQueryUIColors.TABLE.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -1813,11 +1873,18 @@ def execute(self, context: 'Context') -> None: location=self.location, impersonation_chain=self.impersonation_chain, ) - hook.run_table_upsert( + table = hook.run_table_upsert( dataset_id=self.dataset_id, table_resource=self.table_resource, project_id=self.project_id, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table["tableReference"]["datasetId"], + project_id=table["tableReference"]["projectId"], + table_id=table["tableReference"]["tableId"], + ) class BigQueryUpdateTableSchemaOperator(BaseOperator): @@ -1879,6 +1946,7 @@ class BigQueryUpdateTableSchemaOperator(BaseOperator): ) template_fields_renderers = {"schema_fields_updates": "json"} ui_color = BigQueryUIColors.TABLE.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -1910,7 +1978,7 @@ def execute(self, context: 'Context'): impersonation_chain=self.impersonation_chain, ) - return bq_hook.update_table_schema( + table = bq_hook.update_table_schema( schema_fields_updates=self.schema_fields_updates, include_policy_tags=self.include_policy_tags, dataset_id=self.dataset_id, @@ -1918,6 +1986,15 @@ def execute(self, context: 'Context'): project_id=self.project_id, ) + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table["tableReference"]["datasetId"], + project_id=table["tableReference"]["projectId"], + table_id=table["tableReference"]["tableId"], + ) + return table + class BigQueryInsertJobOperator(BaseOperator): """ @@ -1983,6 +2060,7 @@ class BigQueryInsertJobOperator(BaseOperator): ) template_fields_renderers = {"configuration": "json", "configuration.query.query": "sql"} ui_color = BigQueryUIColors.QUERY.value + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -2088,6 +2166,14 @@ def execute(self, context: Any): f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`" ) + table = job.to_api_repr()["configuration"]["query"]["destinationTable"] + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=table["datasetId"], + project_id=table["projectId"], + table_id=table["tableId"], + ) self.job_id = job.job_id return job.job_id diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py index b81d3477464e1..87e6f3f586f4a 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_bigquery.py @@ -21,6 +21,7 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -74,6 +75,7 @@ class BigQueryToBigQueryOperator(BaseOperator): ) template_ext: Sequence[str] = ('.sql',) ui_color = '#e6f0e4' + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -118,7 +120,7 @@ def execute(self, context: 'Context') -> None: with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) - hook.run_copy( + job_id = hook.run_copy( source_project_dataset_tables=self.source_project_dataset_tables, destination_project_dataset_table=self.destination_project_dataset_table, write_disposition=self.write_disposition, @@ -126,3 +128,13 @@ def execute(self, context: 'Context') -> None: labels=self.labels, encryption_configuration=self.encryption_configuration, ) + + job = hook.get_job(job_id=job_id).to_api_repr() + conf = job["configuration"]["copy"]["destinationTable"] + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=conf["datasetId"], + project_id=conf["projectId"], + table_id=conf["tableId"], + ) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py index a0d7b3cfff88e..09ac190e0f269 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_gcs.py @@ -20,6 +20,7 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink if TYPE_CHECKING: from airflow.utils.context import Context @@ -70,6 +71,7 @@ class BigQueryToGCSOperator(BaseOperator): ) template_ext: Sequence[str] = () ui_color = '#e4e6f0' + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -113,7 +115,7 @@ def execute(self, context: 'Context'): location=self.location, impersonation_chain=self.impersonation_chain, ) - hook.run_extract( + job_id = hook.run_extract( source_project_dataset_table=self.source_project_dataset_table, destination_cloud_storage_uris=self.destination_cloud_storage_uris, compression=self.compression, @@ -122,3 +124,14 @@ def execute(self, context: 'Context'): print_header=self.print_header, labels=self.labels, ) + + job = hook.get_job(job_id=job_id).to_api_repr() + conf = job["configuration"]["extract"]["sourceTable"] + dataset_id, project_id, table_id = conf["datasetId"], conf["projectId"], conf["tableId"] + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + table_id=table_id, + ) diff --git a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index ca63ff0d99b50..d8a600eabeb5f 100644 --- a/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -20,6 +20,7 @@ from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook +from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.cloud.utils.bigquery_get_data import bigquery_get_data from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook @@ -75,6 +76,7 @@ class BigQueryToMsSqlOperator(BaseOperator): """ template_fields: Sequence[str] = ('source_project_dataset_table', 'mssql_table', 'impersonation_chain') + operator_extra_links = (BigQueryTableLink(),) def __init__( self, @@ -118,6 +120,14 @@ def execute(self, context: 'Context') -> None: location=self.location, impersonation_chain=self.impersonation_chain, ) + project_id, dataset_id, table_id = self.source_project_dataset_table.split('.') + BigQueryTableLink.persist( + context=context, + task_instance=self, + dataset_id=dataset_id, + project_id=project_id, + table_id=table_id, + ) mssql_hook = MsSqlHook(mssql_conn_id=self.mssql_conn_id, schema=self.database) for rows in bigquery_get_data( self.log, diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 35524343d2563..4089cefd99d74 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -884,6 +884,8 @@ extra-links: - airflow.providers.google.cloud.operators.datafusion.DataFusionPipelinesLink - airflow.providers.google.cloud.links.dataplex.DataplexTaskLink - airflow.providers.google.cloud.links.dataplex.DataplexTasksLink + - airflow.providers.google.cloud.links.bigquery.BigQueryDatasetLink + - airflow.providers.google.cloud.links.bigquery.BigQueryTableLink - airflow.providers.google.cloud.links.bigquery_dts.BigQueryDataTransferConfigLink - airflow.providers.google.cloud.links.dataproc.DataprocLink - airflow.providers.google.cloud.links.dataproc.DataprocListLink diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 5f0223931892c..9de8333eb4c78 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -918,16 +918,6 @@ def test_insert_job(self, mock_client, mock_query_job, nowait): def test_dbapi_get_uri(self): assert self.hook.get_uri().startswith('bigquery://') - def test_dbapi_get_sqlalchemy_engine_failed(self): - with pytest.raises( - AirflowException, - match="For now, we only support instantiating SQLAlchemy engine by" - " using ADC" - ", extra__google_cloud_platform__key_path" - "and extra__google_cloud_platform__keyfile_dict", - ): - self.hook.get_sqlalchemy_engine() - class TestBigQueryTableSplitter(unittest.TestCase): def test_internal_need_default_project(self): diff --git a/tests/providers/google/cloud/operators/test_bigquery.py b/tests/providers/google/cloud/operators/test_bigquery.py index 20bb205bb0f25..b5e42cce2f96a 100644 --- a/tests/providers/google/cloud/operators/test_bigquery.py +++ b/tests/providers/google/cloud/operators/test_bigquery.py @@ -82,7 +82,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_id=TEST_TABLE_ID ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -108,7 +108,7 @@ def test_create_view(self, mock_hook): view=VIEW_DEFINITION, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -134,7 +134,7 @@ def test_create_materialized_view(self, mock_hook): materialized_view=MATERIALIZED_VIEW_DEFINITION, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -170,7 +170,7 @@ def test_create_clustered_empty_table(self, mock_hook): cluster_fields=cluster_fields, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_table.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -200,7 +200,7 @@ def test_execute(self, mock_hook): autodetect=True, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_external_table.assert_called_once_with( external_project_dataset_table=f'{TEST_DATASET}.{TEST_TABLE_ID}', schema_fields=[], @@ -246,7 +246,7 @@ def test_execute(self, mock_hook): location=TEST_DATASET_LOCATION, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.create_empty_dataset.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, @@ -263,7 +263,7 @@ def test_execute(self, mock_hook): task_id=TASK_ID, dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.get_dataset.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID ) @@ -281,7 +281,7 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.update_table.assert_called_once_with( table_resource=table_resource, fields=None, @@ -310,7 +310,7 @@ def test_execute(self, mock_hook): table_id=TEST_TABLE_ID, project_id=TEST_GCP_PROJECT_ID, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.update_table_schema.assert_called_once_with( schema_fields_updates=schema_field_updates, @@ -349,7 +349,7 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.update_dataset.assert_called_once_with( dataset_resource=dataset_resource, dataset_id=TEST_DATASET, @@ -779,7 +779,7 @@ def test_execute(self, mock_hook): project_id=TEST_GCP_PROJECT_ID, ) - operator.execute(None) + operator.execute(context=MagicMock()) mock_hook.return_value.run_table_upsert.assert_called_once_with( dataset_id=TEST_DATASET, project_id=TEST_GCP_PROJECT_ID, table_resource=TEST_TABLE_RESOURCES ) @@ -809,7 +809,7 @@ def test_execute_success(self, mock_hook, mock_md5): job_id=job_id, project_id=TEST_GCP_PROJECT_ID, ) - result = op.execute({}) + result = op.execute(context=MagicMock()) mock_hook.return_value.insert_job.assert_called_once_with( configuration=configuration, @@ -846,7 +846,7 @@ def test_on_kill(self, mock_hook, mock_md5): project_id=TEST_GCP_PROJECT_ID, cancel_on_kill=False, ) - op.execute({}) + op.execute(context=MagicMock()) op.on_kill() mock_hook.return_value.cancel_job.assert_not_called() @@ -917,7 +917,7 @@ def test_execute_reattach(self, mock_hook, mock_md5): project_id=TEST_GCP_PROJECT_ID, reattach_states={"PENDING"}, ) - result = op.execute({}) + result = op.execute(context=MagicMock()) mock_hook.return_value.get_job.assert_called_once_with( location=TEST_DATASET_LOCATION, @@ -962,7 +962,7 @@ def test_execute_force_rerun(self, mock_hook, mock_uuid, mock_md5): project_id=TEST_GCP_PROJECT_ID, force_rerun=True, ) - result = op.execute({}) + result = op.execute(context=MagicMock()) mock_hook.return_value.insert_job.assert_called_once_with( configuration=configuration, diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py index a3995461ffb78..109cafc4a3e7a 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_bigquery.py @@ -46,7 +46,7 @@ def test_execute(self, mock_hook): encryption_configuration=encryption_configuration, ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) mock_hook.return_value.run_copy.assert_called_once_with( source_project_dataset_tables=source_project_dataset_tables, destination_project_dataset_table=destination_project_dataset_table, diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py index 4542172649770..5ed9b660310ab 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_gcs.py @@ -49,7 +49,7 @@ def test_execute(self, mock_hook): labels=labels, ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) mock_hook.return_value.run_extract.assert_called_once_with( source_project_dataset_table=source_project_dataset_table, diff --git a/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py index 5f88c0cedff66..fdae853810fd5 100644 --- a/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py +++ b/tests/providers/google/cloud/transfers/test_bigquery_to_mssql.py @@ -38,7 +38,7 @@ def test_execute_good_request_to_bq(self, mock_hook): replace=False, ) - operator.execute(None) + operator.execute(context=mock.MagicMock()) # fmt: off mock_hook.return_value.list_rows.assert_called_once_with( dataset_id=TEST_DATASET,