Skip to content

Commit

Permalink
Add support in AWS Batch Operator for multinode jobs (#29522)
Browse files Browse the repository at this point in the history
picking up #28321 after it's been somewhat abandoned by the original author.
Addressed my own comment about empty array, and it should be good to go I think.

Initial description from @camilleanne:

Adds support for AWS Batch multinode jobs by allowing a node_overrides json object to be passed through to the boto3 submit_job method.

Adds support for multinode jobs by properly parsing the output of describe_jobs (which is different for container vs multinode) to extract the log stream name.
closes: #25522
  • Loading branch information
vandonr-amz committed Apr 12, 2023
1 parent f080e1e commit 2ce1130
Show file tree
Hide file tree
Showing 4 changed files with 286 additions and 63 deletions.
87 changes: 60 additions & 27 deletions airflow/providers/amazon/aws/hooks/batch_client.py
Expand Up @@ -414,43 +414,76 @@ def parse_job_description(job_id: str, response: dict) -> dict:
return matching_jobs[0]

def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None:
all_info = self.get_job_all_awslogs_info(job_id)
if not all_info:
return None
if len(all_info) > 1:
self.log.warning(
f"AWS Batch job ({job_id}) has more than one log stream, " f"only returning the first one."
)
return all_info[0]

def get_job_all_awslogs_info(self, job_id: str) -> list[dict[str, str]]:
"""
Parse job description to extract AWS CloudWatch information.
:param job_id: AWS Batch Job ID
"""
job_container_desc = self.get_job_description(job_id=job_id).get("container", {})
log_configuration = job_container_desc.get("logConfiguration", {})

# In case if user select other "logDriver" rather than "awslogs"
# than CloudWatch logging should be disabled.
# If user not specify anything than expected that "awslogs" will use
# with default settings:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
log_driver = log_configuration.get("logDriver", "awslogs")
if log_driver != "awslogs":
job_desc = self.get_job_description(job_id=job_id)

job_node_properties = job_desc.get("nodeProperties", {})
job_container_desc = job_desc.get("container", {})

if job_node_properties:
# one log config per node
log_configs = [
p.get("container", {}).get("logConfiguration", {})
for p in job_node_properties.get("nodeRangeProperties", {})
]
# one stream name per attempt
stream_names = [a.get("container", {}).get("logStreamName") for a in job_desc.get("attempts", [])]
elif job_container_desc:
log_configs = [job_container_desc.get("logConfiguration", {})]
stream_name = job_container_desc.get("logStreamName")
stream_names = [stream_name] if stream_name is not None else []
else:
raise AirflowException(
f"AWS Batch job ({job_id}) is not a supported job type. "
"Supported job types: container, array, multinode."
)

# If the user selected another logDriver than "awslogs", then CloudWatch logging is disabled.
if any([c.get("logDriver", "awslogs") != "awslogs" for c in log_configs]):
self.log.warning(
"AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver
f"AWS Batch job ({job_id}) uses non-aws log drivers. AWS CloudWatch logging disabled."
)
return None
return []

awslogs_stream_name = job_container_desc.get("logStreamName")
if not awslogs_stream_name:
# In case of call this method on very early stage of running AWS Batch
# there is possibility than AWS CloudWatch Stream Name not exists yet.
# AWS CloudWatch Stream Name also not created in case of misconfiguration.
self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id)
return None
if not stream_names:
# If this method is called very early after starting the AWS Batch job,
# there is a possibility that the AWS CloudWatch Stream Name would not exist yet.
# This can also happen in case of misconfiguration.
self.log.warning(f"AWS Batch job ({job_id}) doesn't have any AWS CloudWatch Stream.")
return []

# Try to get user-defined log configuration options
log_options = log_configuration.get("options", {})

return {
"awslogs_stream_name": awslogs_stream_name,
"awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": log_options.get("awslogs-region", self.conn_region_name),
}
log_options = [c.get("options", {}) for c in log_configs]

# cross stream names with options (i.e. attempts X nodes) to generate all log infos
result = []
for stream in stream_names:
for option in log_options:
result.append(
{
"awslogs_stream_name": stream,
# If the user did not specify anything, the default settings are:
# awslogs-group = /aws/batch/job
# awslogs-region = `same as AWS Batch Job region`
"awslogs_group": option.get("awslogs-group", "/aws/batch/job"),
"awslogs_region": option.get("awslogs-region", self.conn_region_name),
}
)
return result

