Skip to content

Commit

Permalink
Remove duplicated methods in K8S pod operator module and import them …
Browse files Browse the repository at this point in the history
…from helper function (#36427)

* Remove duplicated methods in K8S pod operator module and import them from helper function

* Clean the tests
  • Loading branch information
hussein-awala committed Dec 26, 2023
1 parent 7bd998e commit af9328e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 73 deletions.
67 changes: 7 additions & 60 deletions airflow/providers/cncf/kubernetes/operators/pod.py
Expand Up @@ -21,7 +21,6 @@
import json
import logging
import re
import secrets
import shlex
import string
import warnings
Expand All @@ -32,7 +31,6 @@

from kubernetes.client import CoreV1Api, V1Pod, models as k8s
from kubernetes.stream import stream
from slugify import slugify
from urllib3.exceptions import HTTPError

from airflow.configuration import conf
Expand All @@ -51,7 +49,11 @@
convert_volume_mount,
)
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import POD_NAME_MAX_LENGTH
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
POD_NAME_MAX_LENGTH,
add_pod_suffix,
create_pod_id,
)
from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
from airflow.providers.cncf.kubernetes.utils import xcom_sidecar # type: ignore[attr-defined]
Expand Down Expand Up @@ -83,61 +85,6 @@
KUBE_CONFIG_ENV_VAR = "KUBECONFIG"


def _rand_str(num):
"""Generate random lowercase alphanumeric string of length num.
TODO: when min airflow version >= 2.5, delete this function and import from kubernetes_helper_functions.
:meta private:
"""
return "".join(secrets.choice(alphanum_lower) for _ in range(num))


def _add_pod_suffix(*, pod_name, rand_len=8, max_len=POD_NAME_MAX_LENGTH):
"""Add random string to pod name while staying under max len.
TODO: when min airflow version >= 2.5, delete this function and import from kubernetes_helper_functions.
:meta private:
"""
suffix = "-" + _rand_str(rand_len)
return pod_name[: max_len - len(suffix)].strip("-.") + suffix


def _create_pod_id(
dag_id: str | None = None,
task_id: str | None = None,
*,
max_length: int = POD_NAME_MAX_LENGTH,
unique: bool = True,
) -> str:
"""
Generate unique pod ID given a dag_id and / or task_id.
TODO: when min airflow version >= 2.5, delete this function and import from kubernetes_helper_functions.
:param dag_id: DAG ID
:param task_id: Task ID
:param max_length: max number of characters
:param unique: whether a random string suffix should be added
:return: A valid identifier for a kubernetes pod name
"""
if not (dag_id or task_id):
raise ValueError("Must supply either dag_id or task_id.")
name = ""
if dag_id:
name += dag_id
if task_id:
if name:
name += "-"
name += task_id
base_name = slugify(name, lowercase=True)[:max_length].strip(".-")
if unique:
return _add_pod_suffix(pod_name=base_name, max_len=max_length)
else:
return base_name


class PodReattachFailure(AirflowException):
"""When we expect to be able to find a pod but cannot."""

Expand Down Expand Up @@ -963,12 +910,12 @@ def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod:
pod = PodGenerator.reconcile_pods(pod_template, pod)

if not pod.metadata.name:
pod.metadata.name = _create_pod_id(
pod.metadata.name = create_pod_id(
task_id=self.task_id, unique=self.random_name_suffix, max_length=POD_NAME_MAX_LENGTH
)
elif self.random_name_suffix:
# user has supplied pod name, we're just adding suffix
pod.metadata.name = _add_pod_suffix(pod_name=pod.metadata.name)
pod.metadata.name = add_pod_suffix(pod_name=pod.metadata.name)

if not pod.metadata.namespace:
hook_namespace = self.hook.get_namespace()
Expand Down
Expand Up @@ -22,17 +22,10 @@
import pytest

from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import create_pod_id
from airflow.providers.cncf.kubernetes.operators.pod import _create_pod_id

pod_name_regex = r"^[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*$"


# todo: when cncf provider min airflow version >= 2.5 remove this parameterization
# we added this function to provider temporarily until min airflow version catches up
# meanwhile, we use this one test to test both core and provider
@pytest.mark.parametrize(
"create_pod_id", [pytest.param(_create_pod_id, id="provider"), pytest.param(create_pod_id, id="core")]
)
class TestCreatePodId:
@pytest.mark.parametrize(
"val, expected",
Expand All @@ -46,7 +39,7 @@ class TestCreatePodId:
("90AçLbˆˆç˙ßߘ˜˙c*a", "90aclb-c-ssss-c-a"), # weird unicode
],
)
def test_create_pod_id_task_only(self, val, expected, create_pod_id):
def test_create_pod_id_task_only(self, val, expected):
actual = create_pod_id(task_id=val, unique=False)
assert actual == expected
assert re.match(pod_name_regex, actual)
Expand All @@ -63,7 +56,7 @@ def test_create_pod_id_task_only(self, val, expected, create_pod_id):
("90AçLbˆˆç˙ßߘ˜˙c*a", "90aclb-c-ssss-c-a"), # weird unicode
],
)
def test_create_pod_id_dag_only(self, val, expected, create_pod_id):
def test_create_pod_id_dag_only(self, val, expected):
actual = create_pod_id(dag_id=val, unique=False)
assert actual == expected
assert re.match(pod_name_regex, actual)
Expand All @@ -80,26 +73,26 @@ def test_create_pod_id_dag_only(self, val, expected, create_pod_id):
("90AçLbˆˆç˙ßߘ˜˙c*a", "90AçLbˆˆç˙ßߘ˜˙c*a", "90aclb-c-ssss-c-a-90aclb-c-ssss-c-a"), # ugly
],
)
def test_create_pod_id_dag_and_task(self, dag_id, task_id, expected, create_pod_id):
def test_create_pod_id_dag_and_task(self, dag_id, task_id, expected):
actual = create_pod_id(dag_id=dag_id, task_id=task_id, unique=False)
assert actual == expected
assert re.match(pod_name_regex, actual)

def test_create_pod_id_dag_too_long_with_suffix(self, create_pod_id):
def test_create_pod_id_dag_too_long_with_suffix(self):
actual = create_pod_id("0" * 254)
assert len(actual) == 63
assert re.match(r"0{54}-[a-z0-9]{8}", actual)
assert re.match(pod_name_regex, actual)

def test_create_pod_id_dag_too_long_non_unique(self, create_pod_id):
def test_create_pod_id_dag_too_long_non_unique(self):
actual = create_pod_id("0" * 254, unique=False)
assert len(actual) == 63
assert re.match(r"0{63}", actual)
assert re.match(pod_name_regex, actual)

@pytest.mark.parametrize("unique", [True, False])
@pytest.mark.parametrize("length", [25, 100, 200, 300])
def test_create_pod_id(self, create_pod_id, length, unique):
def test_create_pod_id(self, length, unique):
"""Test behavior of max_length and unique."""
dag_id = "dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-dag-"
task_id = "task-task-task-task-task-task-task-task-task-task-task-task-task-task-task-task-task-"
Expand Down

0 comments on commit af9328e

Please sign in to comment.