Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gradient/api_sdk/clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .job_client import JobsClient
from .machines_client import MachinesClient
from .model_client import ModelsClient
from .notebook_client import NotebooksClient
from .project_client import ProjectsClient
from .sdk_client import SdkClient
81 changes: 81 additions & 0 deletions gradient/api_sdk/clients/notebook_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from .base_client import BaseClient
from .. import repositories, models


class NotebooksClient(BaseClient):
def create(
self,
vm_type_id,
container_id,
cluster_id,
container_name=None,
name=None,
registry_username=None,
registry_password=None,
default_entrypoint=None,
container_user=None,
shutdown_timeout=None,
is_preemptible=None,
):
"""Create new notebook

:param int vm_type_id:
:param int container_id:
:param int cluster_id:
:param str container_name:
:param str name:
:param str registry_username:
:param str registry_password:
:param str default_entrypoint:
:param str container_user:
:param int|float shutdown_timeout:
:param bool is_preemptible:

:return: Notebook ID
:rtype str:
"""

notebook = models.Notebook(
vm_type_id=vm_type_id,
container_id=container_id,
cluster_id=cluster_id,
container_name=container_name,
name=name,
registry_username=registry_username,
registry_password=registry_password,
default_entrypoint=default_entrypoint,
container_user=container_user,
shutdown_timeout=shutdown_timeout,
is_preemptible=is_preemptible,
)

repository = repositories.CreateNotebook(api_key=self.api_key, logger=self.logger)
handle = repository.create(notebook)
return handle

def get(self, id):
"""Get Notebook

:param str id: Notebook ID
:rtype: models.Notebook
"""
repository = repositories.GetNotebook(api_key=self.api_key, logger=self.logger)
notebook = repository.get(id=id)
return notebook

def delete(self, id):
"""Delete existing notebook

:param str id: Notebook ID
"""
repository = repositories.DeleteNotebook(api_key=self.api_key, logger=self.logger)
repository.delete(id)

def list(self):
"""Get list of Notebooks

:rtype: list[models.Notebook]
"""
repository = repositories.ListNotebooks(api_key=self.api_key, logger=self.logger)
notebooks = repository.list()
return notebooks
5 changes: 4 additions & 1 deletion gradient/api_sdk/clients/sdk_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient, \
MachinesClient, NotebooksClient
from .job_client import JobsClient
from .. import logger as sdk_logger

Expand All @@ -15,3 +16,5 @@ def __init__(self, api_key, logger=sdk_logger.MuteLogger()):
self.models = ModelsClient(api_key=api_key, logger=logger)
self.jobs = JobsClient(api_key=api_key, logger=logger)
self.projects = ProjectsClient(api_key=api_key, logger=logger)
self.machines = MachinesClient(api_key=api_key, logger=logger)
self.notebooks = NotebooksClient(api_key=api_key, logger=logger)
1 change: 1 addition & 0 deletions gradient/api_sdk/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .log import LogRow
from .machine import Machine, MachineEvent, MachineUtilization
from .model import Model
from .notebook import Notebook
from .project import Project
21 changes: 21 additions & 0 deletions gradient/api_sdk/models/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import attr


@attr.s
class Notebook(object):
id = attr.ib(type=str, default=None)
vm_type_id = attr.ib(type=int, default=None)
container_id = attr.ib(type=int, default=None)
container_name = attr.ib(type=str, default=None)
name = attr.ib(type=str, default=None)
cluster_id = attr.ib(type=int, default=None)
registry_username = attr.ib(type=str, default=None)
registry_password = attr.ib(type=str, default=None)
default_entrypoint = attr.ib(type=str, default=None)
container_user = attr.ib(type=str, default=None)
shutdown_timeout = attr.ib(type=int, default=None)
is_preemptible = attr.ib(type=bool, default=None)
project_id = attr.ib(type=bool, default=None)
state = attr.ib(type=bool, default=None)
vm_type = attr.ib(type=bool, default=None)
fqdn = attr.ib(type=bool, default=None)
1 change: 1 addition & 0 deletions gradient/api_sdk/repositories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
from .machines import CheckMachineAvailability, CreateMachine, CreateResource, StartMachine, StopMachine, \
RestartMachine, GetMachine, UpdateMachine, GetMachineUtilization
from .models import ListModels
from .notebooks import CreateNotebook, DeleteNotebook, GetNotebook, ListNotebooks
from .projects import CreateProject, ListProjects
14 changes: 0 additions & 14 deletions gradient/api_sdk/repositories/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,6 @@ def _get_api_url(self, **_):
return config.config.CONFIG_HOST


class ParseJobDictMixin(object):
@staticmethod
def _parse_object(job_dict, **kwargs):
"""

:param job_dict:
:param kwargs:
:return:
:rtype: Job
"""
job = JobSchema().get_instance(job_dict)
return job


class ListJobs(GetBaseJobApiUrlMixin, ListResources):

