Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-4739] Add ability to specify labels per task with kubernetes executor config #5376

Merged
merged 1 commit into from Jun 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion airflow/contrib/example_dags/example_kubernetes_executor.py
Expand Up @@ -94,4 +94,10 @@ def use_zip_binary():
"affinity": affinity}}
)

start_task.set_downstream([one_task, two_task, three_task])
# Add arbitrary labels to worker pods
four_task = PythonOperator(
task_id="four_task", python_callable=print_stuff, dag=dag,
executor_config={"KubernetesExecutor": {"labels": {"foo": "bar"}}}
)

start_task.set_downstream([one_task, two_task, three_task, four_task])
9 changes: 6 additions & 3 deletions airflow/executors/kubernetes_executor.py
Expand Up @@ -45,7 +45,7 @@ class KubernetesExecutorConfig:
def __init__(self, image=None, image_pull_policy=None, request_memory=None,
request_cpu=None, limit_memory=None, limit_cpu=None,
gcp_service_account_key=None, node_selectors=None, affinity=None,
annotations=None, volumes=None, volume_mounts=None, tolerations=None):
annotations=None, volumes=None, volume_mounts=None, tolerations=None, labels={}):
self.image = image
self.image_pull_policy = image_pull_policy
self.request_memory = request_memory
Expand All @@ -59,17 +59,18 @@ def __init__(self, image=None, image_pull_policy=None, request_memory=None,
self.volumes = volumes
self.volume_mounts = volume_mounts
self.tolerations = tolerations
self.labels = labels

def __repr__(self):
return "{}(image={}, image_pull_policy={}, request_memory={}, request_cpu={}, " \
"limit_memory={}, limit_cpu={}, gcp_service_account_key={}, " \
"node_selectors={}, affinity={}, annotations={}, volumes={}, " \
"volume_mounts={}, tolerations={})" \
"volume_mounts={}, tolerations={}, labels={})" \
.format(KubernetesExecutorConfig.__name__, self.image, self.image_pull_policy,
self.request_memory, self.request_cpu, self.limit_memory,
self.limit_cpu, self.gcp_service_account_key, self.node_selectors,
self.affinity, self.annotations, self.volumes, self.volume_mounts,
self.tolerations)
self.tolerations, self.labels)

@staticmethod
def from_dict(obj):
Expand All @@ -96,6 +97,7 @@ def from_dict(obj):
volumes=namespaced.get('volumes', []),
volume_mounts=namespaced.get('volume_mounts', []),
tolerations=namespaced.get('tolerations', None),
labels=namespaced.get('labels', {}),
)

def as_dict(self):
Expand All @@ -113,6 +115,7 @@ def as_dict(self):
'volumes': self.volumes,
'volume_mounts': self.volume_mounts,
'tolerations': self.tolerations,
'labels': self.labels,
}


Expand Down
5 changes: 3 additions & 2 deletions airflow/kubernetes/worker_configuration.py
Expand Up @@ -201,8 +201,9 @@ def _get_security_context(self):

return security_context

def _get_labels(self, labels):
def _get_labels(self, kube_executor_labels, labels):
copy = self.kube_config.kube_labels.copy()
copy.update(kube_executor_labels)
copy.update(labels)
return copy

Expand Down Expand Up @@ -337,7 +338,7 @@ def make_pod(self, namespace, worker_uuid, pod_id, dag_id, task_id, execution_da
image_pull_policy=(kube_executor_config.image_pull_policy or
self.kube_config.kube_image_pull_policy),
cmds=airflow_command,
labels=self._get_labels({
labels=self._get_labels(kube_executor_config.labels, {
'airflow-worker': worker_uuid,
'dag_id': dag_id,
'task_id': task_id,
Expand Down
8 changes: 6 additions & 2 deletions tests/executors/test_kubernetes_executor.py
Expand Up @@ -633,10 +633,14 @@ def test_get_configmaps(self):

def test_get_labels(self):
worker_config = WorkerConfiguration(self.kube_config)
labels = worker_config._get_labels({
labels = worker_config._get_labels({'my_kube_executor_label': 'kubernetes'}, {
'dag_id': 'override_dag_id',
})
self.assertEqual({'my_label': 'label_id', 'dag_id': 'override_dag_id'}, labels)
self.assertEqual({
'my_label': 'label_id',
'dag_id': 'override_dag_id',
'my_kube_executor_label': 'kubernetes'
}, labels)


class TestKubernetesExecutor(unittest.TestCase):
Expand Down