Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 2 additions & 44 deletions airflow/providers/amazon/aws/log/s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
from functools import cached_property
from typing import TYPE_CHECKING

from packaging.version import Version

from airflow.configuration import conf
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.log.file_task_handler import FileTaskHandler
Expand All @@ -35,18 +33,6 @@
from airflow.models.taskinstance import TaskInstance


def get_default_delete_local_copy():
"""Load delete_local_logs conf if Airflow version > 2.6 and return False if not.

TODO: delete this function when min airflow version >= 2.6
"""
from airflow.version import version

if Version(version) < Version("2.6"):
return False
return conf.getboolean("logging", "delete_local_logs")


class S3TaskHandler(FileTaskHandler, LoggingMixin):
"""
S3TaskHandler is a python log handler that handles and reads task instance logs.
Expand All @@ -66,8 +52,8 @@ def __init__(
self._hook = None
self.closed = False
self.upload_on_close = True
self.delete_local_copy = (
kwargs["delete_local_copy"] if "delete_local_copy" in kwargs else get_default_delete_local_copy()
self.delete_local_copy = kwargs.get(
"delete_local_copy", conf.getboolean("logging", "delete_local_logs")
)

@cached_property
Expand Down Expand Up @@ -145,34 +131,6 @@ def _read_remote_logs(self, ti, try_number, metadata=None) -> tuple[list[str], l
messages.append(f"No logs found on s3 for ti={ti}")
return messages, logs

def _read(self, ti, try_number, metadata=None):
"""
Read logs of given task instance and try_number from S3 remote storage.

If failed, read the log from task instance host machine.

todo: when min airflow version >= 2.6 then remove this method (``_read``)

:param ti: task instance object
:param try_number: task instance try_number to read logs from
:param metadata: log metadata,
can be used for steaming log reading and auto-tailing.
"""
# from airflow 2.6 we no longer implement the _read method
if hasattr(super(), "_read_remote_logs"):
return super()._read(ti, try_number, metadata)
# if we get here, we're on airflow < 2.6 and we use this backcompat logic
messages, logs = self._read_remote_logs(ti, try_number, metadata)
if logs:
return "".join(f"*** {x}\n" for x in messages) + "\n".join(logs), {"end_of_log": True}
else:
if metadata and metadata.get("log_pos", 0) > 0:
log_prefix = ""
else:
log_prefix = "*** Falling back to local log\n"
local_log, metadata = super()._read(ti, try_number, metadata)
return f"{log_prefix}{local_log}", metadata

def s3_log_exists(self, remote_log_location: str) -> bool:
"""
Check if remote_log_location exists in remote storage.
Expand Down
39 changes: 4 additions & 35 deletions tests/providers/amazon/aws/log/test_s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,33 +148,6 @@ def test_read_when_s3_log_missing(self):
assert actual == expected
assert {"end_of_log": True, "log_pos": 0} == metadata[0]

def test_read_when_s3_log_missing_and_log_pos_missing_pre_26(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
# mock that super class has no _read_remote_logs method
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.hasattr", return_value=False):
log, metadata = self.s3_task_handler.read(ti)
assert 1 == len(log)
assert log[0][0][-1].startswith("*** Falling back to local log")

def test_read_when_s3_log_missing_and_log_pos_zero_pre_26(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
# mock that super class has no _read_remote_logs method
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.hasattr", return_value=False):
log, metadata = self.s3_task_handler.read(ti, metadata={"log_pos": 0})
assert 1 == len(log)
assert log[0][0][-1].startswith("*** Falling back to local log")

def test_read_when_s3_log_missing_and_log_pos_over_zero_pre_26(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
# mock that super class has no _read_remote_logs method
with mock.patch("airflow.providers.amazon.aws.log.s3_task_handler.hasattr", return_value=False):
log, metadata = self.s3_task_handler.read(ti, metadata={"log_pos": 1})
assert 1 == len(log)
assert not log[0][0][-1].startswith("*** Falling back to local log")

def test_s3_read_when_log_missing(self):
handler = self.s3_task_handler
url = "s3://bucket/foo"
Expand Down Expand Up @@ -240,15 +213,11 @@ def test_close_no_upload(self):
boto3.resource("s3").Object("bucket", self.remote_log_key).get()

@pytest.mark.parametrize(
"delete_local_copy, expected_existence_of_local_copy, airflow_version",
[(True, False, "2.6.0"), (False, True, "2.6.0"), (True, True, "2.5.0"), (False, True, "2.5.0")],
"delete_local_copy, expected_existence_of_local_copy",
[(True, False), (False, True)],
)
def test_close_with_delete_local_logs_conf(
self, delete_local_copy, expected_existence_of_local_copy, airflow_version
):
with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}), mock.patch(
"airflow.version.version", airflow_version
):
def test_close_with_delete_local_logs_conf(self, delete_local_copy, expected_existence_of_local_copy):
with conf_vars({("logging", "delete_local_logs"): str(delete_local_copy)}):
handler = S3TaskHandler(self.local_log_location, self.remote_log_base)

handler.log.info("test")
Expand Down