def get_request_url(self, **kwargs):
Expand Down
81 changes: 81 additions & 0 deletions gradient/api_sdk/repositories/notebooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from gradient import config
from .common import CreateResource, DeleteResource, ListResources, GetResource
from .. import serializers


class GetNotebookApiUrlMixin(object):
def _get_api_url(self, use_vpc=False):
return config.config.CONFIG_HOST


class CreateNotebook(GetNotebookApiUrlMixin, CreateResource):
SERIALIZER_CLS = serializers.NotebookSchema

def get_request_url(self, **kwargs):
return "notebooks/createNotebook"

def _process_instance_dict(self, instance_dict):
# the API requires this field but marshmallow does not create it if it's value is None
instance_dict.setdefault("containerId")
return instance_dict


class DeleteNotebook(GetNotebookApiUrlMixin, DeleteResource):
def get_request_url(self, **kwargs):
return "notebooks/v2/deleteNotebook"

def _get_request_json(self, kwargs):
notebook_id = kwargs["id"]
d = {"notebookId": notebook_id}
return d

def _send_request(self, client, url, json_data=None):
response = client.post(url, json=json_data)
return response


class GetNotebook(GetNotebookApiUrlMixin, GetResource):
def get_request_url(self, **kwargs):
notebook_id = kwargs["id"]
url = "notebooks/{}/getNotebook".format(notebook_id)
return url

def _parse_object(self, data, **kwargs):
# this ugly hack is here because marshmallow disallows reading value into `id` field
# if JSON's field was named differently (despite using load_from in schema definition)
data["id"] = data["handle"]

serializer = serializers.NotebookSchema()
notebooks = serializer.get_instance(data)
return notebooks


class ListNotebooks(GetNotebookApiUrlMixin, ListResources):
def get_request_url(self, **kwargs):
return "notebooks/getNotebooks"

def _parse_objects(self, data, **kwargs):
notebook_dicts = data["notebookList"]
# this ugly hack is here because marshmallow disallows reading value into `id` field
# if JSON's field was named differently (despite using load_from in schema definition)
for d in notebook_dicts:
d["id"] = d["handle"]

serializer = serializers.NotebookSchema()
notebooks = serializer.get_instance(notebook_dicts, many=True)
return notebooks

def _get_request_json(self, kwargs):
json_ = {
"filter": {
"filter": {
"limit": 11,
"offset": 0,
"where": {
"dtDeleted": None,
},
"order": "jobId desc",
},
},
}
return json_
1 change: 1 addition & 0 deletions gradient/api_sdk/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@
from .log import LogRowSchema
from .machine import MachineSchema, MachineSchemaForListing, MachineEventSchema
from .model import Model
from .notebook import NotebookSchema
from .project import Project
25 changes: 25 additions & 0 deletions gradient/api_sdk/serializers/notebook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import marshmallow

from . import BaseSchema
from .. import models


class NotebookSchema(BaseSchema):
MODEL = models.Notebook

id = marshmallow.fields.Str()
vm_type_id = marshmallow.fields.Int(load_from="vmTypeId", dump_to="vmTypeId")
container_id = marshmallow.fields.Int(load_from="containerId", dump_to="containerId", allow_none=True)
container_name = marshmallow.fields.Str(load_from="containerName", dump_to="containerName", allow_none=True)
name = marshmallow.fields.Str()
cluster_id = marshmallow.fields.Int(load_from="clusterId", dump_to="clusterId")
registry_username = marshmallow.fields.Str(load_from="registryUsername", dump_to="registryUsername")
registry_password = marshmallow.fields.Str(load_from="registryPassword", dump_to="registryPassword")
default_entrypoint = marshmallow.fields.Str(load_from="defaultEntrypoint", dump_to="defaultEntrypoint")
container_user = marshmallow.fields.Str(load_from="containerUser", dump_to="containerUser")
shutdown_timeout = marshmallow.fields.Int(load_from="shutdownTimeout", dump_to="shutdownTimeout")
is_preemptible = marshmallow.fields.Bool(load_from="isPreemptible", dump_to="isPreemptible")
project_id = marshmallow.fields.Str(load_from="projectHandle", dump_to="projectHandle")
state = marshmallow.fields.Str()
vm_type = marshmallow.fields.Str(load_from="vmType", dump_to="vmType")
fqdn = marshmallow.fields.Str()
1 change: 1 addition & 0 deletions gradient/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import gradient.cli.jobs
import gradient.cli.machines
import gradient.cli.models
import gradient.cli.notebooks
import gradient.cli.projects
import gradient.cli.run

Expand Down
7 changes: 0 additions & 7 deletions gradient/cli/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from gradient.commands import hyperparameters as hyperparameters_commands


def add_use_docker_file_flag_if_used(ctx, param, value):
if value:
ctx.params["useDockerFile"] = True

return value


@cli.group("hyperparameters", help="Manage hyperparameters", cls=ClickGroup)
def hyperparameters_group():
pass
Expand Down
Loading