Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in LivyOperator when its trigger times out #38916

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions airflow/providers/apache/livy/operators/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __init__(
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}

self._batch_id: int | str
self._batch_id: int | str | None = None
self.retry_args = retry_args
self.deferrable = deferrable

Expand Down Expand Up @@ -170,6 +170,7 @@ def execute(self, context: Context) -> Any:
polling_interval=self._polling_interval,
extra_options=self._extra_options,
extra_headers=self._extra_headers,
execution_timeout=self.execution_timeout,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -217,8 +218,12 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
for log_line in event["log_lines"]:
self.log.info(log_line)

if event["status"] == "error":
if event["status"] == "timeout":
self.hook.delete_batch(event["batch_id"])

if event["status"] in ["error", "timeout"]:
raise AirflowException(event["response"])

self.log.info(
"%s completed with response %s",
self.task_id,
Expand Down
27 changes: 26 additions & 1 deletion airflow/providers/apache/livy/triggers/livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import asyncio
from datetime import datetime, timedelta, timezone
from typing import Any, AsyncIterator

from airflow.providers.apache.livy.hooks.livy import BatchState, LivyAsyncHook
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(
extra_options: dict[str, Any] | None = None,
extra_headers: dict[str, Any] | None = None,
livy_hook_async: LivyAsyncHook | None = None,
execution_timeout: timedelta | None = None,
):
super().__init__()
self._batch_id = batch_id
Expand All @@ -63,6 +65,7 @@ def __init__(
self._extra_options = extra_options
self._extra_headers = extra_headers
self._livy_hook_async = livy_hook_async
self._execution_timeout = execution_timeout

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize LivyTrigger arguments and classpath."""
Expand All @@ -76,6 +79,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"extra_options": self._extra_options,
"extra_headers": self._extra_headers,
"livy_hook_async": self._livy_hook_async,
"execution_timeout": self._execution_timeout,
},
)

Expand Down Expand Up @@ -113,16 +117,37 @@ async def poll_for_termination(self, batch_id: int | str) -> dict[str, Any]:

:param batch_id: id of the batch session to monitor.
"""
if self._execution_timeout is not None:
timeout_datetime = datetime.now(timezone.utc) + self._execution_timeout
else:
timeout_datetime = None
batch_execution_timed_out = False
hook = self._get_async_hook()
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
while state["batch_state"] not in hook.TERMINAL_STATES:
self.log.info("Batch with id %s is in state: %s", batch_id, state["batch_state"].value)
batch_execution_timed_out = (
timeout_datetime is not None and datetime.now(timezone.utc) > timeout_datetime
)
if batch_execution_timed_out:
break
self.log.info("Sleeping for %s seconds", self._polling_interval)
await asyncio.sleep(self._polling_interval)
state = await hook.get_batch_state(batch_id)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
log_lines = await hook.dump_batch_logs(batch_id)
if batch_execution_timed_out:
self.log.info(
"Batch with id %s did not terminate, but it reached execution timeout.",
batch_id,
)
return {
"status": "timeout",
"batch_id": batch_id,
"response": f"Batch {batch_id} timed out",
"log_lines": log_lines,
}
self.log.info("Batch with id %s terminated with state: %s", batch_id, state["batch_state"].value)
if state["batch_state"] != BatchState.SUCCESS:
return {
"status": "error",
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/apache/livy/operators/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,19 @@ def test_execution_with_extra_options_deferrable(
task.execute(context=self.mock_context)
assert task.hook.extra_options == extra_options

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
def test_when_kill_is_called_right_after_construction_it_should_not_raise_attribute_error(
self, mock_delete_batch
):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
dag=self.dag,
task_id="livy_example",
)
task.kill()
mock_delete_batch.assert_not_called()

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.get_batch", return_value=GET_BATCH)
Expand Down Expand Up @@ -380,6 +393,30 @@ def test_execute_complete_error(self, mock_post):
)
self.mock_context["ti"].xcom_push.assert_not_called()

@patch("airflow.providers.apache.livy.operators.livy.LivyHook.post_batch", return_value=BATCH_ID)
@patch("airflow.providers.apache.livy.operators.livy.LivyHook.delete_batch")
def test_execute_complete_timeout(self, mock_delete, mock_post):
task = LivyOperator(
livy_conn_id="livyunittest",
file="sparkapp",
dag=self.dag,
task_id="livy_example",
polling_interval=1,
deferrable=True,
)
with pytest.raises(AirflowException):
task.execute_complete(
context=self.mock_context,
event={
"status": "timeout",
"log_lines": ["mock log"],
"batch_id": BATCH_ID,
"response": "mock timeout",
},
)
mock_delete.assert_called_once_with(BATCH_ID)
self.mock_context["ti"].xcom_push.assert_not_called()


@pytest.mark.db_test
def test_spark_params_templating(create_task_instance_of_operator):
Expand Down
30 changes: 30 additions & 0 deletions tests/providers/apache/livy/triggers/test_livy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
from datetime import timedelta
from unittest import mock

import pytest
Expand Down Expand Up @@ -46,6 +47,7 @@ def test_livy_trigger_serialization(self):
"extra_options": None,
"extra_headers": None,
"livy_hook_async": None,
"execution_timeout": None,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -195,3 +197,31 @@ async def test_livy_trigger_poll_for_termination_state(self, mock_dump_batch_log
# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.get_batch_state")
@mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook.dump_batch_logs")
async def test_livy_trigger_poll_for_termination_timeout(
self, mock_dump_batch_logs, mock_get_batch_state
):
"""
Test if poll_for_termination() returns timeout response when execution times out.
"""
mock_get_batch_state.return_value = {"batch_state": BatchState.RUNNING}
mock_dump_batch_logs.return_value = ["mock_log"]
trigger = LivyTrigger(
batch_id=1,
spark_params={},
livy_conn_id=LivyHook.default_conn_name,
polling_interval=1,
execution_timeout=timedelta(seconds=0),
)

task = await trigger.poll_for_termination(1)

assert task == {
"status": "timeout",
"batch_id": 1,
"response": "Batch 1 timed out",
"log_lines": ["mock_log"],
}