Skip to content

Commit

Permalink
feat: remote cluster support
Browse files Browse the repository at this point in the history
cleanup formatting

feat: remote cluster support
  • Loading branch information
Ralf Grubenmann committed Jun 17, 2024
1 parent 77b1ecc commit 21d0d3d
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 96 deletions.
7 changes: 7 additions & 0 deletions example.config.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,11 @@ git {
}
k8s {
namespace = namespace_where_notebooks_run
remote_clusters = [
{
name = remote_cluster
namespace = notebooks_namespace_in_remote_cluster
kube_config_path = path_where_kubeconfig_is_mounted
}
]
}
2 changes: 1 addition & 1 deletion renku_notebooks/api/classes/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_image_manifest(self, image: "Image") -> Optional[dict[str, Any]]:
"""Query the docker API to get the manifest of an image."""
if image.hostname != self.hostname:
raise ImageParseError(
f"The image hostname {image.hostname} does not match " f"the image repository {self.hostname}"
f"The image hostname {image.hostname} does not match the image repository {self.hostname}"
)
token = self._get_docker_token(image)
image_digest_url = f"https://{image.hostname}/v2/{image.name}/manifests/{image.tag}"
Expand Down
91 changes: 54 additions & 37 deletions renku_notebooks/api/classes/k8s_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,11 +270,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
(None, None),
)
secrets_container_index, secrets_container = next(
(
(i, c)
for i, c in enumerate(init_containers)
if c.name == "init-user-secrets"
),
((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"),
(None, None),
)

Expand All @@ -294,16 +290,11 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
else None
)
secrets_renku_access_token_env = (
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN")
if secrets_container is not None
else None
find_env_var(secrets_container, "RENKU_ACCESS_TOKEN") if secrets_container is not None else None
)

patches = list()
if (
git_proxy_container_index is not None
and git_proxy_renku_access_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -314,10 +305,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
}
)
if (
git_proxy_container_index is not None
and git_proxy_renku_refresh_token_env is not None
):
if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -328,10 +316,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.refresh_token,
},
)
if (
git_clone_container_index is not None
and git_clone_renku_access_token_env is not None
):
if git_clone_container_index is not None and git_clone_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -342,10 +327,7 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
"value": renku_tokens.access_token,
},
)
if (
secrets_container_index is not None
and secrets_renku_access_token_env is not None
):
if secrets_container_index is not None and secrets_renku_access_token_env is not None:
patches.append(
{
"op": "replace",
Expand All @@ -367,6 +349,26 @@ def patch_statefulset_tokens(self, name: str, renku_tokens: RenkuTokens):
)


class RemoteK8sClient(NamespacedK8sClient):
def __init__(
self,
name: str,
namespace: str,
kube_config_path: str,
amalthea_group: str,
amalthea_version: str,
amalthea_plural: str,
):
super().__init__(namespace, amalthea_group, amalthea_version, amalthea_plural)
self.name = name
load_config(kube_config_path=kube_config_path)
self._custom_objects = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch = client.CustomObjectsApi(client.ApiClient())
self._custom_objects_patch.api_client.set_default_header("Content-Type", "application/json-patch+json")
self._core_v1 = client.CoreV1Api()
self._apps_v1 = client.AppsV1Api()


class JsServerCache:
def __init__(self, url: str):
self.url = url
Expand Down Expand Up @@ -406,7 +408,7 @@ def get_server(self, name: str) -> Optional[dict[str, Any]]:
if len(output) == 0:
return
if len(output) > 1:
raise ProgrammingError(f"Expected to find 1 server when getting server {name}, " f"found {len(output)}.")
raise ProgrammingError(f"Expected to find 1 server when getting server {name}, found {len(output)}.")
return output[0]


Expand All @@ -417,11 +419,13 @@ def __init__(
renku_ns_client: NamespacedK8sClient,
username_label: str,
session_ns_client: Optional[NamespacedK8sClient] = None,
remote_cluster_clients: dict[str, RemoteK8sClient] = {},
):
self.js_cache = js_cache
self.renku_ns_client = renku_ns_client
self.username_label = username_label
self.session_ns_client = session_ns_client
self.remote_cluster_clients = remote_cluster_clients
if not self.username_label:
raise ProgrammingError("username_label has to be provided to K8sClient")

Expand All @@ -430,14 +434,24 @@ def list_servers(self, safe_username: str) -> list[dict[str, Any]]:
Attempt to use the cache first but if the cache fails then use the k8s API.
"""
try:
return self.js_cache.list_servers(safe_username)
except JSCacheError:
logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
label_selector = f"{self.username_label}={safe_username}"
return self.renku_ns_client.list_servers(label_selector) + (
self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else []
)
# try:
# return self.js_cache.list_servers(safe_username)
# except JSCacheError:
# logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
# label_selector = f"{self.username_label}={safe_username}"
# return self.renku_ns_client.list_servers(label_selector) + (
# self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else []
# )
logging.warning(f"Skipping the cache to list servers for user: {safe_username}")
label_selector = f"{self.username_label}={safe_username}"
remote_cluster_servers = [
s for c in self.remote_cluster_clients.values() for s in c.list_servers(label_selector)
]
return (
self.renku_ns_client.list_servers(label_selector)
+ (self.session_ns_client.list_servers(label_selector) if self.session_ns_client is not None else [])
+ remote_cluster_servers
)

def get_server(self, name: str, safe_username: str) -> Optional[dict[str, Any]]:
"""Attempt to get a specific server by name from the cache.
Expand All @@ -459,7 +473,7 @@ def get_server(self, name: str, safe_username: str) -> Optional[dict[str, Any]]:
output.append(res)
if len(output) > 1:
raise ProgrammingError(
"Expected less than two results for searching for " f"server {name}, but got {len(output)}"
"Expected less than two results for searching for server {name}, but got {len(output)}"
)
if len(output) == 0:
return
Expand Down Expand Up @@ -493,12 +507,15 @@ def get_secret(self, name: str) -> Optional[dict[str, Any]]:
return secret
return self.renku_ns_client.get_secret(name)

def create_server(self, manifest: dict[str, Any], safe_username: str):
def create_server(self, manifest: dict[str, Any], safe_username: str, cluster: str | None):
server_name = manifest.get("metadata", {}).get("name")
server = self.get_server(server_name, safe_username)
if server:
# NOTE: server already exists
return server
if cluster:
cluster_client = self.remote_cluster_clients[cluster]
return cluster_client.create_server(manifest)
if not self.session_ns_client:
return self.renku_ns_client.create_server(manifest)
return self.session_ns_client.create_server(manifest)
Expand All @@ -507,7 +524,7 @@ def patch_server(self, server_name: str, safe_username: str, patch: dict[str, An
server = self.get_server(server_name, safe_username)
if not server:
raise MissingResourceError(
f"Cannot find server {server_name} for user " f"{safe_username} in order to patch it."
f"Cannot find server {server_name} for user {safe_username} in order to patch it."
)

namespace = server.get("metadata", {}).get("namespace")
Expand All @@ -525,7 +542,7 @@ def delete_server(self, server_name: str, safe_username: str, forced: bool = Fal
server = self.get_server(server_name, safe_username)
if not server:
raise MissingResourceError(
f"Cannot find server {server_name} for user " f"{safe_username} in order to delete it."
f"Cannot find server {server_name} for user {safe_username} in order to delete it."
)
namespace = server.get("metadata", {}).get("namespace")
if namespace == self.renku_ns_client.namespace:
Expand Down
48 changes: 20 additions & 28 deletions renku_notebooks/api/classes/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
user: AnonymousUser | RegisteredUser,
server_name: str,
image: str | None,
cluster_name: str | None,
server_options: ServerOptions,
environment_variables: dict[str, str],
user_secrets: K8sUserSecrets | None,
Expand All @@ -43,6 +44,7 @@ def __init__(
using_default_image: bool = False,
is_image_private: bool = False,
repositories: list[Repository] = [],
host: str | None = None,
**_,
):
self._check_flask_config()
Expand All @@ -51,6 +53,7 @@ def __init__(
self._k8s_client: K8sClient = k8s_client
self.safe_username = self._user.safe_username
self.image = image
self.cluster_name = cluster_name
self.server_options = server_options
self.environment_variables = environment_variables
self.user_secrets = user_secrets
Expand All @@ -59,6 +62,7 @@ def __init__(
self.work_dir = work_dir
self.cloudstorage: list[ICloudStorageRequest] | None = cloudstorage
self.is_image_private = is_image_private
self.host = host or config.sessions.ingress.host

if self.server_options.idle_threshold_seconds is not None:
self.idle_seconds_threshold = self.server_options.idle_threshold_seconds
Expand Down Expand Up @@ -110,37 +114,32 @@ def repositories(self) -> list[Repository]:
@property
def server_url(self) -> str:
"""The URL where a user can access their session."""

if type(self._user) is RegisteredUser:
return urljoin(
f"https://{config.sessions.ingress.host}",
f"https://{self.host}",
f"sessions/{self.server_name}",
)
return urljoin(
f"https://{config.sessions.ingress.host}",
f"https://{self.host}",
f"sessions/{self.server_name}?token={self._user.username}",
)

@property
def git_providers(self) -> list[GitProvider]:
"""The list of git providers."""
if self._git_providers is None:
self._git_providers = config.git_provider_helper.get_providers(
user=self.user
)
self._git_providers = config.git_provider_helper.get_providers(user=self.user)
return self._git_providers

@property
def required_git_providers(self) -> list[GitProvider]:
"""The list of required git providers."""
required_provider_ids: set[str] = set(
r.provider for r in self.repositories if r.provider
)
required_provider_ids: set[str] = set(r.provider for r in self.repositories if r.provider)
return [p for p in self.git_providers if p.id in required_provider_ids]

def __str__(self):
return (
f"<UserServer user: {self._user.username} server_name: {self.server_name}>"
)
return f"<UserServer user: {self._user.username} server_name: {self.server_name}>"

def start(self) -> dict[str, Any] | None:
"""Create the jupyterserver resource in k8s."""
Expand All @@ -152,9 +151,7 @@ def start(self) -> dict[str, Any] | None:
f"or Docker resources are missing: {', '.join(errors)}"
)
)
return self._k8s_client.create_server(
self._get_session_manifest(), self.safe_username
)
return self._k8s_client.create_server(self._get_session_manifest(), self.safe_username, self.cluster_name)

@staticmethod
def _check_flask_config():
Expand Down Expand Up @@ -190,8 +187,7 @@ def _check_environment_variables_overrides(patches_list: list[dict[str, Any]]):

if key in env_vars and env_vars[key] != value:
raise DuplicateEnvironmentVariableError(
message=f"Environment variable {path}::{name} is being overridden by "
"multiple patches"
message=f"Environment variable {path}::{name} is being overridden by multiple patches"
)
else:
env_vars[key] = value
Expand Down Expand Up @@ -358,16 +354,12 @@ def get_annotations(self) -> dict[str, str | None]:
f"{prefix}hibernationDirty": "",
f"{prefix}hibernationSynchronized": "",
f"{prefix}hibernationDate": "",
f"{prefix}hibernatedSecondsThreshold": str(
self.hibernated_seconds_threshold
),
f"{prefix}hibernatedSecondsThreshold": str(self.hibernated_seconds_threshold),
f"{prefix}lastActivityDate": "",
f"{prefix}idleSecondsThreshold": str(self.idle_seconds_threshold),
}
if self.server_options.resource_class_id:
annotations[f"{prefix}resourceClassId"] = str(
self.server_options.resource_class_id
)
annotations[f"{prefix}resourceClassId"] = str(self.server_options.resource_class_id)
return annotations


Expand All @@ -384,6 +376,7 @@ def __init__(
commit_sha: str,
notebook: str | None, # TODO: Is this value actually needed?
image: str | None,
cluster_name: str | None,
server_options: ServerOptions,
environment_variables: dict[str, str],
user_secrets: K8sUserSecrets | None,
Expand All @@ -393,6 +386,7 @@ def __init__(
work_dir: Path,
using_default_image: bool = False,
is_image_private: bool = False,
host: str | None = None,
**_,
):
gitlab_project_name = f"{namespace}/{project}"
Expand All @@ -412,6 +406,7 @@ def __init__(
user=user,
server_name=server_name,
image=image,
cluster_name=cluster_name,
server_options=server_options,
environment_variables=environment_variables,
user_secrets=user_secrets,
Expand All @@ -422,6 +417,7 @@ def __init__(
using_default_image=using_default_image,
is_image_private=is_image_private,
repositories=[single_repository] if single_repository is not None else [],
host=host,
)

self.namespace = namespace
Expand Down Expand Up @@ -455,9 +451,7 @@ def _branch_exists(self):
try:
self.gitlab_project.branches.get(self.branch)
except Exception as err:
current_app.logger.warning(
f"Branch {self.branch} cannot be verified or does not exist. {err}"
)
current_app.logger.warning(f"Branch {self.branch} cannot be verified or does not exist. {err}")
else:
return True
return False
Expand All @@ -468,9 +462,7 @@ def _commit_sha_exists(self):
try:
self.gitlab_project.commits.get(self.commit_sha)
except Exception as err:
current_app.logger.warning(
f"Commit {self.commit_sha} cannot be verified or does not exist. {err}"
)
current_app.logger.warning(f"Commit {self.commit_sha} cannot be verified or does not exist. {err}")
else:
return True
return False
Expand Down
2 changes: 1 addition & 1 deletion renku_notebooks/api/classes/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@ def git_creds_from_headers(headers):
)

def __str__(self):
return f"<Registered user username:{self.username} name: " f"{self.full_name} email: {self.email}>"
return f"<Registered user username:{self.username} name: {self.full_name} email: {self.email}>"
Loading

0 comments on commit 21d0d3d

Please sign in to comment.