/
notebook_service.py
372 lines (335 loc) · 15.4 KB
/
notebook_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# -*- coding: utf-8 -*-
#
# Copyright 2018-2022 - Swiss Data Science Center (SDSC)
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
# Eidgenössische Technische Hochschule Zürich (ETHZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Docker based interactive session provider."""
import urllib
from pathlib import Path
from time import monotonic, sleep
from typing import Any, Dict, List, Optional, Tuple, Union
from lazy_object_proxy import Proxy
from renku.command.command_builder import inject
from renku.core import errors
from renku.core.interface.client_dispatcher import IClientDispatcher
from renku.core.management.client import LocalClient
from renku.core.plugin import hookimpl
from renku.core.session.utils import get_renku_project_name, get_renku_url
from renku.core.util import communication, requests
from renku.core.util.git import get_remote
from renku.domain_model.session import ISessionProvider, Session
def _get_token(client: LocalClient, renku_url: str) -> Tuple[str, bool]:
"""Get a token for authenticating with renku.
If the user is logged in then the JWT token from renku login will be used.
Otherwise the anonymous user token will be used. Returns the token and a flag to
indicate if the user is registered (true) or anonymous(false).
"""
registered_token = client.get_value(section="http", key=urllib.parse.urlparse(renku_url).netloc)
if not registered_token:
return _get_anonymous_credentials(client=client, renku_url=renku_url), False
return registered_token, True
def _get_anonymous_credentials(client: LocalClient, renku_url: str) -> str:
def _get_anonymous_token() -> Optional[str]:
import requests
with requests.Session() as session:
url = urllib.parse.urljoin(renku_url, "api/user")
try:
session.get(
url,
headers={
"X-Requested-With": "XMLHttpRequest",
"X-Forwarded-Uri": "/api/user",
},
)
except (requests.exceptions.RequestException, requests.exceptions.ConnectionError):
pass
return session.cookies.get("anon-id")
renku_host = urllib.parse.urlparse(renku_url).netloc
anon_token = client.get_value(section="anonymous_token", key=renku_host)
if not anon_token:
anon_token = _get_anonymous_token()
if not anon_token:
raise errors.AuthenticationError(
"Could not get anonymous user token from Renku. "
f"Ensure the Renku deployment at {renku_url} supports anonymous sessions."
)
client.set_value(section="anonymous_token", key=renku_host, value=anon_token, global_only=True)
return anon_token
class NotebookServiceSessionProvider(ISessionProvider):
"""A session provider that uses the notebook service API to launch sessions."""
DEFAULT_TIMEOUT_SECONDS = 300
def __init__(self):
self.__renku_url = None
self.__notebooks_url = None
self.__client_dispatcher = None
def _client_dispatcher(self):
"""Get (and if required set) the client dispatcher class variable."""
if not self.__client_dispatcher:
self.__client_dispatcher = Proxy(lambda: inject.instance(IClientDispatcher))
return self.__client_dispatcher
def _renku_url(self) -> str:
"""Get the URL of the renku instance."""
if not self.__renku_url:
renku_url = get_renku_url()
if not renku_url:
raise errors.UsageError(
"Cannot determine the renku URL to launch a session. "
"Ensure your current project is a valid Renku project."
)
self.__renku_url = renku_url
return self.__renku_url
def _notebooks_url(self) -> str:
"""Get the url of the notebooks API."""
if not self.__notebooks_url:
url = urllib.parse.urljoin(self._renku_url(), "api/notebooks")
self.__notebooks_url = url
return self.__notebooks_url
def _token(self) -> str:
"""Get the JWT token used to authenticate against Renku."""
token, _ = _get_token(client=self._client_dispatcher().current_client, renku_url=self._renku_url())
if token is None:
raise errors.AuthenticationError("Please run the renku login command to authenticate with Renku.")
return token
def _is_user_registered(self) -> bool:
_, is_user_registered = _get_token(client=self._client_dispatcher().current_client, renku_url=self._renku_url())
return is_user_registered
def _auth_header(self) -> Dict[str, str]:
"""Get the authentication header with the JWT token or cookie needed to authenticate with Renku."""
if self._is_user_registered():
return {"Authorization": f"Bearer {self._token()}"}
return {"Cookie": f"anon-id={self._token()}"}
def _get_renku_project_name_parts(self) -> Dict[str, str]:
client = self._client_dispatcher().current_client
if client.remote["name"]:
if get_remote(client.repository, name="renku-backup-origin") and client.remote["owner"].startswith(
"repos/"
):
owner = client.remote["owner"].lstrip("repos/")
else:
owner = client.remote["owner"]
return {
"namespace": owner,
"project": client.remote["name"],
}
else:
# INFO: In this case the owner/name split is not available. The project name is then
# derived from the combined name of the remote and has to be split up in the two parts.
parts = get_renku_project_name().split("/")
return {
"namespace": "/".join(parts[:-1]),
"project": parts[:-1],
}
def _wait_for_session_status(
self,
name: Optional[str],
status: str,
):
if not name:
return
start = monotonic()
while monotonic() - start < self.DEFAULT_TIMEOUT_SECONDS:
res = self._send_renku_request(
"get", f"{self._notebooks_url()}/servers/{name}", headers=self._auth_header()
)
if res.status_code == 404 and status == "stopping":
return
if res.status_code == 200 and status != "stopping":
if res.json().get("status", {}).get("state") == status:
return
sleep(5)
raise errors.NotebookServiceSessionError(f"Waiting for the session {name} to reach status {status} timed out.")
def _wait_for_image(
self,
image_name: str,
config: Optional[Dict[str, Any]],
):
"""Check if an image exists, and if it does not wait for it to appear.
Timeout after a specific period of time.
"""
start = monotonic()
while monotonic() - start < self.DEFAULT_TIMEOUT_SECONDS:
if self.find_image(image_name, config):
return
sleep(5)
raise errors.NotebookServiceSessionError(
f"Waiting for the image {image_name} to be built timed out."
"Are you sure that the image was successfully built? This could be the result "
"of problems with your Dockerfile."
)
def pre_start_checks(self):
"""Check if the state of the repository is as expected before starting a session."""
if not self._is_user_registered():
return
if self._client_dispatcher().current_client.repository.is_dirty(untracked_files=True):
communication.confirm(
"You have new uncommitted or untracked changes to your repository. "
"Renku can automatically commit these changes so that it builds "
"the correct environment for your session. Do you wish to proceed?",
abort=True,
)
self._client_dispatcher().current_client.repository.add(all=True)
self._client_dispatcher().current_client.repository.commit("Automated commit by Renku CLI.")
def _remote_head_hexsha(self):
return get_remote(self._client_dispatcher().current_client.repository).head
@staticmethod
def _send_renku_request(req_type: str, *args, **kwargs):
res = getattr(requests, req_type)(*args, **kwargs)
if res.status_code == 401:
raise errors.AuthenticationError(
"Please run the renku login command to authenticate with Renku or to refresh your expired credentials."
)
return res
def build_image(self, image_descriptor: Path, image_name: str, config: Optional[Dict[str, Any]]):
"""Builds the container image."""
if self.find_image(image_name, config=config):
return
if not self._is_user_registered():
raise errors.NotebookSessionImageNotExistError(
f"Renku cannot find the image {image_name} and use it in an anonymous session."
)
if self._client_dispatcher().current_client.repository.head.commit.hexsha != self._remote_head_hexsha():
self._client_dispatcher().current_client.repository.push()
self._wait_for_image(image_name=image_name, config=config)
def find_image(self, image_name: str, config: Optional[Dict[str, Any]]) -> bool:
"""Find the given container image."""
return (
self._send_renku_request(
"get",
f"{self._notebooks_url()}/images",
headers=self._auth_header(),
params={"image_url": image_name},
).status_code
== 200
)
@hookimpl
def session_provider(self) -> Tuple[ISessionProvider, str]:
"""Supported session provider.
Returns:
a tuple of ``self`` and provider name.
"""
return (self, "notebook_service")
def session_list(self, project_name: str, config: Optional[Dict[str, Any]]) -> List[Session]:
"""Lists all the sessions currently running by the given session provider.
Returns:
list: a list of sessions.
"""
sessions_res = self._send_renku_request(
"get",
f"{self._notebooks_url()}/servers",
headers=self._auth_header(),
params=self._get_renku_project_name_parts(),
)
if sessions_res.status_code == 200:
return [
Session(
session["name"],
session.get("status", {}).get("state", "unknown"),
session["url"],
)
for session in sessions_res.json().get("servers", {}).values()
]
return []
def session_start(
self,
image_name: str,
project_name: str,
config: Optional[Dict[str, Any]],
client: LocalClient,
cpu_request: Optional[float] = None,
mem_request: Optional[str] = None,
disk_request: Optional[str] = None,
gpu_request: Optional[str] = None,
) -> str:
"""Creates an interactive session.
Returns:
str: a unique id for the created interactive sesssion.
"""
session_commit = client.repository.head.commit.hexsha
if not self._is_user_registered():
communication.warn(
"You are starting a session as an anonymous user. "
"None of the local changes in this project will be reflected in your session. "
"In addition, any changes you make in the new session will be lost when "
"the session is shut down."
)
else:
if client.repository.head.commit.hexsha != self._remote_head_hexsha():
# INFO: The user is registered, the image is pinned or already available
# but the local repository is not fully in sync with the remote
communication.confirm(
"You have unpushed commits that will not be present in your session. "
"Renku can automatically push these commits so that they are present "
"in the session you are launching. Do you wish to proceed?",
abort=True,
)
client.repository.push()
server_options: Dict[str, Union[str, float]] = {}
if cpu_request:
server_options["cpu_request"] = cpu_request
if mem_request:
server_options["mem_request"] = mem_request
if gpu_request:
server_options["gpu_request"] = int(gpu_request)
if disk_request:
server_options["disk_request"] = disk_request
payload = {
"image": image_name,
"commit_sha": session_commit,
"serverOptions": server_options,
**self._get_renku_project_name_parts(),
}
res = self._send_renku_request(
"post",
f"{self._notebooks_url()}/servers",
headers=self._auth_header(),
json=payload,
)
if res.status_code in [200, 201]:
session_name = res.json()["name"]
self._wait_for_session_status(session_name, "running")
return session_name
raise errors.NotebookServiceSessionError("Cannot start session via the notebook service because " + res.text)
def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> bool:
"""Stops all sessions (for the given project) or a specific interactive session."""
responses = []
if stop_all:
sessions = self.session_list(project_name=project_name, config=None)
for session in sessions:
responses.append(
self._send_renku_request(
"delete", f"{self._notebooks_url()}/servers/{session.id}", headers=self._auth_header()
)
)
self._wait_for_session_status(session.id, "stopping")
else:
responses.append(
self._send_renku_request(
"delete", f"{self._notebooks_url()}/servers/{session_name}", headers=self._auth_header()
)
)
self._wait_for_session_status(session_name, "stopping")
return all([response.status_code == 204 for response in responses])
def session_url(self, session_name: str) -> Optional[str]:
"""Get the URL of the interactive session."""
res = self._send_renku_request(
"get", f"{self._notebooks_url()}/servers/{session_name}", headers=self._auth_header()
)
if res.status_code == 200:
if res.json().get("status", {}).get("state") != "running":
raise errors.NotebookSessionNotReadyError(
f"The session {session_name} cannot be accessed now because it is not ready."
)
return res.json().get("url")
return None