Skip to content

Commit

Permalink
MLCOMPUTE-949 | Pick spark ui port from a preferred port range (#128)
Browse files Browse the repository at this point in the history
* Pick spark ui port from a preferred port range

* Load port range from srv configs

* Bump version

* Fix typo and add more comments
  • Loading branch information
edingroot committed Sep 20, 2023
1 parent 653ea7e commit c9aaf8e
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 31 deletions.
19 changes: 9 additions & 10 deletions service_configuration_lib/spark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from urllib.parse import urlparse

import boto3
import ephemeral_port_reserve
import requests
import yaml
from boto3 import Session

from service_configuration_lib import utils
from service_configuration_lib.text_colors import TextColors
from service_configuration_lib.utils import load_spark_srv_conf

AWS_CREDENTIALS_DIR = '/etc/boto_cfg/'
AWS_ENV_CREDENTIALS_PROVIDER = 'com.amazonaws.auth.EnvironmentVariableCredentialsProvider'
Expand Down Expand Up @@ -78,7 +77,6 @@

SUPPORTED_CLUSTER_MANAGERS = ['kubernetes', 'local']
DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'
PREFERRED_SPARK_UI_PORT = 39091

log = logging.Logger(__name__)
log.setLevel(logging.INFO)
Expand Down Expand Up @@ -184,11 +182,6 @@ def assume_aws_role(
return resp['Credentials']


def _pick_random_port(preferred_port: int = 0) -> int:
"""Return a random port. """
return ephemeral_port_reserve.reserve('0.0.0.0', preferred_port)


def _get_k8s_docker_volumes_conf(
volumes: Optional[List[Mapping[str, str]]] = None,
):
Expand Down Expand Up @@ -418,7 +411,7 @@ def __init__(self):
(
self.spark_srv_conf, self.spark_constants, self.default_spark_srv_conf,
self.mandatory_default_spark_srv_conf, self.spark_costs,
) = load_spark_srv_conf()
) = utils.load_spark_srv_conf()
except Exception as e:
log.error(f'Failed to load Spark srv configs: {e}')

Expand Down Expand Up @@ -1075,9 +1068,15 @@ def get_spark_conf(
spark_app_base_name
)

# Pick a port from a pre-defined port range, which will then be used by our Jupyter
# server metric aggregator API. The aggregator API collects Prometheus metrics from multiple
# Spark sessions and exposes them through a single endpoint.
ui_port = int(
(spark_opts_from_env or {}).get('spark.ui.port') or
_pick_random_port(PREFERRED_SPARK_UI_PORT),
utils.ephemeral_port_reserve_range(
self.spark_constants.get('preferred_spark_ui_port_start'),
self.spark_constants.get('preferred_spark_ui_port_end'),
),
)

spark_conf = {**(spark_opts_from_env or {}), **_filter_user_spark_opts(user_spark_opts)}
Expand Down
51 changes: 51 additions & 0 deletions service_configuration_lib/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import contextlib
import errno
import logging
from socket import error as SocketError
from socket import SO_REUSEADDR
from socket import socket
from socket import SOL_SOCKET
from typing import Mapping
from typing import Tuple

import yaml


DEFAULT_SPARK_RUN_CONFIG = '/nail/srv/configs/spark.yaml'

log = logging.Logger(__name__)
Expand All @@ -28,3 +35,47 @@ def load_spark_srv_conf(preset_values=None) -> Tuple[Mapping, Mapping, Mapping,
except Exception as e:
log.warning(f'Failed to load {DEFAULT_SPARK_RUN_CONFIG}: {e}')
raise e


def ephemeral_port_reserve_range(preferred_port_start: int, preferred_port_end: int, ip='127.0.0.1') -> int:
"""
Pick an available from the preferred port range. If all ports from the port range are unavailable,
pick a random available ephemeral port.
Implemetation referenced from upstream:
https://github.com/Yelp/ephemeral-port-reserve/blob/master/ephemeral_port_reserve.py
This function is used to pick a Spark UI (API) port from a pre-defined port range which is used by
our Jupyter server metric aggregator. The aggregator API collects Prometheus metrics from multiple
Spark sessions and exposes them through a single endpoint.
"""
assert preferred_port_start <= preferred_port_end

with contextlib.closing(socket()) as s:
binded = False
for port in range(preferred_port_start, preferred_port_end + 1):
s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
try:
s.bind((ip, port))
binded = True
break
except SocketError as e:
# socket.error: EADDRINUSE Address already in use
if e.errno == errno.EADDRINUSE:
continue
else:
raise
if not binded:
s.bind((ip, 0))

# the connect below deadlocks on kernel >= 4.4.0 unless this arg is greater than zero
s.listen(1)

sockname = s.getsockname()

# these three are necessary just to get the port into a TIME_WAIT state
with contextlib.closing(socket()) as s2:
s2.connect(sockname)
sock, _ = s.accept()
with contextlib.closing(sock):
return sockname[1]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

setup(
name='service-configuration-lib',
version='2.18.6',
version='2.18.7',
provides=['service_configuration_lib'],
description='Start, stop, and inspect Yelp SOA services',
url='https://github.com/Yelp/service_configuration_lib',
Expand Down
32 changes: 12 additions & 20 deletions tests/spark_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,6 @@ def test_fail(self, tmpdir):
spark_config.get_aws_credentials(aws_credentials_yaml=str(fp))


def test_pick_random_port():
with mock.patch('ephemeral_port_reserve.reserve') as mock_reserve:
preferred_port = 33123 # Any ephemeral port for testing
port = spark_config._pick_random_port(preferred_port)
(host, prefer_port), _ = mock_reserve.call_args
assert host == '0.0.0.0'
assert prefer_port >= 33000
assert port == mock_reserve.return_value


class MockConfigFunction:

def __init__(self, mock_obj, mock_func, return_value):
Expand Down Expand Up @@ -1092,13 +1082,13 @@ def test_convert_user_spark_opts_value_str(self):
}

@pytest.fixture
def mock_pick_random_port(self):
def mock_ephemeral_port_reserve_range(self):
port = '12345'
with mock.patch.object(spark_config, '_pick_random_port', return_value=port):
with mock.patch.object(utils, 'ephemeral_port_reserve_range', return_value=port):
yield port

@pytest.fixture(params=[None, '23456'])
def ui_port(self, request, mock_pick_random_port):
def ui_port(self, request):
return request.param

@pytest.fixture(params=[None, 'test_app_name_from_env'])
Expand All @@ -1111,8 +1101,8 @@ def spark_opts_from_env(self, request, ui_port):
return spark_opts or None

@pytest.fixture
def assert_ui_port(self, spark_opts_from_env, ui_port, mock_pick_random_port):
expected_output = ui_port if ui_port else mock_pick_random_port
def assert_ui_port(self, ui_port, mock_ephemeral_port_reserve_range):
expected_output = ui_port or mock_ephemeral_port_reserve_range

def verify(output):
key = 'spark.ui.port'
Expand All @@ -1125,13 +1115,13 @@ def user_spark_opts(self, request):
return request.param

@pytest.fixture
def assert_app_name(self, spark_opts_from_env, user_spark_opts, ui_port, mock_pick_random_port):
def assert_app_name(self, spark_opts_from_env, user_spark_opts, ui_port, mock_ephemeral_port_reserve_range):
expected_output = (spark_opts_from_env or {}).get('spark.app.name')
if not expected_output:
expected_output = (
(user_spark_opts or {}).get('spark.app.name') or
self.spark_app_base_name
) + '_' + (ui_port or mock_pick_random_port) + '_123'
) + '_' + (ui_port or mock_ephemeral_port_reserve_range) + '_123'

def verify(output):
key = 'spark.app.name'
Expand Down Expand Up @@ -1189,8 +1179,8 @@ def _get_k8s_base_volumes(self):
]

