diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py index bbe507716f959..39c9269f1544a 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/custom_object_launcher.py @@ -248,7 +248,8 @@ def get_body(self): self.body.spec["imagePullSecrets"] = k8s_spec.image_pull_secrets for item in ["driver", "executor"]: # Env List - self.body.spec[item]["env"] = k8s_spec.env_vars + existing_env = self.body.spec[item].get("env") or [] + self.body.spec[item]["env"] = existing_env + k8s_spec.env_vars self.body.spec[item]["envFrom"] = k8s_spec.env_from # Volumes self.body.spec[item]["volumeMounts"] = k8s_spec.volume_mounts diff --git a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index dfaac1ecd042d..0d77361199d15 100644 --- a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -396,6 +396,12 @@ def _setup_spark_configuration(self, context: Context): spec_dict[component]["labels"].update(task_context_labels) + spec_dict = template_body.setdefault("spark", {}).setdefault("spec", {}) + for component in ["driver", "executor"]: + env_list = spec_dict.setdefault(component, {}).setdefault("env", []) + if not any(e.get("name") == "SPARK_APPLICATION_NAME" for e in env_list): + env_list.append({"name": "SPARK_APPLICATION_NAME", "value": self.name}) + self.log.info("Creating sparkApplication.") self.launcher = CustomObjectLauncher( name=self.name, diff --git a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py index 56f3132d7d0be..b5f278ad49d31 100644 --- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -166,11 +166,12 @@ def _get_expected_application_dict_with_labels(task_name="default_yaml"): } -def _get_expected_application_dict_without_task_context_labels(task_name="default_yaml"): +def _get_expected_application_dict_without_task_context_labels(task_name="default_yaml", app_name=None): """Create expected application dict without task context labels (only original file labels).""" original_file_labels = { "version": "2.4.5", } + app_name = app_name or task_name return { "apiVersion": "sparkoperator.k8s.io/v1beta2", @@ -193,6 +194,7 @@ def _get_expected_application_dict_without_task_context_labels(task_name="defaul "labels": original_file_labels.copy(), "serviceAccount": "spark", "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + "env": [{"name": "SPARK_APPLICATION_NAME", "value": app_name}], }, "executor": { "cores": 1, @@ -200,6 +202,7 @@ def _get_expected_application_dict_without_task_context_labels(task_name="defaul "memory": "512m", "labels": original_file_labels.copy(), "volumeMounts": [{"name": "test-volume", "mountPath": "/tmp"}], + "env": [{"name": "SPARK_APPLICATION_NAME", "value": app_name}], }, }, } @@ -378,7 +381,9 @@ def test_create_application( assert isinstance(done_op.name, str) assert done_op.name != "" - expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict = _get_expected_application_dict_without_task_context_labels( + task_name, app_name=done_op.name + ) expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=expected_dict, @@ -424,7 +429,9 @@ def test_create_application_and_use_name_from_operator_args( else: assert done_op.name == name_normalized - expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict = _get_expected_application_dict_without_task_context_labels( + task_name, app_name=done_op.name + ) expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=expected_dict, @@ -467,7 +474,9 @@ def test_create_application_and_use_name_task_id( else: assert done_op.name == name_normalized - expected_dict = _get_expected_application_dict_without_task_context_labels(task_name) + expected_dict = _get_expected_application_dict_without_task_context_labels( + task_name, app_name=done_op.name + ) expected_dict["metadata"]["name"] = done_op.name mock_create_namespaced_crd.assert_called_with( body=expected_dict, @@ -504,6 +513,8 @@ def test_new_template_from_yaml( expected_dict = _get_expected_k8s_dict() expected_dict["metadata"]["name"] = done_op.name + expected_dict["spec"]["driver"]["env"] = [{"name": "SPARK_APPLICATION_NAME", "value": done_op.name}] + expected_dict["spec"]["executor"]["env"] = [{"name": "SPARK_APPLICATION_NAME", "value": done_op.name}] mock_create_namespaced_crd.assert_called_with( body=expected_dict, **self.call_commons, @@ -540,6 +551,8 @@ def test_template_spec( expected_dict = _get_expected_k8s_dict() expected_dict["metadata"]["name"] = done_op.name + expected_dict["spec"]["driver"]["env"] = [{"name": "SPARK_APPLICATION_NAME", "value": done_op.name}] + expected_dict["spec"]["executor"]["env"] = [{"name": "SPARK_APPLICATION_NAME", "value": done_op.name}] mock_create_namespaced_crd.assert_called_with( body=expected_dict, **self.call_commons, @@ -625,9 +638,11 @@ def test_env( task_name, mock_create_job_name, job_spec=job_spec, mock_get_kube_client=mock_get_kube_client ) assert op.launcher.body["spec"]["driver"]["env"] == [ + {"name": "SPARK_APPLICATION_NAME", "value": "default_env"}, k8s.V1EnvVar(name="TEST_ENV_1", value="VALUE1"), ] assert op.launcher.body["spec"]["executor"]["env"] == [ + {"name": "SPARK_APPLICATION_NAME", "value": "default_env"}, k8s.V1EnvVar(name="TEST_ENV_1", value="VALUE1"), ] @@ -1520,3 +1535,58 @@ def test_reattach_skips_launcher_creation_in_execute( # And verify delete works op.on_kill() mock_launcher_cls.return_value.delete_spark_job.assert_called() + + @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") + def test_spark_application_name_env_injected(self, mock_client): + op = SparkKubernetesOperator( + task_id="test_task", + namespace="default", + template_spec={ + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "spec": { + "driver": {}, + "executor": {}, + }, + }, + reattach_on_restart=False, + ) + op.name = "my-spark-app-abc123" + + with mock.patch.object(op, "get_or_create_spark_crd", return_value=mock.MagicMock()): + op._setup_spark_configuration(mock.MagicMock()) + + body = op.launcher.body + for component in ["driver", "executor"]: + env = body["spec"][component].get("env", []) + names = [e["name"] for e in env] + assert "SPARK_APPLICATION_NAME" in names + value = next(e["value"] for e in env if e["name"] == "SPARK_APPLICATION_NAME") + assert value == "my-spark-app-abc123" + + @patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.SparkKubernetesOperator.client") + def test_spark_application_name_env_not_duplicated(self, mock_client): + op = SparkKubernetesOperator( + task_id="test_task", + namespace="default", + template_spec={ + "apiVersion": "sparkoperator.k8s.io/v1beta2", + "kind": "SparkApplication", + "spec": { + "driver": {"env": [{"name": "SPARK_APPLICATION_NAME", "value": "user-defined"}]}, + "executor": {"env": [{"name": "SPARK_APPLICATION_NAME", "value": "user-defined"}]}, + }, + }, + reattach_on_restart=False, + ) + op.name = "my-spark-app-abc123" + + with mock.patch.object(op, "get_or_create_spark_crd", return_value=mock.MagicMock()): + op._setup_spark_configuration(mock.MagicMock()) + + body = op.launcher.body + for component in ["driver", "executor"]: + env = body["spec"][component].get("env", []) + app_name_envs = [e for e in env if e["name"] == "SPARK_APPLICATION_NAME"] + assert len(app_name_envs) == 1 # not duplicated + assert app_name_envs[0]["value"] == "user-defined"