diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index e313b95f93315..421d11c09e55c 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -54,8 +54,6 @@ class GlueJobHook(AwsBaseHook): - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` """ - JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds - class LogContinuationTokens: """Used to hold the continuation tokens when reading logs from both streams Glue Jobs write to.""" @@ -75,6 +73,7 @@ def __init__( iam_role_name: str | None = None, create_job_kwargs: dict | None = None, update_config: bool = False, + job_poll_interval: int | float = 6, *args, **kwargs, ): @@ -88,6 +87,7 @@ def __init__( self.s3_glue_logs = "logs/glue-logs/" self.create_job_kwargs = create_job_kwargs or {} self.update_config = update_config + self.job_poll_interval = job_poll_interval worker_type_exists = "WorkerType" in self.create_job_kwargs num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs @@ -278,7 +278,7 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d if ret: return ret else: - time.sleep(self.JOB_POLL_INTERVAL) + time.sleep(self.job_poll_interval) async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]: """ @@ -297,7 +297,7 @@ async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = if ret: return ret else: - await asyncio.sleep(self.JOB_POLL_INTERVAL) + await asyncio.sleep(self.job_poll_interval) def _handle_state( self, diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 37010b6fd876a..1d6146e42b9a2 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -99,6 +99,7 @@ def __init__( deferrable: bool = False, verbose: bool = False, update_config: bool = False, + job_poll_interval: int | float = 6, **kwargs, ): super().__init__(**kwargs) @@ -121,6 +122,7 @@ def __init__( self.verbose = verbose self.update_config = update_config self.deferrable = deferrable + self.job_poll_interval = job_poll_interval def execute(self, context: Context): """Execute AWS Glue Job from Airflow. @@ -151,6 +153,7 @@ def execute(self, context: Context): iam_role_name=self.iam_role_name, create_job_kwargs=self.create_job_kwargs, update_config=self.update_config, + job_poll_interval=self.job_poll_interval, ) self.log.info( "Initializing AWS Glue Job: %s. Wait for completion: %s", @@ -181,6 +184,7 @@ def execute(self, context: Context): run_id=glue_job_run["JobRunId"], verbose=self.verbose, aws_conn_id=self.aws_conn_id, + job_poll_interval=self.job_poll_interval, ), method_name="execute_complete", ) diff --git a/airflow/providers/amazon/aws/triggers/glue.py b/airflow/providers/amazon/aws/triggers/glue.py index 71df2f6cb5d15..00529330a7743 100644 --- a/airflow/providers/amazon/aws/triggers/glue.py +++ b/airflow/providers/amazon/aws/triggers/glue.py @@ -39,11 +39,14 @@ def __init__( run_id: str, verbose: bool, aws_conn_id: str, + job_poll_interval: int | float, ): + super().__init__() self.job_name = job_name self.run_id = run_id self.verbose = verbose self.aws_conn_id = aws_conn_id + self.job_poll_interval = job_poll_interval def serialize(self) -> tuple[str, dict[str, Any]]: return ( @@ -54,10 +57,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "run_id": self.run_id, "verbose": str(self.verbose), "aws_conn_id": self.aws_conn_id, + "job_poll_interval": self.job_poll_interval, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: - hook = GlueJobHook(aws_conn_id=self.aws_conn_id) + hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval) await hook.async_job_completion(self.job_name, self.run_id, self.verbose) yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id}) diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 1497affd8622a..c41598f3d9aa6 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -353,8 +353,7 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M @mock.patch.object(GlueJobHook, "get_job_state") def test_job_completion_success(self, get_state_mock: MagicMock): - hook = GlueJobHook() - hook.JOB_POLL_INTERVAL = 0 + hook = GlueJobHook(job_poll_interval=0) get_state_mock.side_effect = [ "RUNNING", "RUNNING", @@ -368,8 +367,7 @@ def test_job_completion_success(self, get_state_mock: MagicMock): @mock.patch.object(GlueJobHook, "get_job_state") def test_job_completion_failure(self, get_state_mock: MagicMock): - hook = GlueJobHook() - hook.JOB_POLL_INTERVAL = 0 + hook = GlueJobHook(job_poll_interval=0) get_state_mock.side_effect = [ "RUNNING", "RUNNING", @@ -384,8 +382,7 @@ def test_job_completion_failure(self, get_state_mock: MagicMock): @pytest.mark.asyncio @mock.patch.object(GlueJobHook, "async_get_job_state") async def test_async_job_completion_success(self, get_state_mock: MagicMock): - hook = GlueJobHook() - hook.JOB_POLL_INTERVAL = 0 + hook = GlueJobHook(job_poll_interval=0) get_state_mock.side_effect = [ "RUNNING", "RUNNING", @@ -400,8 +397,7 @@ async def test_async_job_completion_success(self, get_state_mock: MagicMock): @pytest.mark.asyncio @mock.patch.object(GlueJobHook, "async_get_job_state") async def test_async_job_completion_failure(self, get_state_mock: MagicMock): - hook = GlueJobHook() - hook.JOB_POLL_INTERVAL = 0 + hook = GlueJobHook(job_poll_interval=0) get_state_mock.side_effect = [ "RUNNING", "RUNNING", diff --git a/tests/providers/amazon/aws/triggers/test_glue.py b/tests/providers/amazon/aws/triggers/test_glue.py index 014658b9503a3..cc98ecc74831b 100644 --- a/tests/providers/amazon/aws/triggers/test_glue.py +++ b/tests/providers/amazon/aws/triggers/test_glue.py @@ -30,12 +30,12 @@ class TestGlueJobTrigger: @pytest.mark.asyncio @mock.patch.object(GlueJobHook, "async_get_job_state") async def test_wait_job(self, get_state_mock: mock.MagicMock): - GlueJobHook.JOB_POLL_INTERVAL = 0.1 trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", + job_poll_interval=0.1, ) get_state_mock.side_effect = [ "RUNNING", @@ -52,12 +52,12 @@ async def test_wait_job(self, get_state_mock: mock.MagicMock): @pytest.mark.asyncio @mock.patch.object(GlueJobHook, "async_get_job_state") async def test_wait_job_failed(self, get_state_mock: mock.MagicMock): - GlueJobHook.JOB_POLL_INTERVAL = 0.1 trigger = GlueJobCompleteTrigger( job_name="job_name", run_id="JobRunId", verbose=False, aws_conn_id="aws_conn_id", + job_poll_interval=0.1, ) get_state_mock.side_effect = [ "RUNNING",