Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ jobs:
if: (contains( matrix.os, 'ubuntu' ) || contains( matrix.os, 'macos-12')) && ( matrix.subset == 'dragon' )
run: |
smart build --device cpu --onnx --dragon -v
echo "LD_LIBRARY_PATH=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/.dragon/dragon-0.9/lib:$LD_LIBRARY_PATH" >> $GITHUB_ENV
SP=$(python -c 'import site; print(site.getsitepackages()[0])')/smartsim/_core/config/dragon/.env
LLP=$(cat $SP | grep LD_LIBRARY_PATH | awk '{split($0, array, "="); print array[2]}')
echo "LD_LIBRARY_PATH=$LLP:$LD_LIBRARY_PATH" >> $GITHUB_ENV

- name: Install ML Runtimes with Smart (no ONNX,TF on Apple Silicon)
if: contains( matrix.os, 'macos-14' )
Expand Down
1 change: 1 addition & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Jump to:

Description

- Fix dragon package installation bug
- Adjust schemas for better performance
- Add TorchWorker first implementation and mock inference app example
- Add error handling in Worker Manager pipeline
Expand Down
40 changes: 21 additions & 19 deletions smartsim/_core/_cli/scripts/dragon_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,38 +155,40 @@ def retrieve_asset(working_dir: pathlib.Path, asset: GitReleaseAsset) -> pathlib

:param working_dir: location in file system where assets should be written
:param asset: GitHub release asset to retrieve
:returns: path to the downloaded asset"""
if working_dir.exists() and list(working_dir.rglob("*.whl")):
return working_dir
:returns: path to the directory containing the extracted release asset"""
download_dir = working_dir / str(asset.id)

# if we've previously downloaded the release and still have
# wheels laying around, use that cached version instead
if download_dir.exists() and list(download_dir.rglob("*.whl")):
return download_dir

archive = WebTGZ(asset.browser_download_url)
archive.extract(working_dir)
archive.extract(download_dir)

logger.debug(f"Retrieved {asset.browser_download_url} to {working_dir}")
return working_dir
logger.debug(f"Retrieved {asset.browser_download_url} to {download_dir}")
return download_dir


def install_package(asset_dir: pathlib.Path) -> int:
"""Install the package found in `asset_dir` into the current python environment

:param asset_dir: path to a decompressed archive contents for a release asset"""
wheels = asset_dir.rglob("*.whl")
wheel_path = next(wheels, None)
if not wheel_path:
logger.error(f"No wheel found for package in {asset_dir}")
found_wheels = list(asset_dir.rglob("*.whl"))
if not found_wheels:
logger.error(f"No wheel(s) found for package in {asset_dir}")
return 1

create_dotenv(wheel_path.parent)
create_dotenv(found_wheels[0].parent)

while wheel_path is not None:
logger.info(f"Installing package: {wheel_path.absolute()}")
try:
wheels = list(map(str, found_wheels))
logger.info("Installing packages:\n%s", "\n".join(wheels))

try:
pip("install", "--force-reinstall", str(wheel_path), "numpy<2")
wheel_path = next(wheels, None)
except Exception:
logger.error(f"Unable to install from {asset_dir}")
return 1
pip("install", *wheels)
except Exception:
logger.error(f"Unable to install from {asset_dir}")
return 1

return 0

Expand Down
100 changes: 92 additions & 8 deletions tests/test_dragon_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
retrieve_asset,
retrieve_asset_info,
)
from smartsim._core._install.builder import WebTGZ
from smartsim.error.errors import SmartSimCLIActionCancelled

# The tests in this file belong to the group_a group
Expand All @@ -58,14 +59,25 @@
def test_archive(test_dir: str, archive_path: pathlib.Path) -> pathlib.Path:
"""Fixture for returning a simple tarfile to test on"""
num_files = 10

archive_name = archive_path.name
archive_name = archive_name.replace(".tar.gz", "")

with tarfile.TarFile.open(archive_path, mode="w:gz") as tar:
mock_whl = pathlib.Path(test_dir) / "mock.whl"
mock_whl = pathlib.Path(test_dir) / archive_name / f"{archive_name}.whl"
mock_whl.parent.mkdir(parents=True, exist_ok=True)
mock_whl.touch()

tar.add(mock_whl)

