Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix potential tlsport selection collision by using state file as
tlsport lock file

This removes all potential concurrency issues during tls port selection
and ensures that concurrent mounts will never select the same port.
  • Loading branch information
RyanStan committed Dec 14, 2022
1 parent d504305 commit f3a8f88
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 34 deletions.
5 changes: 4 additions & 1 deletion amazon-efs-utils.spec
Expand Up @@ -35,7 +35,7 @@
%endif

Name : amazon-efs-utils
Version : 1.34.3
Version : 1.34.4
Release : 1%{platform}
Summary : This package provides utilities for simplifying the use of EFS file systems

Expand Down Expand Up @@ -137,6 +137,9 @@ fi
%clean

%changelog
* Tue Dec 13 2022 Ryan Stankiewicz <rjstank@amazon.com> - 1.34.4
- Fix potential tlsport selection collision by using state file as tlsport lock file.

* Thu Dec 1 2022 Preetham Puneeth Munipalli <tmunipre@amazon.com> - 1.34.3
- Fix potential tlsport selection race condition by closing socket right before establishing stunnel
- Fix stunnel constantly restart issue when upgrading from 1.32.1 and before version to latest version
Expand Down
2 changes: 1 addition & 1 deletion build-deb.sh
Expand Up @@ -11,7 +11,7 @@ set -ex

BASE_DIR=$(pwd)
BUILD_ROOT=${BASE_DIR}/build/debbuild
VERSION=1.34.3
VERSION=1.34.4
RELEASE=1
DEB_SYSTEM_RELEASE_PATH=/etc/os-release

Expand Down
2 changes: 1 addition & 1 deletion config.ini
Expand Up @@ -7,5 +7,5 @@
#

[global]
version=1.34.3
version=1.34.4
release=1
2 changes: 1 addition & 1 deletion dist/amazon-efs-utils.control
@@ -1,6 +1,6 @@
Package: amazon-efs-utils
Architecture: all
Version: 1.34.3
Version: 1.34.4
Section: utils
Depends: python3, nfs-common, stunnel4 (>= 4.56), openssl (>= 1.0.2), util-linux
Priority: optional
Expand Down
83 changes: 67 additions & 16 deletions src/mount_efs/__init__.py
Expand Up @@ -85,7 +85,7 @@
BOTOCORE_PRESENT = False


VERSION = "1.34.3"
VERSION = "1.34.4"
SERVICE = "elasticfilesystem"

AMAZON_LINUX_2_RELEASE_ID = "Amazon Linux release 2 (Karoo)"
Expand Down Expand Up @@ -939,7 +939,7 @@ def get_tls_port_range(config):
return lower_bound, upper_bound


def choose_tls_port_and_get_bind_sock(config, options):
def choose_tls_port_and_get_bind_sock(config, options, state_file_dir):
if "tlsport" in options:
ports_to_try = [int(options["tlsport"])]
else:
Expand All @@ -951,10 +951,14 @@ def choose_tls_port_and_get_bind_sock(config, options):
random.shuffle(ports_to_try)

if "netns" not in options:
tls_port_sock = find_tls_port_in_range_and_get_bind_sock(ports_to_try)
tls_port_sock = find_tls_port_in_range_and_get_bind_sock(
ports_to_try, state_file_dir
)
else:
with NetNS(nspath=options["netns"]):
tls_port_sock = find_tls_port_in_range_and_get_bind_sock(ports_to_try)
tls_port_sock = find_tls_port_in_range_and_get_bind_sock(
ports_to_try, state_file_dir
)

if tls_port_sock:
return tls_port_sock
Expand All @@ -971,20 +975,44 @@ def choose_tls_port_and_get_bind_sock(config, options):
)


def find_tls_port_in_range_and_get_bind_sock(ports_to_try):
def find_tls_port_in_range_and_get_bind_sock(ports_to_try, state_file_dir):
sock = socket.socket()
for tls_port in ports_to_try:
mount = find_existing_mount_using_tls_port(state_file_dir, tls_port)
if mount:
logging.debug(
"Skip binding TLS port %s as it is already assigned to %s",
tls_port,
mount,
)
continue
try:
logging.info("binding %s", tls_port)
sock.bind(("localhost", tls_port))
return sock
except socket.error as e:
logging.info(e)
logging.warning(e)
continue
sock.close()
return None


def find_existing_mount_using_tls_port(state_file_dir, tls_port):
if not os.path.exists(state_file_dir):
logging.debug(
"State file dir %s does not exist, assuming no existing mount using tls port %s",
state_file_dir,
tls_port,
)
return None

for fname in os.listdir(state_file_dir):
if fname.endswith(".%s" % tls_port):
return fname

return None


