Skip to content

Commit

Permalink
Add realtime container execution logs for BatchOperator (#31837)
Browse files Browse the repository at this point in the history
* Update param description for get_batch_log_fetcher in batch_waiters

* Update the log fetcher with latest changes with continuation token
  • Loading branch information
killua1zoldyck committed Jun 19, 2023
1 parent eb27641 commit e01ff47
Show file tree
Hide file tree
Showing 12 changed files with 454 additions and 246 deletions.
23 changes: 21 additions & 2 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Expand Up @@ -28,13 +28,15 @@

from random import uniform
from time import sleep
from typing import Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.typing_compat import Protocol, runtime_checkable


Expand Down Expand Up @@ -253,19 +255,36 @@ def check_job_success(self, job_id: str) -> bool:

raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")

def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
def wait_for_job(
self,
job_id: str,
delay: int | float | None = None,
get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] | None = None,
) -> None:
"""
Wait for Batch job to complete.
:param job_id: a Batch job ID
:param delay: a delay before polling for job status
:param get_batch_log_fetcher : a method that returns batch_log_fetcher
:raises: AirflowException
"""
self.delay(delay)
self.poll_for_job_running(job_id, delay)
self.poll_for_job_complete(job_id, delay)
batch_log_fetcher = None
try:
if get_batch_log_fetcher:
batch_log_fetcher = get_batch_log_fetcher(job_id)
if batch_log_fetcher:
batch_log_fetcher.start()
self.poll_for_job_complete(job_id, delay)
finally:
if batch_log_fetcher:
batch_log_fetcher.stop()
batch_log_fetcher.join()
self.log.info("AWS Batch job (%s) has completed", job_id)

def poll_for_job_running(self, job_id: str, delay: int | float | None = None) -> None:
Expand Down
30 changes: 25 additions & 5 deletions airflow/providers/amazon/aws/hooks/batch_waiters.py
Expand Up @@ -29,13 +29,15 @@
import sys
from copy import deepcopy
from pathlib import Path
from typing import Callable

import botocore.client
import botocore.exceptions
import botocore.waiter

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher


class BatchWaitersHook(BatchClientHook):
Expand Down Expand Up @@ -184,7 +186,12 @@ def list_waiters(self) -> list[str]:
"""
return self.waiter_model.waiter_names

def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
def wait_for_job(
self,
job_id: str,
delay: int | float | None = None,
get_batch_log_fetcher: Callable[[str], AwsTaskLogFetcher | None] | None = None,
) -> None:
"""
Wait for Batch job to complete. This assumes that the ``.waiter_model`` is configured
using some variation of the ``.default_config`` so that it can generate waiters with the
Expand All @@ -194,6 +201,9 @@ def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
:param delay: A delay before polling for job status
:param get_batch_log_fetcher: A method that returns batch_log_fetcher of
type AwsTaskLogFetcher or None when the CloudWatch log stream hasn't been created yet.
:raises: AirflowException
.. note::
Expand All @@ -216,10 +226,20 @@ def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])

waiter = self.get_waiter("JobComplete")
waiter.config.delay = self.add_jitter(waiter.config.delay, width=2, minima=1)
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])
batch_log_fetcher = None
try:
if get_batch_log_fetcher:
batch_log_fetcher = get_batch_log_fetcher(job_id)
if batch_log_fetcher:
batch_log_fetcher.start()
waiter = self.get_waiter("JobComplete")
waiter.config.delay = self.add_jitter(waiter.config.delay, width=2, minima=1)
waiter.config.max_attempts = sys.maxsize # timeout is managed by Airflow
waiter.wait(jobs=[job_id])
finally:
if batch_log_fetcher:
batch_log_fetcher.stop()
batch_log_fetcher.join()

except (botocore.exceptions.ClientError, botocore.exceptions.WaiterError) as err:
raise AirflowException(err)
89 changes: 0 additions & 89 deletions airflow/providers/amazon/aws/hooks/ecs.py
Expand Up @@ -17,19 +17,10 @@
# under the License.
from __future__ import annotations

import time
from collections import deque
from datetime import datetime, timedelta
from logging import Logger
from threading import Event, Thread
from typing import Generator

from botocore.exceptions import ClientError, ConnectionClosedError
from botocore.waiter import Waiter

from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.providers.amazon.aws.utils import _StringCompareEnum
from airflow.typing_compat import Protocol, runtime_checkable

Expand Down Expand Up @@ -143,86 +134,6 @@ def get_task_state(self, cluster, task) -> str:
return self.conn.describe_tasks(cluster=cluster, tasks=[task])["tasks"][0]["lastStatus"]


class EcsTaskLogFetcher(Thread):
"""
Fetches Cloudwatch log events with specific interval as a thread
and sends the log events to the info channel of the provided logger.
"""

def __init__(
self,
*,
log_group: str,
log_stream_name: str,
fetch_interval: timedelta,
logger: Logger,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
):
super().__init__()
self._event = Event()

self.fetch_interval = fetch_interval

self.logger = logger
self.log_group = log_group
self.log_stream_name = log_stream_name

self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name)

