diff --git a/airflow/contrib/example_dags/example_kubernetes_executor.py b/airflow/contrib/example_dags/example_kubernetes_executor.py index d288211564b74..661075a8f2952 100644 --- a/airflow/contrib/example_dags/example_kubernetes_executor.py +++ b/airflow/contrib/example_dags/example_kubernetes_executor.py @@ -94,4 +94,10 @@ def use_zip_binary(): "affinity": affinity}} ) -start_task.set_downstream([one_task, two_task, three_task]) +# Add arbitrary labels to worker pods +four_task = PythonOperator( + task_id="four_task", python_callable=print_stuff, dag=dag, + executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}} +) + +start_task.set_downstream([one_task, two_task, three_task, four_task]) diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py index 5da42a95d8bff..f86137f2f59e2 100644 --- a/airflow/executors/kubernetes_executor.py +++ b/airflow/executors/kubernetes_executor.py @@ -45,7 +45,7 @@ class KubernetesExecutorConfig: def __init__(self, image=None, image_pull_policy=None, request_memory=None, request_cpu=None, limit_memory=None, limit_cpu=None, gcp_service_account_key=None, node_selectors=None, affinity=None, - annotations=None, volumes=None, volume_mounts=None, tolerations=None): + annotations=None, volumes=None, volume_mounts=None, tolerations=None, labels={}): self.image = image self.image_pull_policy = image_pull_policy self.request_memory = request_memory @@ -59,17 +59,18 @@ def __init__(self, image=None, image_pull_policy=None, request_memory=None, self.volumes = volumes self.volume_mounts = volume_mounts self.tolerations = tolerations + self.labels = labels def __repr__(self): return "{}(image={}, image_pull_policy={}, request_memory={}, request_cpu={}, " \ "limit_memory={}, limit_cpu={}, gcp_service_account_key={}, " \ "node_selectors={}, affinity={}, annotations={}, volumes={}, " \ - "volume_mounts={}, tolerations={})" \ + "volume_mounts={}, tolerations={}, labels={})" \ .format(KubernetesExecutorConfig.__name__, self.image, self.image_pull_policy, self.request_memory, self.request_cpu, self.limit_memory, self.limit_cpu, self.gcp_service_account_key, self.node_selectors, self.affinity, self.annotations, self.volumes, self.volume_mounts, - self.tolerations) + self.tolerations, self.labels) @staticmethod def from_dict(obj): @@ -96,6 +97,7 @@ def from_dict(obj): volumes=namespaced.get('volumes', []), volume_mounts=namespaced.get('volume_mounts', []), tolerations=namespaced.get('tolerations', None), + labels=namespaced.get('labels', {}), ) def as_dict(self): @@ -113,6 +115,7 @@ def as_dict(self): 'volumes': self.volumes, 'volume_mounts': self.volume_mounts, 'tolerations': self.tolerations, + 'labels': self.labels, } diff --git a/airflow/kubernetes/worker_configuration.py b/airflow/kubernetes/worker_configuration.py index 6b5b217d24350..29ea54cb4827c 100644 --- a/airflow/kubernetes/worker_configuration.py +++ b/airflow/kubernetes/worker_configuration.py @@ -201,8 +201,9 @@ def _get_security_context(self): return security_context - def _get_labels(self, labels): + def _get_labels(self, kube_executor_labels, labels): copy = self.kube_config.kube_labels.copy() + copy.update(kube_executor_labels) copy.update(labels) return copy @@ -337,7 +338,7 @@ def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_da image_pull_policy=(kube_executor_config.image_pull_policy or self.kube_config.kube_image_pull_policy), cmds=airflow_command, - labels=self._get_labels({ + labels=self._get_labels(kube_executor_config.labels, { 'airflow-worker': worker_uuid, 'dag_id': dag_id, 'task_id': task_id, diff --git a/tests/executors/test_kubernetes_executor.py b/tests/executors/test_kubernetes_executor.py index 9062419a38505..20d94d143964f 100644 --- a/tests/executors/test_kubernetes_executor.py +++ b/tests/executors/test_kubernetes_executor.py @@ -633,10 +633,14 @@ def test_get_configmaps(self): def test_get_labels(self): worker_config = WorkerConfiguration(self.kube_config) - labels = worker_config._get_labels({ + labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, { 'dag_id': 'override_dag_id', }) - self.assertEqual({'my_label': 'label_id', 'dag_id': 'override_dag_id'}, labels) + self.assertEqual({ + 'my_label': 'label_id', + 'dag_id': 'override_dag_id', + 'my_kube_executor_label': 'kubernetes' + }, labels) class TestKubernetesExecutor(unittest.TestCase):