def is_ocsp_enabled(config, options):
if "ocsp" in options:
return True
Expand Down Expand Up @@ -1301,6 +1329,24 @@ def write_tls_tunnel_state_file(
return state_file


def rewrite_tls_tunnel_state_file(state, state_file_dir, state_file):
with open(os.path.join(state_file_dir, state_file), "w") as f:
json.dump(state, f)
return state_file


def update_tls_tunnel_temp_state_file_with_tunnel_pid(
temp_tls_state_file, state_file_dir, stunnel_pid
):
with open(os.path.join(state_file_dir, temp_tls_state_file), "r") as f:
state = json.load(f)
state["pid"] = stunnel_pid
temp_tls_state_file = rewrite_tls_tunnel_state_file(
state, state_file_dir, temp_tls_state_file
)
return temp_tls_state_file


def test_tunnel_process(tunnel_proc, fs_id):
tunnel_proc.poll()
if tunnel_proc.returncode is not None:
Expand Down Expand Up @@ -1478,7 +1524,7 @@ def bootstrap_tls(
state_file_dir=STATE_FILE_DIR,
fallback_ip_address=None,
):
tls_port_sock = choose_tls_port_and_get_bind_sock(config, options)
tls_port_sock = choose_tls_port_and_get_bind_sock(config, options, state_file_dir)
tls_port = get_tls_port_from_sock(tls_port_sock)

try:
Expand Down Expand Up @@ -1560,6 +1606,18 @@ def bootstrap_tls(
tunnel_args = [_stunnel_bin(), stunnel_config_file]
if "netns" in options:
tunnel_args = ["nsenter", "--net=" + options["netns"]] + tunnel_args

# This temp state file is acting like a tlsport lock file, which is why pid =- 1
temp_tls_state_file = write_tls_tunnel_state_file(
fs_id,
mountpoint,
tls_port,
-1,
tunnel_args,
[stunnel_config_file],
state_file_dir,
cert_details=cert_details,
)
finally:
# Always close the socket we created when choosing TLS port only until now to
# 1. avoid concurrent TLS mount port collision 2. enable stunnel process to bind the port
Expand All @@ -1577,15 +1635,8 @@ def bootstrap_tls(
)
logging.info("Started TLS tunnel, pid: %d", tunnel_proc.pid)

temp_tls_state_file = write_tls_tunnel_state_file(
fs_id,
mountpoint,
tls_port,
tunnel_proc.pid,
tunnel_args,
[stunnel_config_file],
state_file_dir,
cert_details=cert_details,
update_tls_tunnel_temp_state_file_with_tunnel_pid(
temp_tls_state_file, state_file_dir, tunnel_proc.pid
)

if "netns" not in options:
Expand Down
2 changes: 1 addition & 1 deletion src/watchdog/__init__.py
Expand Up @@ -56,7 +56,7 @@
AMAZON_LINUX_2_RELEASE_ID,
AMAZON_LINUX_2_PRETTY_NAME,
]
VERSION = "1.34.3"
VERSION = "1.34.4"
SERVICE = "elasticfilesystem"

CONFIG_FILE = "/etc/amazon/efs/efs-utils.conf"
Expand Down
9 changes: 9 additions & 0 deletions test/mount_efs_test/test_bootstrap_tls.py
Expand Up @@ -44,6 +44,10 @@ def setup_mocks(mocker):
mocker.patch("mount_efs.create_certificate")
mocker.patch("os.rename")
mocker.patch("os.kill")
mocker.patch(
"mount_efs.update_tls_tunnel_temp_state_file_with_tunnel_pid",
return_value="~mocktempfile",
)

process_mock = MagicMock()
process_mock.communicate.return_value = (
Expand Down Expand Up @@ -72,6 +76,10 @@ def setup_mocks_without_popen(mocker):
)
mocker.patch("mount_efs.write_tls_tunnel_state_file", return_value="~mocktempfile")
mocker.patch("os.kill")
mocker.patch(
"mount_efs.update_tls_tunnel_temp_state_file_with_tunnel_pid",
return_value="~mocktempfile",
)

write_config_mock = mocker.patch(
"mount_efs.write_stunnel_config_file", return_value=EXPECTED_STUNNEL_CONFIG_FILE
Expand Down Expand Up @@ -115,6 +123,7 @@ def config_get_side_effect(section, field):
assert not os.path.exists(state_file_dir)

mocker.patch("mount_efs._stunnel_bin", return_value="/usr/bin/stunnel")
mocker.patch("mount_efs.find_existing_mount_using_tls_port", return_value=None)
with mount_efs.bootstrap_tls(
MOCK_CONFIG, INIT_SYSTEM, DNS_NAME, FS_ID, MOUNT_POINT, {}, state_file_dir
):
Expand Down
59 changes: 46 additions & 13 deletions test/mount_efs_test/test_choose_tls_port.py
Expand Up @@ -3,9 +3,12 @@
# Licensed under the MIT License. See the LICENSE accompanying this file
# for the specific language governing permissions and limitations under
# the License.

import logging
import random
import socket
import sys
import tempfile
import unittest
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -45,42 +48,70 @@ def _get_config():
return config


def test_choose_tls_port_first_try(mocker):
def test_choose_tls_port_first_try(mocker, tmpdir):
sock_mock = MagicMock()
sock_mock.getsockname.return_value = ("local_host", DEFAULT_TLS_PORT)
mocker.patch("socket.socket", return_value=sock_mock)
options = {}

tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(
_get_config(), options, str(tmpdir)
)
tls_port = mount_efs.get_tls_port_from_sock(tls_port_sock)
assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH


def test_choose_tls_port_second_try(mocker):
def test_choose_tls_port_second_try(mocker, tmpdir):
bad_sock = MagicMock()
bad_sock.bind.side_effect = [socket.error, None]
bad_sock.getsockname.return_value = ("local_host", DEFAULT_TLS_PORT)
options = {}

mocker.patch("socket.socket", return_value=bad_sock)

tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(
_get_config(), options, str(tmpdir)
)
tls_port = mount_efs.get_tls_port_from_sock(tls_port_sock)

assert DEFAULT_TLS_PORT_RANGE_LOW <= tls_port <= DEFAULT_TLS_PORT_RANGE_HIGH
assert 2 == bad_sock.bind.call_count
assert 1 == bad_sock.getsockname.call_count


def test_choose_tls_port_never_succeeds(mocker, capsys):
@unittest.skipIf(sys.version_info < (3, 6), reason="requires python3.6")
def test_choose_tls_port_collision(mocker, tmpdir, caplog):
"""Ensure we don't choose a port that is pending mount"""
sock = MagicMock()
mocker.patch("socket.socket", return_value=sock)
mocker.patch(
"random.shuffle",
return_value=range(DEFAULT_TLS_PORT_RANGE_LOW, DEFAULT_TLS_PORT_RANGE_HIGH),
)

port_suffix = ".%s" % str(DEFAULT_TLS_PORT_RANGE_LOW)
temp_state_file = tempfile.NamedTemporaryFile(
suffix=port_suffix, prefix="~", dir=tmpdir
)

options = {}
with caplog.at_level(logging.DEBUG):
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options, tmpdir)

temp_state_file.close()
sock.bind.assert_called_once_with(("localhost", DEFAULT_TLS_PORT_RANGE_LOW + 1))
assert "Skip binding TLS port" in caplog.text


def test_choose_tls_port_never_succeeds(mocker, tmpdir, capsys):
bad_sock = MagicMock()
bad_sock.bind.side_effect = socket.error()
options = {}

mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options, str(tmpdir))

assert 0 != ex.value.code

Expand All @@ -93,27 +124,29 @@ def test_choose_tls_port_never_succeeds(mocker, capsys):
)