@staticmethod
def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
Expand Down
102 changes: 78 additions & 24 deletions airflow/providers/amazon/aws/operators/batch.py
Expand Up @@ -25,6 +25,7 @@
"""
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Sequence

from airflow.compat.functools import cached_property
Expand Down Expand Up @@ -54,7 +55,9 @@ class BatchOperator(BaseOperator):
:param job_name: the name for the job that will run on AWS Batch (templated)
:param job_definition: the job definition name on AWS Batch
:param job_queue: the queue name on AWS Batch
:param overrides: the `containerOverrides` parameter for boto3 (templated)
:param overrides: DEPRECATED, use container_overrides instead with the same value.
:param container_overrides: the `containerOverrides` parameter for boto3 (templated)
:param node_overrides: the `nodeOverrides` parameter for boto3 (templated)
:param array_properties: the `arrayProperties` parameter for boto3
:param parameters: the `parameters` for boto3 (templated)
:param job_id: the job ID, usually unknown (None) until the
Expand Down Expand Up @@ -88,14 +91,19 @@ class BatchOperator(BaseOperator):
"job_name",
"job_definition",
"job_queue",
"overrides",
"container_overrides",
"array_properties",
"node_overrides",
"parameters",
"waiters",
"tags",
"wait_for_completion",
)
template_fields_renderers = {"overrides": "json", "parameters": "json"}
template_fields_renderers = {
"container_overrides": "json",
"parameters": "json",
"node_overrides": "json",
}

@property
def operator_extra_links(self):
Expand All @@ -114,8 +122,10 @@ def __init__(
job_name: str,
job_definition: str,
job_queue: str,
overrides: dict,
overrides: dict | None = None, # deprecated
container_overrides: dict | None = None,
array_properties: dict | None = None,
node_overrides: dict | None = None,
parameters: dict | None = None,
job_id: str | None = None,
waiters: Any | None = None,
Expand All @@ -133,17 +143,43 @@ def __init__(
self.job_name = job_name
self.job_definition = job_definition
self.job_queue = job_queue
self.overrides = overrides or {}
self.array_properties = array_properties or {}

self.container_overrides = container_overrides
# handle `overrides` deprecation in favor of `container_overrides`
if overrides:
if container_overrides:
# disallow setting both old and new params
raise AirflowException(
"'container_overrides' replaces the 'overrides' parameter. "
"You cannot specify both. Please remove assignation to the deprecated 'overrides'."
)
self.container_overrides = overrides
warnings.warn(
"Parameter `overrides` is deprecated, Please use `container_overrides` instead.",
DeprecationWarning,
stacklevel=2,
)

self.node_overrides = node_overrides
self.array_properties = array_properties
self.parameters = parameters or {}
self.waiters = waiters
self.tags = tags or {}
self.wait_for_completion = wait_for_completion
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
aws_conn_id=aws_conn_id,
region_name=region_name,

# params for hook
self.max_retries = max_retries
self.status_retries = status_retries
self.aws_conn_id = aws_conn_id
self.region_name = region_name

@cached_property
def hook(self) -> BatchClientHook:
return BatchClientHook(
max_retries=self.max_retries,
status_retries=self.status_retries,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
)

def execute(self, context: Context):
Expand Down Expand Up @@ -174,18 +210,27 @@ def submit_job(self, context: Context):
self.job_definition,
self.job_queue,
)
self.log.info("AWS Batch job - container overrides: %s", self.overrides)

if self.container_overrides:
self.log.info("AWS Batch job - container overrides: %s", self.container_overrides)
if self.array_properties:
self.log.info("AWS Batch job - array properties: %s", self.array_properties)
if self.node_overrides:
self.log.info("AWS Batch job - node properties: %s", self.node_overrides)

args = {
"jobName": self.job_name,
"jobQueue": self.job_queue,
"jobDefinition": self.job_definition,
"arrayProperties": self.array_properties,
"parameters": self.parameters,
"tags": self.tags,
"containerOverrides": self.container_overrides,
"nodeOverrides": self.node_overrides,
}

try:
response = self.hook.client.submit_job(
jobName=self.job_name,
jobQueue=self.job_queue,
jobDefinition=self.job_definition,
arrayProperties=self.array_properties,
parameters=self.parameters,
containerOverrides=self.overrides,
tags=self.tags,
)
response = self.hook.client.submit_job(**trim_none_values(args))
except Exception as e:
self.log.error(
"AWS Batch job failed submission - job definition: %s - on queue %s",
Expand Down Expand Up @@ -249,15 +294,24 @@ def monitor_job(self, context: Context):
else:
self.hook.wait_for_job(self.job_id)

awslogs = self.hook.get_job_awslogs_info(self.job_id)
awslogs = self.hook.get_job_all_awslogs_info(self.job_id)
if awslogs:
self.log.info("AWS Batch job (%s) CloudWatch Events details found: %s", self.job_id, awslogs)
self.log.info("AWS Batch job (%s) CloudWatch Events details found. Links to logs:", self.job_id)
link_builder = CloudWatchEventsLink()
for log in awslogs:
self.log.info(link_builder.format_link(**log))
if len(awslogs) > 1:
# there can be several log streams on multi-node jobs
self.log.warning(
"out of all those logs, we can only link to one in the UI. " "Using the first one."
)

CloudWatchEventsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
**awslogs,
**awslogs[0],
)

self.hook.check_job_success(self.job_id)
Expand Down
74 changes: 71 additions & 3 deletions tests/providers/amazon/aws/hooks/test_batch_client.py
Expand Up @@ -280,15 +280,24 @@ def test_job_no_awslogs_stream(self, caplog):
"jobs": [
{
"jobId": JOB_ID,
"container": {},
"container": {"logConfiguration": {}},
}
]
}

with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
assert "doesn't create AWS CloudWatch Stream" in caplog.messages[0]
assert "doesn't have any AWS CloudWatch Stream" in caplog.messages[0]

def test_job_not_recognized_job(self):
self.client_mock.describe_jobs.return_value = {"jobs": [{"jobId": JOB_ID}]}
with pytest.raises(AirflowException) as ctx:
self.batch_client.get_job_awslogs_info(JOB_ID)
# It should not retry when this client error occurs
self.client_mock.describe_jobs.assert_called_once_with(jobs=[JOB_ID])
msg = "is not a supported job type"
assert msg in str(ctx.value)

def test_job_splunk_logs(self, caplog):
self.client_mock.describe_jobs.return_value = {
Expand All @@ -307,7 +316,66 @@ def test_job_splunk_logs(self, caplog):
with caplog.at_level(level=logging.WARNING):
assert self.batch_client.get_job_awslogs_info(JOB_ID) is None
assert len(caplog.records) == 1
assert "uses logDriver (splunk). AWS CloudWatch logging disabled." in caplog.messages[0]
assert "uses non-aws log drivers. AWS CloudWatch logging disabled." in caplog.messages[0]

def test_job_awslogs_multinode_job(self):
self.client_mock.describe_jobs.return_value = {
"jobs": [
{
"jobId": JOB_ID,
"attempts": [
{"container": {"exitCode": 0, "logStreamName": "test/stream/attempt0"}},
{"container": {"exitCode": 0, "logStreamName": "test/stream/attempt1"}},
],
"nodeProperties": {
"mainNode": 0,
"nodeRangeProperties": [
{
"targetNodes": "0:",
"container": {
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/test/batch/job-a",
"awslogs-region": AWS_REGION,
},
}
},
},
{
"targetNodes": "1:",
"container": {
"logConfiguration": {
"logDriver": "awslogs",
"options": {
"awslogs-group": "/test/batch/job-b",
"awslogs-region": AWS_REGION,
},
}
},
},
],
},
}
]
}
awslogs = self.batch_client.get_job_all_awslogs_info(JOB_ID)
assert len(awslogs) == 4
assert all([log["awslogs_region"] == AWS_REGION for log in awslogs])

combinations = {
("test/stream/attempt0", "/test/batch/job-a"): False,
("test/stream/attempt0", "/test/batch/job-b"): False,
("test/stream/attempt1", "/test/batch/job-a"): False,
("test/stream/attempt1", "/test/batch/job-b"): False,
}
for log_info in awslogs:
# mark combinations that we see
combinations[(log_info["awslogs_stream_name"], log_info["awslogs_group"])] = True

assert len(combinations) == 4
# all combinations listed above should have been seen
assert all(combinations.values())


class TestBatchClientDelays:
Expand Down

0 comments on commit 2ce1130

Please sign in to comment.