for i in range(num_files):
content = pathlib.Path(test_dir) / f"{i:04}.txt"
content = pathlib.Path(test_dir) / archive_name / f"{i:04}.txt"
content.write_text(f"i am file {i}\n")
tar.add(content)
content.unlink()

mock_whl.unlink()

return archive_path


Expand Down Expand Up @@ -118,6 +130,7 @@ def test_assets(monkeypatch: pytest.MonkeyPatch) -> t.Dict[str, GitReleaseAsset]
_git_attr(value=f"http://foo/{archive_name}"),
)
monkeypatch.setattr(asset, "_name", _git_attr(value=archive_name))
monkeypatch.setattr(asset, "_id", _git_attr(value=123))
assets.append(asset)

return assets
Expand Down Expand Up @@ -149,11 +162,22 @@ def test_retrieve_cached(
test_archive: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Verify that a previously retrieved asset archive is re-used"""
with tarfile.TarFile.open(test_archive) as tar:
tar.extractall(test_dir)
"""Verify that a previously retrieved asset archive is re-used and the
release asset retrieval is not attempted"""

ts1 = test_archive.parent.stat().st_ctime
asset_id = 123

def mock_webtgz_extract(self_, target_) -> None:
mock_extraction_dir = pathlib.Path(target_)
with tarfile.TarFile.open(test_archive) as tar:
tar.extractall(mock_extraction_dir)

# we'll use the mock extract to create the files that would normally be downloaded
expected_output_dir = test_archive.parent / str(asset_id)
mock_webtgz_extract(None, expected_output_dir)

# get modification time of directory holding the "downloaded" archive
ts1 = expected_output_dir.stat().st_ctime

requester = Requester(
auth=None,
Expand All @@ -174,16 +198,76 @@ def test_retrieve_cached(
# ensure mocked asset has values that we use...
monkeypatch.setattr(asset, "_browser_download_url", _git_attr(value="http://foo"))
monkeypatch.setattr(asset, "_name", _git_attr(value=mock_archive_name))
monkeypatch.setattr(asset, "_id", _git_attr(value=asset_id))

# show that retrieving an asset w/a different ID results in ignoring
# other wheels from prior downloads in the parent directory of the asset
asset_path = retrieve_asset(test_archive.parent, asset)
ts2 = asset_path.stat().st_ctime

# NOTE: the file should be written to a subdir based on the asset ID
assert (
asset_path == test_archive.parent
) # show that the expected path matches the output path
asset_path == expected_output_dir
) # shows that the expected path matches the output path
assert ts1 == ts2 # show that the file wasn't changed...


def test_retrieve_updated(
test_archive: pathlib.Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Verify that a previously retrieved asset archive is not re-used if a new
version is found"""

old_asset_id = 100
asset_id = 123

def mock_webtgz_extract(self_, target_) -> None:
mock_extraction_dir = pathlib.Path(target_)
with tarfile.TarFile.open(test_archive) as tar:
tar.extractall(mock_extraction_dir)

# we'll use the mock extract to create the files that would normally be downloaded
expected_output_dir = test_archive.parent / str(asset_id)
old_output_dir = test_archive.parent / str(old_asset_id)
mock_webtgz_extract(None, old_output_dir)

requester = Requester(
auth=None,
base_url="https://github.com",
user_agent="mozilla",
per_page=10,
verify=False,
timeout=1,
retry=1,
pool_size=1,
)
headers = {"mock-header": "mock-value"}
attributes = {"mock-attr": "mock-attr-value"}
completed = True

asset = GitReleaseAsset(requester, headers, attributes, completed)

# ensure mocked asset has values that we use...
monkeypatch.setattr(asset, "_browser_download_url", _git_attr(value="http://foo"))
monkeypatch.setattr(asset, "_name", _git_attr(value=mock_archive_name))
monkeypatch.setattr(asset, "_id", _git_attr(value=asset_id))
monkeypatch.setattr(
WebTGZ,
"extract",
lambda s_, t_: mock_webtgz_extract(s_, expected_output_dir),
) # mock the retrieval of the updated archive

# tell it to retrieve. it should return the path to the new download, not the old one
asset_path = retrieve_asset(test_archive.parent, asset)

# sanity check we don't have the same paths
assert old_output_dir != expected_output_dir

# verify the "cached" copy wasn't used
assert asset_path == expected_output_dir


@pytest.mark.parametrize(
"dragon_pin,pyv,is_found,is_crayex",
[
Expand Down