Skip to content

Commit

Permalink
fix: regain feature-parity with trogon/main
Browse files Browse the repository at this point in the history
Allow `ArgumentSchema` to contain one or more types.

Define `ChoiceSchema` to allow choice-types in a sequence of types;
used in place of `click.Choice`.
  • Loading branch information
Donald Mellenbruch authored and Donald Mellenbruch committed Jun 26, 2023
1 parent 1ef2c21 commit e35fde1
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/demo_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def cli(ctx, verbose, hidden_arg):
"--extra",
"-e",
nargs=2,
type=(str, int),
type=(str, click.Choice(["1", "2", "3"])),
multiple=True,
default=[("one", 1), ("two", 2)],
help="Add extra data as key-value pairs (repeatable)",
Expand Down
2 changes: 1 addition & 1 deletion examples/demo_click_nogroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"--extra",
"-e",
nargs=2,
type=(str, int),
type=(str, click.Choice(["1", "2", "3"])),
multiple=True,
default=[("one", 1), ("two", 2)],
help="Add extra data as key-value pairs (repeatable)",
Expand Down
29 changes: 18 additions & 11 deletions trogon/click.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Sequence
from typing import Type


from trogon import Trogon
Expand All @@ -19,6 +19,7 @@
CommandName,
CommandSchema,
OptionSchema,
ChoiceSchema,
)
from typing import Type, Any
from uuid import UUID
Expand All @@ -37,6 +38,13 @@
click.Path: Path,
}

def _convert_click_to_py_type(click_type: Type[Any]) -> Type[Any]:
if isinstance(click_type, click.Choice):
return ChoiceSchema(choices=click_type.choices)

return CLICK_TO_PY_TYPES.get(
click_type, CLICK_TO_PY_TYPES.get(type(click_type), str)
)