def test_choose_tls_port_option_specified(mocker):
def test_choose_tls_port_option_specified(mocker, tmpdir):
sock_mock = MagicMock()
sock_mock.getsockname.return_value = ("local_host", DEFAULT_TLS_PORT)
mocker.patch("socket.socket", return_value=sock_mock)
options = {"tlsport": DEFAULT_TLS_PORT}

tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
tls_port_sock = mount_efs.choose_tls_port_and_get_bind_sock(
_get_config(), options, str(tmpdir)
)
tls_port = mount_efs.get_tls_port_from_sock(tls_port_sock)

assert DEFAULT_TLS_PORT == tls_port


def test_choose_tls_port_option_specified_unavailable(mocker, capsys):
def test_choose_tls_port_option_specified_unavailable(mocker, tmpdir, capsys):
bad_sock = MagicMock()
bad_sock.bind.side_effect = socket.error()
options = {"tlsport": 1000}

mocker.patch("socket.socket", return_value=bad_sock)

with pytest.raises(SystemExit) as ex:
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options, str(tmpdir))

assert 0 != ex.value.code

Expand All @@ -123,13 +156,13 @@ def test_choose_tls_port_option_specified_unavailable(mocker, capsys):
assert 1 == bad_sock.bind.call_count


def test_choose_tls_port_under_netns(mocker, capsys):
def test_choose_tls_port_under_netns(mocker, tmpdir):
mocker.patch("builtins.open")
setns_mock = mocker.patch("mount_efs.setns", return_value=(None, None))
mocker.patch("socket.socket", return_value=MagicMock())
options = {"netns": "/proc/1000/ns/net"}

mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options)
mount_efs.choose_tls_port_and_get_bind_sock(_get_config(), options, str(tmpdir))
utils.assert_called(setns_mock)


Expand Down

0 comments on commit f3a8f88

Please sign in to comment.