Skip to content

Commit

Permalink
Add more type hints to PodLauncher (#18928)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Oct 13, 2021
1 parent a256af0 commit b2045d6
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions airflow/providers/cncf/kubernetes/utils/pod_launcher.py
Expand Up @@ -19,11 +19,13 @@
import math
import time
from datetime import datetime as dt
from typing import Optional, Tuple, Union
from typing import Iterable, Optional, Tuple, Union

import pendulum
import tenacity
from kubernetes import client, watch
from kubernetes.client.models.v1_event import V1Event
from kubernetes.client.models.v1_event_list import V1EventList
from kubernetes.client.models.v1_pod import V1Pod
from kubernetes.client.rest import ApiException
from kubernetes.stream import stream as kubernetes_stream
Expand All @@ -39,7 +41,7 @@
from airflow.utils.state import State


def should_retry_start_pod(exception: Exception):
def should_retry_start_pod(exception: Exception) -> bool:
"""Check if an Exception indicates a transient error and warrants retrying"""
if isinstance(exception, ApiException):
return exception.status == 409
Expand Down Expand Up @@ -78,7 +80,7 @@ def __init__(
self._watch = watch.Watch()
self.extract_xcom = extract_xcom

def run_pod_async(self, pod: V1Pod, **kwargs):
def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
"""Runs POD asynchronously"""
pod_mutation_hook(pod)

Expand All @@ -98,7 +100,7 @@ def run_pod_async(self, pod: V1Pod, **kwargs):
raise e
return resp

def delete_pod(self, pod: V1Pod):
def delete_pod(self, pod: V1Pod) -> None:
"""Deletes POD"""
try:
self._client.delete_namespaced_pod(
Expand All @@ -115,7 +117,7 @@ def delete_pod(self, pod: V1Pod):
reraise=True,
retry=tenacity.retry_if_exception(should_retry_start_pod),
)
def start_pod(self, pod: V1Pod, startup_timeout: int = 120):
def start_pod(self, pod: V1Pod, startup_timeout: int = 120) -> None:
"""
Launches the pod synchronously and waits for completion.
Expand Down Expand Up @@ -210,22 +212,22 @@ def parse_log_line(self, line: str) -> Tuple[Optional[Union[Date, Time, DateTime
return None, line
return last_log_time, message

def _task_status(self, event):
def _task_status(self, event: V1Event) -> str:
self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase)
status = self.process_status(event.metadata.name, event.status.phase)
return status

def pod_not_started(self, pod: V1Pod):
def pod_not_started(self, pod: V1Pod) -> bool:
"""Tests if pod has not started"""
state = self._task_status(self.read_pod(pod))
return state == State.QUEUED

def pod_is_running(self, pod: V1Pod):
def pod_is_running(self, pod: V1Pod) -> bool:
"""Tests if pod is running"""
state = self._task_status(self.read_pod(pod))
return state not in (State.SUCCESS, State.FAILED)

def base_container_is_running(self, pod: V1Pod):
def base_container_is_running(self, pod: V1Pod) -> bool:
"""Tests if base container is running"""
event = self.read_pod(pod)
status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None)
Expand All @@ -240,7 +242,7 @@ def read_pod_logs(
tail_lines: Optional[int] = None,
timestamps: bool = False,
since_seconds: Optional[int] = None,
):
) -> Iterable[str]:
"""Reads log from the POD"""
additional_kwargs = {}
if since_seconds:
Expand All @@ -265,7 +267,7 @@ def read_pod_logs(
raise

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod_events(self, pod):
def read_pod_events(self, pod: V1Pod) -> V1EventList:
"""Reads events from the POD"""
try:
return self._client.list_namespaced_event(
Expand All @@ -275,14 +277,14 @@ def read_pod_events(self, pod):
raise AirflowException(f'There was an error reading the kubernetes API: {e}')

@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod(self, pod: V1Pod):
def read_pod(self, pod: V1Pod) -> V1Pod:
"""Read POD information"""
try:
return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace)
except BaseHTTPError as e:
raise AirflowException(f'There was an error reading the kubernetes API: {e}')

def _extract_xcom(self, pod: V1Pod):
def _extract_xcom(self, pod: V1Pod) -> str:
resp = kubernetes_stream(
self._client.connect_get_namespaced_pod_exec,
pod.metadata.name,
Expand All @@ -304,7 +306,7 @@ def _extract_xcom(self, pod: V1Pod):
raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}')
return result

def _exec_pod_command(self, resp, command):
def _exec_pod_command(self, resp, command: str) -> None:
if resp.is_open():
self.log.info('Running command... %s\n', command)
resp.write_stdin(command + '\n')
Expand All @@ -317,7 +319,7 @@ def _exec_pod_command(self, resp, command):
break
return None

def process_status(self, job_id, status):
def process_status(self, job_id: str, status: str) -> str:
"""Process status information for the JOB"""
status = status.lower()
if status == PodStatus.PENDING:
Expand Down

0 comments on commit b2045d6

Please sign in to comment.