@pytest.fixture
def assert_kubernetes_conf(self, base_volumes, ui_port, mock_pick_random_port):
expected_ui_port = ui_port if ui_port else mock_pick_random_port
def assert_kubernetes_conf(self, base_volumes, ui_port, mock_ephemeral_port_reserve_range):
expected_ui_port = ui_port if ui_port else mock_ephemeral_port_reserve_range

expected_output = {
'spark.master': f'k8s://https://k8s.{self.cluster}.paasta:6443',
Expand Down Expand Up @@ -1238,7 +1228,6 @@ def test_leaders_get_spark_conf_kubernetes(
self,
user_spark_opts,
spark_opts_from_env,
ui_port,
base_volumes,
mock_append_spark_prometheus_conf,
mock_append_event_log_conf,
Expand All @@ -1248,6 +1237,7 @@ def test_leaders_get_spark_conf_kubernetes(
mock_get_dra_configs,
mock_update_spark_srv_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down Expand Up @@ -1341,6 +1331,7 @@ def test_show_console_progress_jupyter(
mock_adjust_spark_requested_resources_kubernetes,
mock_get_dra_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down Expand Up @@ -1382,6 +1373,7 @@ def test_local_spark(
mock_get_dra_configs,
mock_update_spark_srv_configs,
mock_spark_srv_conf_file,
mock_ephemeral_port_reserve_range,
mock_time,
assert_ui_port,
assert_app_name,
Expand Down

0 comments on commit c9aaf8e

Please sign in to comment.