diff --git a/service_configuration_lib/spark_config.py b/service_configuration_lib/spark_config.py index 5e293ae..9cca422 100644 --- a/service_configuration_lib/spark_config.py +++ b/service_configuration_lib/spark_config.py @@ -77,6 +77,12 @@ SUPPORTED_CLUSTER_MANAGERS = ['kubernetes', 'local'] DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml' +TICKET_NOT_REQUIRED_USERS = { + 'batch', # non-human spark-run from batch boxes + 'TRON', # tronjobs that run commands like paasta mark-for-deployment + None, # placeholder for being unable to determine user +} +USER_LABEL_UNSPECIFIED = 'UNSPECIFIED' log = logging.Logger(__name__) log.setLevel(logging.WARN) @@ -305,7 +311,7 @@ def _get_k8s_spark_env( service_account_name: Optional[str] = None, include_self_managed_configs: bool = True, k8s_server_address: Optional[str] = None, - user: Optional[str] = None, + user: Optional[str] = USER_LABEL_UNSPECIFIED, jira_ticket: Optional[str] = None, ) -> Dict[str, str]: # RFC 1123: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names @@ -314,8 +320,6 @@ def _get_k8s_spark_env( _paasta_cluster = utils.get_k8s_resource_name_limit_size_with_hash(paasta_cluster) _paasta_service = utils.get_k8s_resource_name_limit_size_with_hash(paasta_service) _paasta_instance = utils.get_k8s_resource_name_limit_size_with_hash(paasta_instance) - if not user: - user = os.environ.get('USER', 'UNSPECIFIED') spark_env = { 'spark.master': f'k8s://https://k8s.{paasta_cluster}.paasta:6443', @@ -1040,6 +1044,7 @@ def get_spark_conf( :param service_account_name: The k8s service account to use for spark k8s authentication. :param force_spark_resource_configs: skip the resource/instances recalculation. This is strongly not recommended. + :param user: the user who is running the spark job. :returns: spark opts in a dict. """ # Mesos deprecation @@ -1051,8 +1056,11 @@ def get_spark_conf( # is str type. user_spark_opts = _convert_user_spark_opts_value_to_str(user_spark_opts) + # Get user from environment variables if it's not set + user = user or os.environ.get('USER', None) + if self.mandatory_default_spark_srv_conf.get('spark.yelp.jira_ticket.enabled') == 'true': - needs_jira_check = os.environ.get('USER', '') not in ['batch', 'TRON', ''] + needs_jira_check = cluster_manager != 'local' and user not in TICKET_NOT_REQUIRED_USERS if needs_jira_check: valid_ticket = self._get_valid_jira_ticket(jira_ticket) if valid_ticket is None: @@ -1060,7 +1068,7 @@ def get_spark_conf( 'Job requires a valid Jira ticket (format PROJ-1234).\n' 'Please pass the parameter as: paasta spark-run --jira-ticket=PROJ-1234 \n' 'For more information: https://yelpwiki.yelpcorp.com/spaces/AML/pages/402885641 \n' - 'If you have questions, please reach out to #spark on Slack.\n' + f'If you have questions, please reach out to #spark on Slack. (user={user})\n' ) raise RuntimeError(error_msg) else: diff --git a/tests/spark_config_test.py b/tests/spark_config_test.py index 2aaab2e..29f9a11 100644 --- a/tests/spark_config_test.py +++ b/tests/spark_config_test.py @@ -623,6 +623,20 @@ def gpu_pool(self, tmpdir, monkeypatch): }, False, ), + ( + 'Local spark cluster', + 'local', + { + 'spark.executor.cores': '2', + 'spark.executor.instances': '600', + }, + { + 'spark.executor.memory': '28g', + 'spark.executor.cores': '4', + 'spark.executor.instances': '600', + }, + False, + ), ], ) def test_adjust_spark_requested_resources( @@ -1887,37 +1901,27 @@ def test_get_spark_conf_with_jira_validation_disabled(self, mock_spark_srv_conf_ assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' not in result @pytest.mark.parametrize( - 'user_env,should_check', [ - ('regular_user', True), - ('batch', False), - ('TRON', False), - ('', False), + 'cluster_manager,user,should_check', [ + ('kubernetes', 'regular_user', True), + ('kubernetes', 'batch', False), + ('kubernetes', 'TRON', False), + ('kubernetes', None, False), + ('local', 'regular_user', False), + ('local', 'TRON', False), + ('local', None, False), ], ) def test_jira_ticket_check_for_different_users( - self, user_env, should_check, mock_spark_srv_conf_file_with_jira_enabled, mock_log, + self, cluster_manager, user, should_check, mock_spark_srv_conf_file_with_jira_enabled, mock_log, ): """Test that Jira ticket validation is skipped for certain users.""" - with mock.patch.dict(os.environ, {'USER': user_env}): - spark_conf_builder = spark_config.SparkConfBuilder() + spark_conf_builder = spark_config.SparkConfBuilder() - if should_check: - # For regular users, validation should be enforced - with pytest.raises(RuntimeError): - spark_conf_builder.get_spark_conf( - cluster_manager='kubernetes', - spark_app_base_name='test-app', - user_spark_opts={}, - paasta_cluster='test-cluster', - paasta_pool='test-pool', - paasta_service='test-service', - paasta_instance='test-instance', - docker_img='test-image', - ) - else: - # For special users, validation should be skipped + if should_check: + # For regular users, validation should be enforced + with pytest.raises(RuntimeError): spark_conf_builder.get_spark_conf( - cluster_manager='kubernetes', + cluster_manager=cluster_manager, spark_app_base_name='test-app', user_spark_opts={}, paasta_cluster='test-cluster', @@ -1925,5 +1929,19 @@ def test_jira_ticket_check_for_different_users( paasta_service='test-service', paasta_instance='test-instance', docker_img='test-image', + user=user, ) - mock_log.debug.assert_called_with('Jira ticket check not required for this job configuration.') + else: + # For special users, validation should be skipped + spark_conf_builder.get_spark_conf( + cluster_manager=cluster_manager, + spark_app_base_name='test-app', + user_spark_opts={}, + paasta_cluster='test-cluster', + paasta_pool='test-pool', + paasta_service='test-service', + paasta_instance='test-instance', + docker_img='test-image', + user=user, + ) + mock_log.debug.assert_called_with('Jira ticket check not required for this job configuration.')