Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,18 @@ def _deploy_pathways_proxy_server(
_logger.info("Successfully deployed Pathways proxy.")


def _restore_env_var(key: str, original_value: str | None) -> None:
"""Restores an environment variable to its original value or unsets it."""
if original_value is None:
_logger.info("Unsetting environment variable: %s", key)
os.environ.pop(key, None)
else:
_logger.info(
"Restoring environment variable '%s' to '%s'", key, original_value
)
os.environ[key] = original_value


class _ISCPathways:
"""Class for managing TPUs for interactive supercomputing.

Expand Down Expand Up @@ -163,6 +175,10 @@ def __init__(
self._proxy_port = None
self.proxy_server_image = proxy_server_image
self.proxy_options = proxy_options or ProxyOptions()
self._old_jax_platforms = None
self._old_jax_backend_target = None
self._old_jax_platforms_config = None
self._old_jax_backend_target_config = None

def __repr__(self):
return (
Expand All @@ -176,6 +192,15 @@ def __repr__(self):

def __enter__(self):
"""Enters the context manager, ensuring cluster exists."""
self._old_jax_platforms = os.environ.get(_JAX_PLATFORMS_KEY)
self._old_jax_backend_target = os.environ.get(_JAX_BACKEND_TARGET_KEY)
self._old_jax_platforms_config = getattr(
jax.config, _JAX_PLATFORMS_KEY, None
)
self._old_jax_backend_target_config = getattr(
jax.config, _JAX_BACKEND_TARGET_KEY, None
)

try:
_deploy_pathways_proxy_server(
pathways_service=self.pathways_service,
Expand All @@ -199,11 +224,17 @@ def __enter__(self):
)

# Update the JAX backend to use the proxy.
os.environ[_JAX_PLATFORMS_KEY] = _JAX_PLATFORM_PROXY
os.environ[
_JAX_BACKEND_TARGET_KEY
] = f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}"

jax.config.update(_JAX_PLATFORMS_KEY, _JAX_PLATFORM_PROXY)
jax.config.update(
_JAX_BACKEND_TARGET_KEY,
f"{_JAX_BACKEND_TARGET_HOSTNAME}:{self._proxy_port}",
)

pathwaysutils.initialize()
_logger.info(
"Interactive supercomputing proxy client ready for cluster '%s'.",
Expand All @@ -221,7 +252,7 @@ def __exit__(self, exc_type, exc_value, traceback):
_logger.info("Exiting ISCPathways context.")
self._cleanup()

def _cleanup(self):
def _cleanup(self) -> None:
"""Cleans up resources created by the ISCPathways context."""
# 1. Clear JAX caches and run garbage collection.
_logger.info("Starting Pathways proxy cleanup.")
Expand All @@ -248,6 +279,16 @@ def _cleanup(self):
gke_utils.delete_gke_job(self._proxy_job_name)
_logger.info("Pathways proxy GKE job deletion complete.")

# 4. Restore JAX variables.
_logger.info("Restoring JAX env and config variables...")
_restore_env_var(_JAX_PLATFORMS_KEY, self._old_jax_platforms)
_restore_env_var(_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target)
jax.config.update(_JAX_PLATFORMS_KEY, self._old_jax_platforms_config)
jax.config.update(
_JAX_BACKEND_TARGET_KEY, self._old_jax_backend_target_config
)
_logger.info("JAX variables restored.")


@contextlib.contextmanager
def connect(
Expand Down
Loading