Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@

SUPPORTED_CLUSTER_MANAGERS = ['kubernetes', 'local']
DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
TICKET_NOT_REQUIRED_USERS = {
'batch', # non-human spark-run from batch boxes
'TRON', # tronjobs that run commands like paasta mark-for-deployment
None, # placeholder for being unable to determine user
}
USER_LABEL_UNSPECIFIED = 'UNSPECIFIED'

log = logging.Logger(__name__)
log.setLevel(logging.WARN)
Expand Down Expand Up @@ -305,7 +311,7 @@ def _get_k8s_spark_env(
service_account_name: Optional[str] = None,
include_self_managed_configs: bool = True,
k8s_server_address: Optional[str] = None,
user: Optional[str] = None,
user: Optional[str] = USER_LABEL_UNSPECIFIED,
jira_ticket: Optional[str] = None,
) -> Dict[str, str]:
# RFC 1123: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-label-names
Expand All @@ -314,8 +320,6 @@ def _get_k8s_spark_env(
_paasta_cluster = utils.get_k8s_resource_name_limit_size_with_hash(paasta_cluster)
_paasta_service = utils.get_k8s_resource_name_limit_size_with_hash(paasta_service)
_paasta_instance = utils.get_k8s_resource_name_limit_size_with_hash(paasta_instance)
if not user:
user = os.environ.get('USER', 'UNSPECIFIED')

spark_env = {
'spark.master': f'k8s://https://k8s.{paasta_cluster}.paasta:6443',
Expand Down Expand Up @@ -1040,6 +1044,7 @@ def get_spark_conf(
:param service_account_name: The k8s service account to use for spark k8s authentication.
:param force_spark_resource_configs: skip the resource/instances recalculation.
This is strongly not recommended.
:param user: the user who is running the spark job.
:returns: spark opts in a dict.
"""
# Mesos deprecation
Expand All @@ -1051,16 +1056,19 @@ def get_spark_conf(
# is str type.
user_spark_opts = _convert_user_spark_opts_value_to_str(user_spark_opts)

# Get user from environment variables if it's not set
user = user or os.environ.get('USER', None)

if self.mandatory_default_spark_srv_conf.get('spark.yelp.jira_ticket.enabled') == 'true':
needs_jira_check = os.environ.get('USER', '') not in ['batch', 'TRON', '']
needs_jira_check = cluster_manager != 'local' and user not in TICKET_NOT_REQUIRED_USERS
if needs_jira_check:
valid_ticket = self._get_valid_jira_ticket(jira_ticket)
if valid_ticket is None:
error_msg = (
'Job requires a valid Jira ticket (format PROJ-1234).\n'
'Please pass the parameter as: paasta spark-run --jira-ticket=PROJ-1234 \n'
'For more information: https://yelpwiki.yelpcorp.com/spaces/AML/pages/402885641 \n'
'If you have questions, please reach out to #spark on Slack.\n'
f'If you have questions, please reach out to #spark on Slack. (user={user})\n'
)
raise RuntimeError(error_msg)
else:
Expand Down
68 changes: 43 additions & 25 deletions tests/spark_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,20 @@ def gpu_pool(self, tmpdir, monkeypatch):
},
False,
),
(
'Local spark cluster',
'local',
{
'spark.executor.cores': '2',
'spark.executor.instances': '600',
},
{
'spark.executor.memory': '28g',
'spark.executor.cores': '4',
'spark.executor.instances': '600',
},
False,
),
],
)
def test_adjust_spark_requested_resources(
Expand Down Expand Up @@ -1887,43 +1901,47 @@ def test_get_spark_conf_with_jira_validation_disabled(self, mock_spark_srv_conf_
assert 'spark.kubernetes.executor.label.spark.yelp.com/jira_ticket' not in result

@pytest.mark.parametrize(
'user_env,should_check', [
('regular_user', True),
('batch', False),
('TRON', False),
('', False),
'cluster_manager,user,should_check', [
('kubernetes', 'regular_user', True),
('kubernetes', 'batch', False),
('kubernetes', 'TRON', False),
('kubernetes', None, False),
('local', 'regular_user', False),
('local', 'TRON', False),
('local', None, False),
],
)
def test_jira_ticket_check_for_different_users(
self, user_env, should_check, mock_spark_srv_conf_file_with_jira_enabled, mock_log,
self, cluster_manager, user, should_check, mock_spark_srv_conf_file_with_jira_enabled, mock_log,
):
"""Test that Jira ticket validation is skipped for certain users."""
with mock.patch.dict(os.environ, {'USER': user_env}):
spark_conf_builder = spark_config.SparkConfBuilder()
spark_conf_builder = spark_config.SparkConfBuilder()

if should_check:
# For regular users, validation should be enforced
with pytest.raises(RuntimeError):
spark_conf_builder.get_spark_conf(
cluster_manager='kubernetes',
spark_app_base_name='test-app',
user_spark_opts={},
paasta_cluster='test-cluster',
paasta_pool='test-pool',
paasta_service='test-service',
paasta_instance='test-instance',
docker_img='test-image',
)
else:
# For special users, validation should be skipped
if should_check:
# For regular users, validation should be enforced
with pytest.raises(RuntimeError):
spark_conf_builder.get_spark_conf(
cluster_manager='kubernetes',
cluster_manager=cluster_manager,
spark_app_base_name='test-app',
user_spark_opts={},
paasta_cluster='test-cluster',
paasta_pool='test-pool',
paasta_service='test-service',
paasta_instance='test-instance',
docker_img='test-image',
user=user,
)
mock_log.debug.assert_called_with('Jira ticket check not required for this job configuration.')
else:
# For special users, validation should be skipped
spark_conf_builder.get_spark_conf(
cluster_manager=cluster_manager,
spark_app_base_name='test-app',
user_spark_opts={},
paasta_cluster='test-cluster',
paasta_pool='test-pool',
paasta_service='test-service',
paasta_instance='test-instance',
docker_img='test-image',
user=user,
)
mock_log.debug.assert_called_with('Jira ticket check not required for this job configuration.')