Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,10 +500,14 @@ def create_tar_file(source_files, target=None):
Returns:
(str): path to created tar file
"""
for sf in source_files:
_validate_tar_source_path(sf)

if target:
filename = target
else:
_, filename = tempfile.mkstemp()
file_descriptor, filename = tempfile.mkstemp()
os.close(file_descriptor)

with tarfile.open(filename, mode="w:gz", dereference=True) as t:
for sf in source_files:
Expand All @@ -512,6 +516,45 @@ def create_tar_file(source_files, target=None):
return filename


def _validate_tar_source_path(source_path):
"""Validate that tar source symlinks do not escape their source root."""
source_root = os.path.realpath(source_path)

if _is_link_or_junction(source_path):
_validate_link_target_within_root(source_path, os.path.dirname(source_path))
return

if not os.path.isdir(source_path):
return

for root, dirs, files in os.walk(source_path, followlinks=False):
for name in dirs + files:
path = os.path.join(root, name)
if _is_link_or_junction(path):
_validate_link_target_within_root(path, source_root)


def _is_link_or_junction(path):
is_junction = getattr(os.path, "isjunction", lambda _: False)
return os.path.islink(path) or is_junction(path)


def _validate_link_target_within_root(link_path, source_root):
link_real = os.path.realpath(link_path)
root_real = os.path.realpath(source_root)
try:
common_path = os.path.commonpath([link_real, root_real])
except ValueError as e:
raise ValueError(
f"Source file link '{link_path}' resolves outside the source directory"
) from e

if common_path != root_real:
raise ValueError(
f"Source file link '{link_path}' resolves outside the source directory"
)


@contextlib.contextmanager
def _tmpdir(suffix="", prefix="tmp", directory=None):
"""Create a temporary directory with a context manager.
Expand Down
35 changes: 35 additions & 0 deletions sagemaker-core/tests/unit/test_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,41 @@ def test_create_tar_file_with_target(self, tmp_path):
assert os.path.exists(tar_path)
os.remove(tar_path)

def test_create_tar_file_rejects_symlink_outside_source_directory(self, tmp_path):
"""Test that tar creation does not dereference links outside the source."""
source_dir = tmp_path / "source"
source_dir.mkdir()
outside_file = tmp_path / "secret.txt"
outside_file.write_text("secret")
link_path = source_dir / "leak.txt"

try:
link_path.symlink_to(outside_file)
except (OSError, NotImplementedError):
pytest.skip("symlink creation is not available")

with pytest.raises(ValueError, match="resolves outside the source directory"):
create_tar_file([str(source_dir)])

def test_create_tar_file_allows_symlink_inside_source_directory(self, tmp_path):
"""Test that safe internal links are not rejected."""
source_dir = tmp_path / "source"
source_dir.mkdir()
target_file = source_dir / "data.txt"
target_file.write_text("data")
link_path = source_dir / "data-link.txt"

try:
link_path.symlink_to(target_file)
except (OSError, NotImplementedError):
pytest.skip("symlink creation is not available")

tar_path = create_tar_file([str(source_dir)])
try:
assert os.path.exists(tar_path)
finally:
os.remove(tar_path)


class TestTmpdir:
"""Test _tmpdir context manager."""
Expand Down
Loading