Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions gradient/cli/common.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import functools
import json
import re

import click
import colorama
import termcolor
import yaml
from click.exceptions import Exit
from click_didyoumean import DYMMixin
from click_help_colors import HelpColorsGroup

from gradient.cli import cli_types
from gradient.config import config

OPTIONS_FILE_OPTION_NAME = "optionsFile"
OPTIONS_FILE_PARAMETER_NAME = "options_file"
Expand Down Expand Up @@ -78,19 +82,46 @@ def handle_parse_result(self, ctx, opts, args):
ctx, opts, args)


class ArgumentReadValueFromConfigFile(ReadValueFromConfigFile, click.Argument):
class ColorExtrasInCommandHelpMixin(object):
def get_help_record(self, *args, **kwargs):
rv = super(ColorExtrasInCommandHelpMixin, self).get_help_record(*args, **kwargs)
if not config.USE_CONSOLE_COLORS:
return rv

help_str = rv[1]
if help_str:
help_str = self._color_extras(help_str)
rv = rv[0], help_str
return rv

def _color_extras(self, s):
pattern = re.compile(r"^.*(\[.*\])$")
found = re.findall(pattern, s)
if found:
extras_str = found[-1]
coloured_extras_str = self._color_str(extras_str)
s = s.replace(extras_str, coloured_extras_str)

return s

def _color_str(self, s):
s = termcolor.colored(s, config.HELP_HEADERS_COLOR)
return s


class GradientArgument(ColorExtrasInCommandHelpMixin, ReadValueFromConfigFile, click.Argument):
pass


class OptionReadValueFromConfigFile(ReadValueFromConfigFile, click.Option):
class GradientOption(ColorExtrasInCommandHelpMixin, ReadValueFromConfigFile, click.Option):
pass


api_key_option = click.option(
"--apiKey",
"api_key",
help="API key to use this time only",
cls=OptionReadValueFromConfigFile,
cls=GradientOption,
)


Expand Down
30 changes: 15 additions & 15 deletions gradient/cli/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,56 +44,56 @@ def get_deployment_client(api_key):
type=ChoiceType(DEPLOYMENT_TYPES_MAP, case_sensitive=False),
required=True,
help="Model deployment type. Only TensorFlow models can currently be deployed",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--modelId",
"model_id",
required=True,
help="ID of a trained model",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--name",
"name",
required=True,
help="Human-friendly name for new model deployment",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--machineType",
"machine_type",
required=True,
help="Type of machine for new deployment",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--imageUrl",
"image_url",
required=True,
help="Docker image for model serving",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--instanceCount",
"instance_count",
type=int,
required=True,
help="Number of machine instances",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--clusterId",
"cluster_id",
help="Cluster ID",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--vpc",
"use_vpc",
type=bool,
is_flag=True,
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@api_key_option
@common.options_file
Expand Down Expand Up @@ -123,19 +123,19 @@ def create_deployment(api_key, use_vpc, options_file, **kwargs):
"state",
type=ChoiceType(DEPLOYMENT_STATES_MAP, case_sensitive=False),
help="Filter by deployment state",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--projectId",
"project_id",
help="Use to filter by project ID",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--modelId",
"model_id",
help="Use to filter by model ID",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@api_key_option
@common.options_file
Expand All @@ -155,14 +155,14 @@ def get_deployments_list(api_key, options_file, **filters):
"id_",
required=True,
help="Deployment ID",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--vpc",
"use_vpc",
type=bool,
is_flag=True,
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@api_key_option
@common.options_file
Expand All @@ -178,14 +178,14 @@ def start_deployment(id_, use_vpc, options_file, api_key=None):
"id_",
required=True,
help="Deployment ID",
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@click.option(
"--vpc",
"use_vpc",
type=bool,
is_flag=True,
cls=common.OptionReadValueFromConfigFile,
cls=common.GradientOption,
)
@api_key_option
@common.options_file
Expand Down
Loading