diff --git a/airflow/providers/google/cloud/hooks/bigquery.py b/airflow/providers/google/cloud/hooks/bigquery.py index 0594ce4351884..f270e256fbf75 100644 --- a/airflow/providers/google/cloud/hooks/bigquery.py +++ b/airflow/providers/google/cloud/hooks/bigquery.py @@ -3388,6 +3388,31 @@ async def create_job_for_partition_get( job_query_resp = await job_client.query(query_request, cast(Session, session)) return job_query_resp["jobReference"]["jobId"] + async def cancel_job(self, job_id: str, project_id: str | None, location: str | None) -> None: + """ + Cancel a BigQuery job. + + :param job_id: ID of the job to cancel. + :param project_id: Google Cloud Project where the job was running. + :param location: Location where the job was running. + """ + async with ClientSession() as session: + token = await self.get_token(session=session) + job = Job(job_id=job_id, project=project_id, location=location, token=token, session=session) # type: ignore[arg-type] + + self.log.info( + "Attempting to cancel BigQuery job: %s in project: %s, location: %s", + job_id, + project_id, + location, + ) + try: + await job.cancel() + self.log.info("Job %s cancellation requested.", job_id) + except Exception as e: + self.log.error("Failed to cancel BigQuery job %s: %s", job_id, str(e)) + raise + def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> list[Any]: """Convert a response from BigQuery to records. diff --git a/airflow/providers/google/cloud/operators/bigquery.py b/airflow/providers/google/cloud/operators/bigquery.py index 68b423fb4693c..9da97afc2a2d7 100644 --- a/airflow/providers/google/cloud/operators/bigquery.py +++ b/airflow/providers/google/cloud/operators/bigquery.py @@ -2903,6 +2903,7 @@ def execute(self, context: Any): location=self.location or hook.location, poll_interval=self.poll_interval, impersonation_chain=self.impersonation_chain, + cancel_on_kill=self.cancel_on_kill, ), method_name="execute_complete", ) diff --git a/airflow/providers/google/cloud/triggers/bigquery.py b/airflow/providers/google/cloud/triggers/bigquery.py index eafa4825bec2e..fd0170526199d 100644 --- a/airflow/providers/google/cloud/triggers/bigquery.py +++ b/airflow/providers/google/cloud/triggers/bigquery.py @@ -57,6 +57,7 @@ def __init__( table_id: str | None = None, poll_interval: float = 4.0, impersonation_chain: str | Sequence[str] | None = None, + cancel_on_kill: bool = True, ): super().__init__() self.log.info("Using the connection %s .", conn_id) @@ -69,6 +70,7 @@ def __init__( self.table_id = table_id self.poll_interval = poll_interval self.impersonation_chain = impersonation_chain + self.cancel_on_kill = cancel_on_kill def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize BigQueryInsertJobTrigger arguments and classpath.""" @@ -83,6 +85,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "table_id": self.table_id, "poll_interval": self.poll_interval, "impersonation_chain": self.impersonation_chain, + "cancel_on_kill": self.cancel_on_kill, }, ) @@ -113,6 +116,14 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.poll_interval, ) await asyncio.sleep(self.poll_interval) + except asyncio.CancelledError: + self.log.info("Task was killed.") + if self.job_id and self.cancel_on_kill: + await hook.cancel_job( # type: ignore[union-attr] + job_id=self.job_id, project_id=self.project_id, location=self.location + ) + else: + self.log.info("Skipping to cancel job: %s:%s.%s", self.project_id, self.location, self.job_id) except Exception as e: self.log.exception("Exception occurred while checking for query completion") yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/tests/providers/google/cloud/hooks/test_bigquery.py b/tests/providers/google/cloud/hooks/test_bigquery.py index 91181162878a3..37096b0ff3bde 100644 --- a/tests/providers/google/cloud/hooks/test_bigquery.py +++ b/tests/providers/google/cloud/hooks/test_bigquery.py @@ -2190,6 +2190,55 @@ async def test_get_job_output_assert_once_with(self, mock_job_instance): resp = await hook.get_job_output(job_id=JOB_ID, project_id=PROJECT_ID) assert resp == response + @pytest.mark.asyncio + @pytest.mark.db_test + @mock.patch("google.auth.default") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job") + async def test_cancel_job_success(self, mock_job, mock_auth_default): + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) + job_id = "test_job_id" + project_id = "test_project" + location = "US" + + mock_job_instance = AsyncMock() + mock_job_instance.cancel.return_value = None + mock_job.return_value = mock_job_instance + + await self.hook.cancel_job(job_id=job_id, project_id=project_id, location=location) + + mock_job_instance.cancel.assert_called_once() + + @pytest.mark.asyncio + @pytest.mark.db_test + @mock.patch("google.auth.default") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.Job") + async def test_cancel_job_failure(self, mock_job, mock_auth_default): + """ + Test that BigQueryAsyncHook handles exceptions during job cancellation correctly. + """ + mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials) + mock_credentials.token = "ACCESS_TOKEN" + mock_auth_default.return_value = (mock_credentials, PROJECT_ID) + + mock_job_instance = AsyncMock() + mock_job_instance.cancel.side_effect = Exception("Cancellation failed") + mock_job.return_value = mock_job_instance + + hook = BigQueryAsyncHook() + + job_id = "test_job_id" + project_id = "test_project" + location = "US" + + with pytest.raises(Exception) as excinfo: + await hook.cancel_job(job_id=job_id, project_id=project_id, location=location) + + assert "Cancellation failed" in str(excinfo.value), "Exception message not passed correctly" + + mock_job_instance.cancel.assert_called_once() + @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_instance") diff --git a/tests/providers/google/cloud/triggers/test_bigquery.py b/tests/providers/google/cloud/triggers/test_bigquery.py index 9eec245f83203..367b4850de740 100644 --- a/tests/providers/google/cloud/triggers/test_bigquery.py +++ b/tests/providers/google/cloud/triggers/test_bigquery.py @@ -165,6 +165,7 @@ def test_serialization(self, insert_job_trigger): classpath, kwargs = insert_job_trigger.serialize() assert classpath == "airflow.providers.google.cloud.triggers.bigquery.BigQueryInsertJobTrigger" assert kwargs == { + "cancel_on_kill": True, "conn_id": TEST_CONN_ID, "job_id": TEST_JOB_ID, "project_id": TEST_GCP_PROJECT_ID, @@ -233,6 +234,41 @@ async def test_bigquery_op_trigger_exception(self, mock_job_status, caplog, inse actual = await generator.asend(None) assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.cancel_job") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status") + async def test_bigquery_insert_job_trigger_cancellation( + self, mock_get_job_status, mock_cancel_job, caplog, insert_job_trigger + ): + """ + Test that BigQueryInsertJobTrigger handles cancellation correctly, logs the appropriate message, + and conditionally cancels the job based on the `cancel_on_kill` attribute. + """ + insert_job_trigger.cancel_on_kill = True + insert_job_trigger.job_id = "1234" + + mock_get_job_status.side_effect = [ + {"status": "running", "message": "Job is still running"}, + asyncio.CancelledError(), + ] + + mock_cancel_job.return_value = asyncio.Future() + mock_cancel_job.return_value.set_result(None) + + caplog.set_level(logging.INFO) + + try: + async for _ in insert_job_trigger.run(): + pass + except asyncio.CancelledError: + pass + + assert ( + "Task was killed" in caplog.text + or "Bigquery job status is running. Sleeping for 4.0 seconds." in caplog.text + ), "Expected messages about task status or cancellation not found in log." + mock_cancel_job.assert_awaited_once() + class TestBigQueryGetDataTrigger: def test_bigquery_get_data_trigger_serialization(self, get_data_trigger):