Skip to content

Commit

Permalink
Add an option to GlueJobOperator to stop the job run when the TI is…
Browse files Browse the repository at this point in the history
… killed (#32155)

---------

Signed-off-by: Hussein Awala <hussein@awala.fr>
  • Loading branch information
hussein-awala committed Jun 28, 2023
1 parent 98c47f4 commit 1d60332
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 18 deletions.
56 changes: 38 additions & 18 deletions airflow/providers/amazon/aws/operators/glue.py
Expand Up @@ -19,6 +19,7 @@

import os.path
import urllib.parse
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow import AirflowException
Expand Down Expand Up @@ -60,6 +61,7 @@ class GlueJobOperator(BaseOperator):
(default: False)
:param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False)
:param update_config: If True, Operator will update job configuration. (default: False)
:param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
verbose: bool = False,
update_config: bool = False,
job_poll_interval: int | float = 6,
stop_job_run_on_kill: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -123,12 +126,11 @@ def __init__(
self.update_config = update_config
self.deferrable = deferrable
self.job_poll_interval = job_poll_interval
self.stop_job_run_on_kill = stop_job_run_on_kill
self._job_run_id: str | None = None

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.
:return: the current Glue job ID.
"""
@cached_property
def glue_job_hook(self) -> GlueJobHook:
if self.script_location is None:
s3_script_location = None
elif not self.script_location.startswith(self.s3_protocol):
Expand All @@ -140,7 +142,7 @@ def execute(self, context: Context):
s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
else:
s3_script_location = self.script_location
glue_job = GlueJobHook(
return GlueJobHook(
job_name=self.job_name,
desc=self.job_desc,
concurrent_run_limit=self.concurrent_run_limit,
Expand All @@ -155,52 +157,70 @@ def execute(self, context: Context):
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
)

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.
:return: the current Glue job ID.
"""
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
self.job_name,
self.wait_for_completion,
)
glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs)
glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs)
self._job_run_id = glue_job_run["JobRunId"]
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
aws_domain=GlueJobRunDetailsLink.get_aws_domain(glue_job.conn_partition),
region_name=glue_job.conn_region_name,
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition),
region_name=self.glue_job_hook.conn_region_name,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=glue_job_run["JobRunId"],
job_run_id=self._job_run_id,
)
GlueJobRunDetailsLink.persist(
context=context,
operator=self,
region_name=glue_job.conn_region_name,
aws_partition=glue_job.conn_partition,
region_name=self.glue_job_hook.conn_region_name,
aws_partition=self.glue_job_hook.conn_partition,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=glue_job_run["JobRunId"],
job_run_id=self._job_run_id,
)
self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url)

if self.deferrable:
self.defer(
trigger=GlueJobCompleteTrigger(
job_name=self.job_name,
run_id=glue_job_run["JobRunId"],
run_id=self._job_run_id,
verbose=self.verbose,
aws_conn_id=self.aws_conn_id,
job_poll_interval=self.job_poll_interval,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
glue_job_run = self.glue_job_hook.job_completion(self.job_name, self._job_run_id, self.verbose)
self.log.info(
"AWS Glue Job: %s status: %s. Run Id: %s",
self.job_name,
glue_job_run["JobRunState"],
glue_job_run["JobRunId"],
self._job_run_id,
)
else:
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
return glue_job_run["JobRunId"]
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
return self._job_run_id

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in glue job: {event}")
return event["value"]

def on_kill(self):
"""Cancel the running AWS Glue Job."""
if self.stop_job_run_on_kill:
self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
response = self.glue_job_hook.conn.batch_stop_job_run(
JobName=self.job_name,
JobRunIds=[self._job_run_id],
)
if not response["SuccessfulSubmissions"]:
self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
43 changes: 43 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Expand Up @@ -207,3 +207,46 @@ def test_log_correct_url(
assert job_run_id == JOB_RUN_ID

mock_log_info.assert_any_call("You can monitor this Glue Job run at: %s", glue_job_run_url)

@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(GlueJobHook, "get_conn")
def test_killed_without_stop_job_run_on_kill(
self,
_,
mock_glue_hook,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
)
glue.on_kill()
mock_glue_hook.batch_stop_job_run.assert_not_called()

@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(GlueJobHook, "get_conn")
def test_killed_with_stop_job_run_on_kill(
self,
_,
mock_glue_hook,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
stop_job_run_on_kill=True,
)
glue._job_run_id = JOB_RUN_ID
glue.on_kill()
mock_glue_hook.batch_stop_job_run.assert_called_once_with(
JobName=JOB_NAME,
JobRunIds=[JOB_RUN_ID],
)

0 comments on commit 1d60332

Please sign in to comment.