def run(self) -> None:
continuation_token = AwsLogsHook.ContinuationToken()
while not self.is_stopped():
time.sleep(self.fetch_interval.total_seconds())
log_events = self._get_log_events(continuation_token)
for log_event in log_events:
self.logger.info(self._event_to_str(log_event))

def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None = None) -> Generator:
if skip_token is None:
skip_token = AwsLogsHook.ContinuationToken()
try:
yield from self.hook.get_log_events(
self.log_group, self.log_stream_name, continuation_token=skip_token
)
except ClientError as error:
if error.response["Error"]["Code"] != "ResourceNotFoundException":
self.logger.warning("Error on retrieving Cloudwatch log events", error)
else:
self.logger.info(
"Cannot find log stream yet, it can take a couple of seconds to show up. "
"If this error persists, check that the log group and stream are correct: "
"group: %s\tstream: %s",
self.log_group,
self.log_stream_name,
)
yield from ()
except ConnectionClosedError as error:
self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error)
yield from ()

def _event_to_str(self, event: dict) -> str:
event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
message = event["message"]
return f"[{formatted_event_dt}] {message}"

def get_last_log_messages(self, number_messages) -> list:
return [log["message"] for log in deque(self._get_log_events(), maxlen=number_messages)]

def get_last_log_message(self) -> str | None:
try:
return self.get_last_log_messages(1)[0]
except IndexError:
return None

def is_stopped(self) -> bool:
return self._event.is_set()

def stop(self):
self._event.set()


@runtime_checkable
class EcsProtocol(Protocol):
"""
Expand Down
39 changes: 36 additions & 3 deletions airflow/providers/amazon/aws/operators/batch.py
Expand Up @@ -25,6 +25,7 @@
from __future__ import annotations

import warnings
from datetime import timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

Expand All @@ -39,6 +40,7 @@
from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
from airflow.providers.amazon.aws.triggers.batch import BatchOperatorTrigger
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -79,6 +81,10 @@ class BatchOperator(BaseOperator):
:param tags: collection of tags to apply to the AWS Batch job submission
if None, no tags are submitted
:param deferrable: Run operator in the deferrable mode.
:param awslogs_enabled: Specifies whether logs from CloudWatch
should be printed or not, False.
If it is an array job, only the logs of the first task will be printed.
:param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec.
:param poll_interval: (Deferrable mode only) Time in seconds to wait between polling.
.. note::
Expand All @@ -104,6 +110,8 @@ class BatchOperator(BaseOperator):
"waiters",
"tags",
"wait_for_completion",
"awslogs_enabled",
"awslogs_fetch_interval",
)
template_fields_renderers = {
"container_overrides": "json",
Expand Down Expand Up @@ -145,6 +153,8 @@ def __init__(
wait_for_completion: bool = True,
deferrable: bool = False,
poll_interval: int = 30,
awslogs_enabled: bool = False,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
**kwargs,
) -> None:
BaseOperator.__init__(self, **kwargs)
Expand Down Expand Up @@ -179,6 +189,8 @@ def __init__(
self.wait_for_completion = wait_for_completion
self.deferrable = deferrable
self.poll_interval = poll_interval
self.awslogs_enabled = awslogs_enabled
self.awslogs_fetch_interval = awslogs_fetch_interval

# params for hook
self.max_retries = max_retries
Expand Down Expand Up @@ -319,10 +331,16 @@ def monitor_job(self, context: Context):
job_queue_arn=job_queue_arn,
)

if self.waiters:
self.waiters.wait_for_job(self.job_id)
if self.awslogs_enabled:
if self.waiters:
self.waiters.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.hook.wait_for_job(self.job_id, get_batch_log_fetcher=self._get_batch_log_fetcher)
else:
self.hook.wait_for_job(self.job_id)
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)

awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
if awslogs:
Expand All @@ -347,6 +365,21 @@ def monitor_job(self, context: Context):
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)

def _get_batch_log_fetcher(self, job_id: str) -> AwsTaskLogFetcher | None:
awslog_info = self.hook.get_job_awslogs_info(job_id)

if not awslog_info:
return None

return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=awslog_info["awslogs_region"],
log_group=awslog_info["awslogs_group"],
log_stream_name=awslog_info["awslogs_stream_name"],
fetch_interval=self.awslogs_fetch_interval,
logger=self.log,
)


class BatchCreateComputeEnvironmentOperator(BaseOperator):
"""Create an AWS Batch compute environment.
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/operators/ecs.py
Expand Up @@ -33,9 +33,9 @@
from airflow.providers.amazon.aws.hooks.ecs import (
EcsClusterStates,
EcsHook,
EcsTaskLogFetcher,
should_retry_eni,
)
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
from airflow.utils.helpers import prune_dict
from airflow.utils.session import provide_session

Expand Down Expand Up @@ -447,7 +447,7 @@ def __init__(

self.arn: str | None = None
self.retry_args = quota_retry
self.task_log_fetcher: EcsTaskLogFetcher | None = None
self.task_log_fetcher: AwsTaskLogFetcher | None = None
self.wait_for_completion = wait_for_completion
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
Expand Down Expand Up @@ -597,12 +597,12 @@ def _wait_for_task_ended(self) -> None:
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix

def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"

return EcsTaskLogFetcher(
return AwsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
Expand Down

0 comments on commit e01ff47

Please sign in to comment.