Skip to content
Permalink
Browse files
More typing and minor refactor for kubernetes (#24719)
  • Loading branch information
jedcunningham committed Jun 29, 2022
1 parent 737e2f5 commit 9d307102b4a604034d9b1d7f293884821263575f
Showing 4 changed files with 26 additions and 23 deletions.
@@ -42,7 +42,7 @@
MAX_LABEL_LEN = 63


def make_safe_label_value(string):
def make_safe_label_value(string: str) -> str:
"""
Valid label values must be 63 characters or less and must be empty or begin and
end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
@@ -169,7 +169,7 @@ def _deprecation_warning_core_param(deprecation_warnings):
DeprecationWarning,
)

def get_conn(self) -> Any:
def get_conn(self) -> client.ApiClient:
"""Returns kubernetes api session for use with requests"""
in_cluster = self._coalesce_param(
self.in_cluster, self.conn_extras.get("extra__kubernetes__in_cluster") or None
@@ -258,7 +258,7 @@ def get_conn(self) -> Any:

return self._get_default_client(cluster_context=cluster_context)

def _get_default_client(self, *, cluster_context=None):
def _get_default_client(self, *, cluster_context: Optional[str] = None) -> client.ApiClient:
# if we get here, then no configuration has been supplied
# we should try in_cluster since that's most likely
# but failing that just load assuming a kubeconfig file
@@ -276,20 +276,21 @@ def _get_default_client(self, *, cluster_context=None):
return client.ApiClient()

@property
def is_in_cluster(self):
def is_in_cluster(self) -> bool:
"""Expose whether the hook is configured with ``load_incluster_config`` or not"""
if self._is_in_cluster is not None:
return self._is_in_cluster
self.api_client # so we can determine if we are in_cluster or not
assert self._is_in_cluster is not None
return self._is_in_cluster

@cached_property
def api_client(self) -> Any:
def api_client(self) -> client.ApiClient:
"""Cached Kubernetes API client"""
return self.get_conn()

@cached_property
def core_v1_client(self):
def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(api_client=self.api_client)

def create_custom_object(
@@ -377,12 +378,11 @@ def get_pod_log_stream(
:param container: container name
:param namespace: kubernetes namespace
"""
api = client.CoreV1Api(self.api_client)
watcher = watch.Watch()
return (
watcher,
watcher.stream(
api.read_namespaced_pod_log,
self.core_v1_client.read_namespaced_pod_log,
name=pod_name,
container=container,
namespace=namespace if namespace else self.get_namespace(),
@@ -402,8 +402,7 @@ def get_pod_logs(
:param container: container name
:param namespace: kubernetes namespace
"""
api = client.CoreV1Api(self.api_client)
return api.read_namespaced_pod_log(
return self.core_v1_client.read_namespaced_pod_log(
name=pod_name,
container=container,
_preload_content=False,
@@ -300,7 +300,9 @@ def _render_nested_template_fields(
super()._render_nested_template_fields(content, context, jinja_env, seen_oids)

@staticmethod
def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool = True) -> dict:
def _get_ti_pod_labels(
context: Optional['Context'] = None, include_try_number: bool = True
) -> Dict[str, str]:
"""
Generate labels for the pod to track the pod in case of Operator crash
@@ -360,7 +362,9 @@ def hook(self) -> KubernetesHook:
def client(self) -> CoreV1Api:
return self.hook.core_v1_client

def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]:
def find_pod(
self, namespace: str, context: 'Context', *, exclude_checked: bool = True
) -> Optional[k8s.V1Pod]:
"""Returns an already-running pod for this task instance if one exists."""
label_selector = self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
pod_list = self.client.list_namespaced_pod(
@@ -379,7 +383,7 @@ def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.
self.log.info("`try_number` of pod: %s", pod.metadata.labels['try_number'])
return pod

def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: 'Context') -> k8s.V1Pod:
if self.reattach_on_restart:
pod = self.find_pod(self.namespace or pod_request_obj.metadata.namespace, context=context)
if pod:
@@ -388,7 +392,7 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
self.pod_manager.create_pod(pod=pod_request_obj)
return pod_request_obj

def await_pod_start(self, pod):
def await_pod_start(self, pod: k8s.V1Pod):
try:
self.pod_manager.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds)
except PodLaunchFailedException:
@@ -397,7 +401,7 @@ def await_pod_start(self, pod):
self.log.error("Pod Event: %s - %s", event.reason, event.message)
raise

def extract_xcom(self, pod):
def extract_xcom(self, pod: k8s.V1Pod):
"""Retrieves xcom value and kills xcom sidecar container"""
result = self.pod_manager.extract_xcom(pod)
self.log.info("xcom result: \n%s", result)
@@ -461,15 +465,17 @@ def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
with _suppress(Exception):
self.process_pod_deletion(remote_pod)

def process_pod_deletion(self, pod):
def process_pod_deletion(self, pod: k8s.V1Pod):
if pod is not None:
if self.is_delete_operator_pod:
self.log.info("Deleting pod: %s", pod.metadata.name)
self.pod_manager.delete_pod(pod)
else:
self.log.info("skipping deleting pod: %s", pod.metadata.name)

def _build_find_pod_label_selector(self, context: Optional[dict] = None, *, exclude_checked=True) -> str:
def _build_find_pod_label_selector(
self, context: Optional['Context'] = None, *, exclude_checked=True
) -> str:
labels = self._get_ti_pod_labels(context, include_try_number=False)
label_strings = [f'{label_id}={label}' for label_id, label in sorted(labels.items())]
labels_value = ','.join(label_strings)
@@ -478,7 +484,7 @@ def _build_find_pod_label_selector(self, context: Optional[dict] = None, *, excl
labels_value += ',!airflow-worker'
return labels_value

def _set_name(self, name):
def _set_name(self, name: Optional[str]) -> Optional[str]:
if name is None:
if self.pod_template_file or self.full_pod_spec:
return None
@@ -504,7 +510,7 @@ def on_kill(self) -> None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.client.delete_namespaced_pod(**kwargs)

def build_pod_request_obj(self, context=None):
def build_pod_request_obj(self, context: Optional['Context'] = None) -> k8s.V1Pod:
"""
Returns V1Pod object based on pod template file, full pod spec, and other operator parameters.
@@ -106,10 +106,8 @@ def test_in_cluster_connection(
else:
mock_get_default_client.assert_called()
assert isinstance(api_conn, kubernetes.client.api_client.ApiClient)
if mock_get_default_client.called:
# get_default_client sets it, but it's mocked
assert kubernetes_hook.is_in_cluster is None
else:
if not mock_get_default_client.called:
# get_default_client is mocked, so only check is_in_cluster if it isn't called
assert kubernetes_hook.is_in_cluster is in_cluster_called

@pytest.mark.parametrize('in_cluster_fails', [True, False])

0 comments on commit 9d30710

Please sign in to comment.