Skip to content

Commit

Permalink
Fix four bugs in StackdriverTaskHandler (#13784)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Feb 2, 2021
1 parent d65376c commit 833e338
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 45 deletions.
25 changes: 18 additions & 7 deletions airflow/providers/google/cloud/log/stackdriver_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
self.resource: Resource = resource
self.labels: Optional[Dict[str, str]] = labels
self.task_instance_labels: Optional[Dict[str, str]] = {}
self.task_instance_hostname = 'default-hostname'

@cached_property
def _client(self) -> gcp_logging.Client:
Expand Down Expand Up @@ -146,10 +147,11 @@ def set_context(self, task_instance: TaskInstance) -> None:
:type task_instance: :class:`airflow.models.TaskInstance`
"""
self.task_instance_labels = self._task_instance_to_labels(task_instance)
self.task_instance_hostname = task_instance.hostname

def read(
self, task_instance: TaskInstance, try_number: Optional[int] = None, metadata: Optional[Dict] = None
) -> Tuple[List[str], List[Dict]]:
) -> Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]:
"""
Read logs of given task instance from Stackdriver logging.
Expand All @@ -160,12 +162,14 @@ def read(
:type try_number: Optional[int]
:param metadata: log metadata. It is used for steaming log reading and auto-tailing.
:type metadata: Dict
:return: a tuple of list of logs and list of metadata
:rtype: Tuple[List[str], List[Dict]]
:return: a tuple of (
list of (one element tuple with two element tuple - hostname and logs)
and list of metadata)
:rtype: Tuple[List[Tuple[Tuple[str, str]]], List[Dict[str, str]]]
"""
if try_number is not None and try_number < 1:
logs = [f"Error fetching the logs. Try number {try_number} is invalid."]
return logs, [{"end_of_log": "true"}]
logs = f"Error fetching the logs. Try number {try_number} is invalid."
return [((self.task_instance_hostname, logs),)], [{"end_of_log": "true"}]

if not metadata:
metadata = {}
Expand All @@ -188,7 +192,7 @@ def read(
if next_page_token:
new_metadata['next_page_token'] = next_page_token

return [messages], [new_metadata]
return [((self.task_instance_hostname, messages),)], [new_metadata]

def _prepare_log_filter(self, ti_labels: Dict[str, str]) -> str:
"""
Expand Down Expand Up @@ -252,6 +256,8 @@ def _read_logs(
log_filter=log_filter, page_token=next_page_token
)
messages.append(new_messages)
if not messages:
break

end_of_log = True
next_page_token = None
Expand All @@ -271,7 +277,9 @@ def _read_single_logs_page(self, log_filter: str, page_token: Optional[str] = No
:return: Downloaded logs and next page token
:rtype: Tuple[str, str]
"""
entries = self._client.list_entries(filter_=log_filter, page_token=page_token)
entries = self._client.list_entries(
filter_=log_filter, page_token=page_token, order_by='timestamp asc', page_size=1000
)
page = next(entries.pages)
next_page_token = entries.next_page_token
messages = []
Expand Down Expand Up @@ -331,3 +339,6 @@ def get_external_log_url(self, task_instance: TaskInstance, try_number: int) ->

url = f"{self.LOG_VIEWER_BASE_URL}?{urlencode(url_query_string)}"
return url

def close(self) -> None:
self._transport.flush()
6 changes: 3 additions & 3 deletions airflow/utils/log/log_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import logging
from typing import Any, Dict, Iterator, List, Optional, Tuple
from typing import Dict, Iterator, List, Optional, Tuple

from cached_property import cached_property

Expand All @@ -31,7 +31,7 @@ class TaskLogReader:

def read_log_chunks(
self, ti: TaskInstance, try_number: Optional[int], metadata
) -> Tuple[List[str], Dict[str, Any]]:
) -> Tuple[List[Tuple[Tuple[str, str]]], Dict[str, str]]:
"""
Reads chunks of Task Instance logs.
Expand All @@ -42,7 +42,7 @@ def read_log_chunks(
:type try_number: Optional[int]
:param metadata: A dictionary containing information about how to read the task log
:type metadata: dict
:rtype: Tuple[List[str], Dict[str, Any]]
:rtype: Tuple[List[Tuple[Tuple[str, str]]], Dict[str, str]]
The following is an example of how to use this method to read log:
Expand Down
3 changes: 3 additions & 0 deletions tests/cli/commands/test_info_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import contextlib
import importlib
import io
import logging
import os
import unittest
from unittest import mock
Expand Down Expand Up @@ -129,6 +130,8 @@ def test_should_read_logging_configuration(self):
assert "stackdriver" in text

def tearDown(self) -> None:
for handler_ref in logging._handlerList[:]:
logging._removeHandlerRef(handler_ref)
importlib.reload(airflow_local_settings)
configure_logging()

Expand Down
98 changes: 65 additions & 33 deletions tests/providers/google/cloud/log/test_stackdriver_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,21 @@ def _create_list_response(messages, token):
return mock.MagicMock(pages=(n for n in [page]), next_page_token=token)


def _remove_stackdriver_handlers():
for handler_ref in reversed(logging._handlerList[:]):
handler = handler_ref()
if not isinstance(handler, StackdriverTaskHandler):
continue
logging._removeHandlerRef(handler_ref)
del handler


class TestStackdriverLoggingHandlerStandalone(unittest.TestCase):
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_project_id):
self.addCleanup(_remove_stackdriver_handlers)

mock_get_creds_and_project_id.return_value = ('creds', 'project_id')

transport_type = mock.MagicMock()
Expand Down Expand Up @@ -69,6 +80,7 @@ def setUp(self) -> None:
self.ti.try_number = 1
self.ti.state = State.RUNNING
self.addCleanup(self.dag.clear)
self.addCleanup(_remove_stackdriver_handlers)

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client')
Expand Down Expand Up @@ -128,14 +140,18 @@ def test_should_read_logs_for_all_try(self, mock_client, mock_get_creds_and_proj

logs, metadata = self.stackdriver_task_handler.read(self.ti)
mock_client.return_value.list_entries.assert_called_once_with(
filter_='resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"',
filter_=(
'resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"'
),
order_by='timestamp asc',
page_size=1000,
page_token=None,
)
assert ['MSG1\nMSG2'] == logs
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
assert [{'end_of_log': True}] == metadata

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand All @@ -149,14 +165,18 @@ def test_should_read_logs_for_task_with_quote(self, mock_client, mock_get_creds_
self.ti.task_id = "K\"OT"
logs, metadata = self.stackdriver_task_handler.read(self.ti)
mock_client.return_value.list_entries.assert_called_once_with(
filter_='resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="K\\"OT"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"',
filter_=(
'resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="K\\"OT"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"'
),
order_by='timestamp asc',
page_size=1000,
page_token=None,
)
assert ['MSG1\nMSG2'] == logs
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
assert [{'end_of_log': True}] == metadata

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand All @@ -170,15 +190,19 @@ def test_should_read_logs_for_single_try(self, mock_client, mock_get_creds_and_p

logs, metadata = self.stackdriver_task_handler.read(self.ti, 3)
mock_client.return_value.list_entries.assert_called_once_with(
filter_='resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
'labels.try_number="3"',
filter_=(
'resource.type="global"\n'
'logName="projects/asf-project/logs/airflow"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"\n'
'labels.try_number="3"'
),
order_by='timestamp asc',
page_size=1000,
page_token=None,
)
assert ['MSG1\nMSG2'] == logs
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
assert [{'end_of_log': True}] == metadata

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand All @@ -190,14 +214,18 @@ def test_should_read_logs_with_pagination(self, mock_client, mock_get_creds_and_
]
mock_get_creds_and_project_id.return_value = ('creds', 'project_id')
logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3)
mock_client.return_value.list_entries.assert_called_once_with(filter_=mock.ANY, page_token=None)
assert ['MSG1\nMSG2'] == logs
mock_client.return_value.list_entries.assert_called_once_with(
filter_=mock.ANY, order_by='timestamp asc', page_size=1000, page_token=None
)
assert [(('default-hostname', 'MSG1\nMSG2'),)] == logs
assert [{'end_of_log': False, 'next_page_token': 'TOKEN1'}] == metadata1

mock_client.return_value.list_entries.return_value.next_page_token = None
logs, metadata2 = self.stackdriver_task_handler.read(self.ti, 3, metadata1[0])
mock_client.return_value.list_entries.assert_called_with(filter_=mock.ANY, page_token="TOKEN1")
assert ['MSG3\nMSG4'] == logs
mock_client.return_value.list_entries.assert_called_with(
filter_=mock.ANY, order_by='timestamp asc', page_size=1000, page_token="TOKEN1"
)
assert [(('default-hostname', 'MSG3\nMSG4'),)] == logs
assert [{'end_of_log': True}] == metadata2

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand All @@ -211,7 +239,7 @@ def test_should_read_logs_with_download(self, mock_client, mock_get_creds_and_pr

logs, metadata1 = self.stackdriver_task_handler.read(self.ti, 3, {'download_logs': True})

assert ['MSG1\nMSG2\nMSG3\nMSG4'] == logs
assert [(('default-hostname', 'MSG1\nMSG2\nMSG3\nMSG4'),)] == logs
assert [{'end_of_log': True}] == metadata1

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand Down Expand Up @@ -240,17 +268,21 @@ def test_should_read_logs_with_custom_resources(self, mock_client, mock_get_cred

logs, metadata = self.stackdriver_task_handler.read(self.ti)
mock_client.return_value.list_entries.assert_called_once_with(
filter_='resource.type="cloud_composer_environment"\n'
'logName="projects/asf-project/logs/airflow"\n'
'resource.labels."environment.name"="test-instancce"\n'
'resource.labels.location="europpe-west-3"\n'
'resource.labels.project_id="asf-project"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"',
filter_=(
'resource.type="cloud_composer_environment"\n'
'logName="projects/asf-project/logs/airflow"\n'
'resource.labels."environment.name"="test-instancce"\n'
'resource.labels.location="europpe-west-3"\n'
'resource.labels.project_id="asf-project"\n'
'labels.task_id="task_for_testing_file_log_handler"\n'
'labels.dag_id="dag_for_testing_file_task_handler"\n'
'labels.execution_date="2016-01-01T00:00:00+00:00"'
),
order_by='timestamp asc',
page_size=1000,
page_token=None,
)
assert ['TEXT\nTEXT'] == logs
assert [(('default-hostname', 'TEXT\nTEXT'),)] == logs
assert [{'end_of_log': True}] == metadata

@mock.patch('airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_should_support_key_auth(self, session):
assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()

self.assert_remote_logs("INFO - Task exited with return code 0", ti)
self.assert_remote_logs("terminated with exit code 0", ti)

@provide_session
def test_should_support_adc(self, session):
Expand All @@ -78,7 +78,7 @@ def test_should_support_adc(self, session):
assert 0 == subprocess.Popen(["airflow", "scheduler", "--num-runs", "1"]).wait()
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()

self.assert_remote_logs("INFO - Task exited with return code 0", ti)
self.assert_remote_logs("terminated with exit code 0", ti)

def assert_remote_logs(self, expected_message, ti):
with provide_gcp_context(GCP_STACKDRIVER), conf_vars(
Expand All @@ -94,4 +94,8 @@ def assert_remote_logs(self, expected_message, ti):

task_log_reader = TaskLogReader()
logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
# Preview content
print("=" * 80)
print(logs)
print("=" * 80)
assert expected_message in logs

0 comments on commit 833e338

Please sign in to comment.