def introspect_click_app(
app: BaseCommand, cmd_ignorelist: list[str] | None = None
Expand Down Expand Up @@ -73,13 +81,12 @@ def process_command(
subcommands={},
parent=parent,
)

for param in cmd_obj.params:
param_type: Type[Any] = CLICK_TO_PY_TYPES.get(
param.type, CLICK_TO_PY_TYPES.get(type(param.type), str)
)
param_choices: Sequence[str] | None = None
if isinstance(param.type, click.Choice):
param_choices = param.type.choices

click_types: list[Type[Any]] = param.type.types if isinstance(param.type, click.Tuple) else [param.type]

param_types: list[Type[Any]] = [_convert_click_to_py_type(x) for x in click_types]

if isinstance(param, (click.Option, click.core.Group)):
if param.hidden:
Expand All @@ -89,14 +96,14 @@ def process_command(

option_data = OptionSchema(
name=param.opts,
type=param_type,
type=param_types,
is_flag=param.is_flag,
counting=param.count,
secondary_opts=param.secondary_opts,
required=param.required,
default=param.default,
help=param.help,
choices=param_choices,
choices=None,
multiple=param.multiple,
nargs=param.nargs,
sensitive=param.hide_input,
Expand All @@ -108,9 +115,9 @@ def process_command(
elif isinstance(param, click.Argument):
argument_data = ArgumentSchema(
name=param.name,
type=param_type,
type=param_types,
required=param.required,
choices=param_choices,
choices=None,
multiple=param.multiple,
default=param.default,
nargs=param.nargs,
Expand Down
25 changes: 20 additions & 5 deletions trogon/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ def process_cli_option(value) -> "MultiValueParamData":
return value


@dataclass
class ChoiceSchema:
# this is used in place of click.Choice
choices: Sequence[str]

def __post_init__(self):
self.__name__ = 'choice'


@dataclass
class ArgumentSchema:
name: str | list[str]
type: Type[Any] | None = None
type: Type[Any] | Sequence[Type[Any]] | None = None
required: bool = False
help: str | None = None
key: str | tuple[str] = field(default_factory=generate_unique_id)
Expand All @@ -52,14 +61,20 @@ def __post_init__(self):
self.default = MultiValueParamData.process_cli_option(self.default)

if not self.type:
self.type = str

if self.multi_value:
self.multiple = True
self.type = [str]
elif isinstance(self.type, Type):
self.type = [self.type]
elif len(self.type) == 1 and isinstance(self.type[0], ChoiceSchema):
# if there is only one type is it is a 'ChoiceSchema':
self.choices = self.type[0].choices
self.type = [str]

if self.choices:
self.choices = [str(x) for x in self.choices]

if self.multi_value:
self.multiple = True


@dataclass
class OptionSchema(ArgumentSchema):
Expand Down
32 changes: 21 additions & 11 deletions trogon/widgets/parameter_controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)


from trogon.schemas import ArgumentSchema, MultiValueParamData, OptionSchema
from trogon.schemas import ArgumentSchema, MultiValueParamData, OptionSchema, ChoiceSchema
from trogon.widgets.multiple_choice import MultipleChoice

ControlWidgetType: TypeVar = Union[Input, Checkbox, MultipleChoice, Select]
Expand Down Expand Up @@ -135,7 +135,9 @@ def compose(self) -> ComposeResult:
# There's a special case where we have a Choice with multiple=True,
# in this case, we can just render a single MultipleChoice widget
# instead of multiple radio-sets.
control_method = self.get_control_method(schema=schema)
control_method = self.get_control_method(
param_type=ChoiceSchema(choices=schema.choices)
)
multiple_choice_widget = control_method(
default=default,
label=label,
Expand Down Expand Up @@ -187,7 +189,7 @@ def compose(self) -> ComposeResult:
# If it's a multiple, and it's a Choice parameter, then we display
# our special case MultiChoice widget, and so there's no need for this
# button.
if multiple or nargs == -1 and not schema.choices:
if (multiple or nargs == -1) and not schema.choices:
with Horizontal(classes="add-another-button-container"):
yield Button("+ value", variant="success", classes="add-another-button")

Expand All @@ -209,12 +211,20 @@ def make_widget_group(self) -> Iterable[Widget]:
)

# Get the types of the parameter. We can map these types on to widgets that will be rendered.
parameter_types = [parameter_type] * schema.nargs if schema.nargs > 1 else [parameter_type]
parameter_types = [
parameter_type[i] if i < len(parameter_type) else parameter_type[-1]
for i in range(schema.nargs if schema.nargs > 1 else 1)
]
# The above ensures that len(parameter_types) == nargs.
# if there are more parameter_types than args, parameter_types is truncated.
# if there are fewer parameter_types than args, the *last* parameter type is repeated as much as necessary.

# For each of the these parameters, render the corresponding widget for it.
# At this point we don't care about filling in the default values.
for _type in parameter_types:
control_method = self.get_control_method(schema=schema)
if schema.choices:
_type = ChoiceSchema(choices=schema.choices)
control_method = self.get_control_method(param_type=_type)
control_widgets = control_method(
default, label, multiple, schema, schema.key
)
Expand Down Expand Up @@ -301,12 +311,12 @@ def list_to_tuples(
return MultiValueParamData.process_cli_option(collected_values)

def get_control_method(
self, schema: ArgumentSchema
self, param_type: Type[Any],
) -> Callable[[Any, Text, bool, OptionSchema | ArgumentSchema, str], Widget]:
if schema.choices:
return partial(self.make_choice_control, choices=schema.choices)
if isinstance(param_type, ChoiceSchema):
return partial(self.make_choice_control, choices=param_type.choices)

if schema.type is bool:
if param_type is bool:
return self.make_checkbox_control

return self.make_text_control
Expand Down Expand Up @@ -382,7 +392,7 @@ def make_choice_control(
@staticmethod
def _make_command_form_control_label(
name: str | list[str],
type: Type[Any],
types: list[Type[Any]],
is_option: bool,
is_required: bool,
multiple: bool,
Expand All @@ -391,7 +401,7 @@ def _make_command_form_control_label(

names = Text(" / ", style="dim").join([Text(n) for n in names])
text = Text.from_markup(
f"{names}[dim]{' multiple' if multiple else ''} <{type.__name__}>[/] {' [b red]*[/]required' if is_required else ''}"
f"{names}[dim]{' multiple' if multiple else ''} <{', '.join(x.__name__ for x in types)}>[/] {' [b red]*[/]required' if is_required else ''}"
)

return text
Expand Down

0 comments on commit e35fde1

Please sign in to comment.