Skip to content

Commit

Permalink
Restore airflow_local_settings after the `test_should_use_configured_…
Browse files Browse the repository at this point in the history
…log_name` (#38722)
  • Loading branch information
Taragolis committed Apr 3, 2024
1 parent 537e0e6 commit 40c70e3
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions tests/providers/google/cloud/log/test_stackdriver_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.providers.google.cloud.log.stackdriver_task_handler import StackdriverTaskHandler
from airflow.utils import timezone
from airflow.utils.state import TaskInstanceState
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs


Expand Down Expand Up @@ -71,29 +72,34 @@ def test_should_pass_message_to_client(mock_client, mock_get_creds_and_project_i
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.get_credentials_and_project_id")
@mock.patch("airflow.providers.google.cloud.log.stackdriver_task_handler.gcp_logging.Client")
def test_should_use_configured_log_name(mock_client, mock_get_creds_and_project_id):
mock_get_creds_and_project_id.return_value = ("creds", "project_id")
import importlib
import logging

with mock.patch.dict(
"os.environ",
AIRFLOW__LOGGING__REMOTE_LOGGING="true",
AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER="stackdriver://host/path",
):
import importlib
import logging
from airflow import settings
from airflow.config_templates import airflow_local_settings

from airflow import settings
from airflow.config_templates import airflow_local_settings
mock_get_creds_and_project_id.return_value = ("creds", "project_id")

try:
with conf_vars(
{
("logging", "remote_logging"): "True",
("logging", "remote_base_log_folder"): "stackdriver://host/path",
}
):
importlib.reload(airflow_local_settings)
settings.configure_logging()

logger = logging.getLogger("airflow.task")
handler = logger.handlers[0]
assert isinstance(handler, StackdriverTaskHandler)
with mock.patch.object(handler, "transport_type") as transport_type_mock:
logger.error("foo")
transport_type_mock.assert_called_once_with(mock_client.return_value, "path")
finally:
importlib.reload(airflow_local_settings)
settings.configure_logging()

logger = logging.getLogger("airflow.task")
handler = logger.handlers[0]
assert isinstance(handler, StackdriverTaskHandler)
with mock.patch.object(handler, "transport_type") as transport_type_mock:
logger.error("foo")
transport_type_mock.assert_called_once_with(mock_client.return_value, "path")


@pytest.mark.db_test
class TestStackdriverLoggingHandlerTask:
Expand Down

0 comments on commit 40c70e3

Please sign in to comment.