Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferrable mode to KubernetesPodOperator #29017

Merged
merged 3 commits into from Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
160 changes: 160 additions & 0 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Expand Up @@ -16,19 +16,26 @@
# under the License.
from __future__ import annotations

import contextlib
import tempfile
import warnings
from typing import TYPE_CHECKING, Any, Generator

from asgiref.sync import sync_to_async
from kubernetes import client, config, watch
from kubernetes.client.models import V1Pod
from kubernetes.config import ConfigException
from kubernetes_asyncio import client as async_client, config as async_config
from urllib3.exceptions import HTTPError

from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
from airflow.utils import yaml

LOADING_KUBE_CONFIG_FILE_RESOURCE = "Loading Kubernetes configuration file kube_config from {}..."


def _load_body_to_dict(body):
try:
Expand Down Expand Up @@ -396,6 +403,12 @@ def get_pod_logs(
namespace=namespace or self._get_namespace() or self.DEFAULT_NAMESPACE,
)

def get_pod(self, name: str, namespace: str) -> V1Pod:
return self.core_v1_client.read_namespaced_pod(
name=name,
namespace=namespace,
)

def get_namespaced_pod_list(
self,
label_selector: str | None = "",
Expand Down Expand Up @@ -431,3 +444,150 @@ def _get_bool(val) -> bool | None:
elif val.strip().lower() == "false":
return False
return None


class AsyncKubernetesHook(KubernetesHook):
"""Hook to use Kubernetes SDK asynchronously."""

def __init__(self, config_dict: dict | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.config_dict = config_dict

self._extras: dict | None = None

async def _load_config(self):
"""Returns Kubernetes API session for use with requests"""
in_cluster = self._coalesce_param(self.in_cluster, await self._get_field("in_cluster"))
cluster_context = self._coalesce_param(self.cluster_context, await self._get_field("cluster_context"))
kubeconfig = await self._get_field("kube_config")

num_selected_configuration = len([o for o in [in_cluster, kubeconfig, self.config_dict] if o])

if num_selected_configuration > 1:
raise AirflowException(
"Invalid connection configuration. Options kube_config_path, "
"kube_config, in_cluster are mutually exclusive. "
"You can only use one option at a time."
)

if in_cluster:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("within a pod"))
self._is_in_cluster = True
async_config.load_incluster_config()
return async_client.ApiClient()

if self.config_dict:
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("config dictionary"))
await async_config.load_kube_config_from_dict(self.config_dict)
return async_client.ApiClient()

if kubeconfig is not None:
with tempfile.NamedTemporaryFile() as temp_config:
self.log.debug(
"Reading kubernetes configuration file from connection "
"object and writing temporary config file with its content",
)
temp_config.write(kubeconfig.encode())
temp_config.flush()
self._is_in_cluster = False
await async_config.load_kube_config(
config_file=temp_config.name,
client_configuration=self.client_configuration,
context=cluster_context,
)
return async_client.ApiClient()
self.log.debug(LOADING_KUBE_CONFIG_FILE_RESOURCE.format("default configuration file"))
await async_config.load_kube_config(
client_configuration=self.client_configuration,
context=cluster_context,
)

async def get_conn_extras(self) -> dict:
if self._extras is None:
if self.conn_id:
connection = await sync_to_async(self.get_connection)(self.conn_id)
self._extras = connection.extra_dejson
else:
self._extras = {}
return self._extras

async def _get_field(self, field_name):
VladaZakharova marked this conversation as resolved.
Show resolved Hide resolved
if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the 'extra__kubernetes__' prefix "
"when using this method."
)
extras = await self.get_conn_extras()
if field_name in extras:
return extras.get(field_name)
prefixed_name = f"extra__kubernetes__{field_name}"
return extras.get(prefixed_name)

@contextlib.asynccontextmanager
async def get_conn(self) -> async_client.ApiClient:
kube_client = None
try:
kube_client = await self._load_config() or async_client.ApiClient()
yield kube_client
finally:
if kube_client is not None:
await kube_client.close()

async def get_pod(self, name: str, namespace: str) -> V1Pod:
"""
Gets pod's object.

:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with self.get_conn() as connection:
v1_api = async_client.CoreV1Api(connection)
pod: V1Pod = await v1_api.read_namespaced_pod(
name=name,
namespace=namespace,
)
return pod

async def delete_pod(self, name: str, namespace: str):
"""
Deletes pod's object.

:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with self.get_conn() as connection:
try:
v1_api = async_client.CoreV1Api(connection)
await v1_api.delete_namespaced_pod(
name=name, namespace=namespace, body=client.V1DeleteOptions()
)
except async_client.ApiException as e:
# If the pod is already deleted
if e.status != 404:
raise

async def read_logs(self, name: str, namespace: str):
"""
Reads logs inside the pod while starting containers inside. All the logs will be outputted with its
timestamp to track the logs after the execution of the pod is completed. The method is used for async
output of the logs only in the pod failed it execution or the task was cancelled by the user.

:param name: Name of the pod.
:param namespace: Name of the pod's namespace.
"""
async with self.get_conn() as connection:
try:
v1_api = async_client.CoreV1Api(connection)
logs = await v1_api.read_namespaced_pod_log(
name=name,
namespace=namespace,
follow=False,
timestamps=True,
)
logs = logs.splitlines()
for line in logs:
self.log.info("Container logs from %s", line)
return logs
except HTTPError:
self.log.exception("There was an error reading the kubernetes API.")
raise