Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 169 additions & 37 deletions orchestrator/utilities/ray_env/ordered_pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import contextlib
import logging
import os
import threading
import typing

import ray._private.runtime_env.packaging
from ray._private.runtime_env import virtualenv_utils
from ray._private.runtime_env.pip import PipPlugin
from ray._private.runtime_env.plugin import RuntimeEnvPlugin
Expand All @@ -29,18 +31,24 @@ async def create_or_get_virtualenv(path: str, cwd: str, logger: logging.Logger):
await original_create_or_get_virtualenv(path=path, cwd=cwd, logger=logger)


_monkey_patch_lock = threading.RLock()


@contextlib.contextmanager
def patch_create_or_get_virtualenv(phase_index: int):
if phase_index > 0:
setattr(virtualenv_utils, "create_or_get_virtualenv", create_or_get_virtualenv)
try:
yield
finally:
setattr(
virtualenv_utils,
"create_or_get_virtualenv",
original_create_or_get_virtualenv,
)
with _monkey_patch_lock:
if phase_index > 0:
setattr(
virtualenv_utils, "create_or_get_virtualenv", create_or_get_virtualenv
)
try:
yield
finally:
setattr(
virtualenv_utils,
"create_or_get_virtualenv",
original_create_or_get_virtualenv,
)


class OrderedPipPlugin(RuntimeEnvPlugin):
Expand Down Expand Up @@ -93,27 +101,81 @@ def try_import_torch():
"""

name = "ordered_pip"

# VV: Configure Ray to use this RuntimeEnvPlugin last
priority = 100
ClassPath = "orchestrator.utilities.ray_env.ordered_pip.OrderedPipPlugin"

def __init__(self, resources_dir: str | None = None):
if resources_dir is None:
import ray._private.ray_constants as ray_constants
self._global_mtx = threading.RLock()
self._create_env_mtx: dict[str, threading.RLock] = {}
self._pip_resources_dir = resources_dir

resources_dir = os.environ.get(
ray_constants.RAY_RUNTIME_ENV_CREATE_WORKING_DIR_ENV_VAR
)
# VV: Maintains a cache of the environments that have been built thus far
self._cache = {}

if not resources_dir:
import tempfile
def _try_switch_resources_dir_from_context(
self,
context: "RuntimeEnvContext", # noqa: F821
logger: logging.Logger | None = default_logger,
):
# VV: When ray instantiates custom RuntimeEnvPlugins it does not provide a resources_dir path.
# This method is a HACK that the resources_dir based on the RuntimeEnvContext which is known
# at the time of CREATING a virtual environment i.e. **after** the RuntimeEnvPlugin is initialized.

with self._global_mtx:
# VV: Stick with whatever resources dir we've already picked
if self._pip_resources_dir:
return

logger.info("Generating resources dir")
unique = set()
if "PYTHONPATH" in context.env_vars:
# VV: This is a HACK to find the "runtime_resources" path inside the PYTHONPATH env-var
# This is an env-var that the WorkingDirPlugin inserts.
# I noticed that sometimes the PYTHONPATH contains multiple copies of the same PATH.
# The PYTHONPATH looks like this:
# /tmp/ray/session_$timestamp/runtime_resources/working_dir_files/_ray_pkg_$uid
many = context.env_vars["PYTHONPATH"].split(os.pathsep)
logger.info(f"Current PYTHONPATH {many}")
runtime_resources_followup = f"{os.sep}working_dir_files{os.sep}"
unique.update(
[
os.path.join(
x.split(runtime_resources_followup, 1)[0], "ordered_pip"
)
for x in many
if runtime_resources_followup in x
]
)

resources_dir = tempfile.mkdtemp(prefix="ordered_pip_", dir="/tmp/ray")
logger.info(f"The candidate locations of runtime_resources: {list(unique)}")

self._pip_resources_dir = resources_dir
if len(unique) != 1:
import tempfile

unique.clear()
unique.add(tempfile.mkdtemp(prefix="ordered_pip_", dir="/tmp/ray"))

from ray._common.utils import try_to_create_directory
self._switch_resources_dir(unique.pop())

try_to_create_directory(self._pip_resources_dir)
self._pip_plugin = PipPlugin(self._pip_resources_dir)
def _switch_resources_dir(self, resources_dir: str):
with self._global_mtx:
from ray._common.utils import try_to_create_directory

self._pip_resources_dir = resources_dir
try_to_create_directory(self._pip_resources_dir)

@property
def _pip_plugin(self) -> PipPlugin:
# The PipPlugin keeps an internal cache of virtual environments it has created but not yet deleted.
# When .create() is called, it checks this cache for a venv matching the given URI.
# If a match is found, it assumes the venv already exists and skips re-creation.
# However, ordered_pip needs to reuse the same venv multiple times (once per "phase").
# Thus, we create a new PipPlugin instance on demand for each phase of ordered_pip.
# Also, we maintain our own record of venvs to decide whether to create a new "ordered_pip"
# venv or reuse an existing one.
return PipPlugin(self._pip_resources_dir)

@staticmethod
def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
Expand All @@ -132,14 +194,14 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
raise ValueError("runtime_env must be a dictionary")

if "ordered_pip" not in runtime_env_dict:
raise ValueError("missing the 'ordered_pip' key", runtime_env_dict)
return RuntimeEnv(**runtime_env_dict)

if not isinstance(runtime_env_dict["ordered_pip"], dict):
raise ValueError("runtime_env['ordered_pip'] must be a dictionary")

if not isinstance(runtime_env_dict["ordered_pip"]["phases"], list):
raise ValueError(
"runtime_env['ordered_pip']['phases'] must be a dictionary consistent with pip"
"runtime_env['ordered_pip']['phases'] must be an array of pip entries"
)

phases = []
Expand All @@ -164,6 +226,9 @@ def validate(runtime_env_dict: dict[str, typing.Any]) -> RuntimeEnv:
return result

def get_uris(self, runtime_env: "RuntimeEnv") -> list[str]:
if not self.is_ordered_pip_runtimeenv(runtime_env):
return []

# VV: We want the hash to be invariant to the order of package names within a phase,
# and we also want the order of phases to be reflected in the hash.
aggregate_packages = [
Expand All @@ -178,31 +243,76 @@ def get_uris(self, runtime_env: "RuntimeEnv") -> list[str]:
"pip://" + hashlib.sha1(str(aggregate_packages).encode("utf-8")).hexdigest()
]

def is_ordered_pip_runtimeenv(self, runtime_env: "RuntimeEnv") -> bool:
return bool(self.validate(runtime_env).get("ordered_pip"))

async def create(
self,
uri: str,
runtime_env: "RuntimeEnv", # noqa: F821
context: "RuntimeEnvContext", # noqa: F821
logger: logging.Logger | None = default_logger,
) -> int:
self._try_switch_resources_dir_from_context(context, logger)

if not self.is_ordered_pip_runtimeenv(runtime_env):
return 0

uri = self.get_uris(runtime_env)[0]
total_bytes = 0

for idx, pip in enumerate(self.validate(runtime_env)["ordered_pip"]["phases"]):
with patch_create_or_get_virtualenv(idx):
total_bytes += await self._pip_plugin.create(
uri=uri,
runtime_env=RuntimeEnv(pip=pip),
context=context,
logger=logger,
)

return total_bytes
with self._global_mtx:
if uri not in self._create_env_mtx:
self._create_env_mtx[uri] = threading.RLock()

with self._create_env_mtx[uri]:
logger.info(f"Creating {uri} for {runtime_env}")
try:
if os.path.isdir(self.get_path_to_pip_venv(uri)):
logger.info(f"Virtual environment for {uri} already exists")
return self._cache[uri]
except KeyError:
pass

self._cache[uri] = 0
for idx, pip in enumerate(
self.validate(runtime_env)["ordered_pip"]["phases"]
):
with patch_create_or_get_virtualenv(idx):
logger.info(f"Creating {idx} for {uri}")

self._cache[uri] += await self._pip_plugin.create(
uri=uri,
runtime_env=RuntimeEnv(pip=pip),
context=context,
logger=logger,
)
logger.info(f"Done creating {idx} for {uri}")

return self._cache[uri]

def get_path_to_pip_venv(self, uri: str) -> str:
_, env_hash = ray._private.runtime_env.packaging.parse_uri(uri)
return os.path.join(self._pip_resources_dir, "pip", env_hash)

def delete_uri(
self, uri: str, logger: logging.Logger | None = default_logger
) -> int:
return self._pip_plugin.delete_uri(uri=uri, logger=logger)
logger.info(f"Cleaning up {uri}")
del self._cache[uri]

import shutil

import ray._private.utils

env_dir = self.get_path_to_pip_venv(uri)
num_bytes = ray._private.utils.get_directory_size_bytes(env_dir)

try:
shutil.rmtree(env_dir)
except Exception as e:
logger.warning(f"Exception while cleaning up {env_dir} {e!s} - will ignore")

return num_bytes

def modify_context(
self,
Expand All @@ -211,7 +321,14 @@ def modify_context(
context: "RuntimeEnvContext", # noqa: F821
logger: logging.Logger = default_logger,
):
phases = self.validate(runtime_env)["ordered_pip"]["phases"]
self._try_switch_resources_dir_from_context(context)

runtime_env = self.validate(runtime_env)
if not runtime_env.get("ordered_pip"):
return

logger.info(f"Modifying the context for {uris} and {runtime_env}")
phases = runtime_env["ordered_pip"]["phases"]

if not len(phases):
return
Expand All @@ -222,3 +339,18 @@ def modify_context(
context=context,
logger=logger,
)

if "PYTHONPATH" in context.env_vars:
# VV: Ensure unique paths in PYTHONPATH
paths = context.env_vars["PYTHONPATH"].split(os.pathsep)

unique = []
for k in paths:
if k not in unique:
unique.append(k)

context.env_vars["PYTHONPATH"] = os.pathsep.join(unique)

logger.info(
f"Modified the context for {uris} and {runtime_env} with {context.py_executable} {context.env_vars}"
)
Loading