Skip to content

Commit

Permalink
feat: Support Dep on function based commands.
Browse files Browse the repository at this point in the history
  • Loading branch information
DanCardin committed Feb 10, 2024
1 parent f86666d commit 11ae187
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 19 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 0.16

### 0.16.1

- feat: Support `Dep` on function based commands.

### 0.16.0

- feat: Add support for `BinaryIO` and `TextIO` for representing preconfigured
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cappa"
version = "0.16.0"
version = "0.16.1"
description = "Declarative CLI argument parser."

repository = "https://github.com/dancardin/cappa"
Expand Down
47 changes: 32 additions & 15 deletions src/cappa/class_inspect.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import annotations

import dataclasses
import functools
import inspect
import typing
from enum import Enum

import typing_inspect
from typing_extensions import Self, get_args

from cappa.typing import MISSING, get_type_hints, missing
from cappa.typing import MISSING, find_type_annotation, get_type_hints, missing

if typing.TYPE_CHECKING:
from cappa import Arg, Subcommand
Expand Down Expand Up @@ -201,28 +202,44 @@ def get_command_capable_object(obj):
the arguments to the dataclass into the original callable.
"""
if inspect.isfunction(obj):
from cappa import Dep

def call(self):
function_args = []

@functools.wraps(obj)
def call(self, **deps):
kwargs = dataclasses.asdict(self)
return obj(**kwargs)
return obj(**kwargs, **deps)

# We need to create a fake signature for the above callable, which does
# not retain the `Arg` annotations
sig = inspect.signature(obj)
sig_params: dict = dict(sig.parameters)
sig._parameters = sig_params # type: ignore
call.__signature__ = sig # type: ignore

args = get_type_hints(obj, include_extras=True)
parameters = inspect.signature(obj).parameters
fields = [
(
name,
annotation,
dataclasses.field(
default=parameters[name].default
if parameters[name].default is not inspect.Parameter.empty
else dataclasses.MISSING
),
for name, annotation in args.items():
if find_type_annotation(annotation, Dep).obj:
continue

sig_params.pop(name, None)
function_args.append(
(
name,
annotation,
dataclasses.field(
default=parameters[name].default
if parameters[name].default is not inspect.Parameter.empty
else dataclasses.MISSING
),
)
)
for name, annotation in args.items()
]

return dataclasses.make_dataclass(
obj.__name__,
fields,
function_args,
namespace={"__call__": call},
)

Expand Down
2 changes: 1 addition & 1 deletion src/cappa/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def resolve_implicit_deps(command: Command, instance: HasCommand) -> dict:
def fullfill_deps(fn: Callable, fullfilled_deps: dict) -> typing.Any:
result = {}

signature = inspect.signature(fn)
signature = getattr(fn, "__signature__", None) or inspect.signature(fn)
try:
annotations = get_type_hints(fn, include_extras=True)
except NameError as e: # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions src/cappa/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ def is_subclass(typ, superclass):
def get_type_hints(obj, include_extras=False):
result = typing_extensions.get_type_hints(obj, include_extras=include_extras)
if sys.version_info < (3, 11): # pragma: no cover
return fix_annotated_optional_type_hints(result)
return result
result = fix_annotated_optional_type_hints(result)

return {k: v for k, v in result.items() if k not in {"return"}}


def fix_annotated_optional_type_hints(
Expand Down
19 changes: 19 additions & 0 deletions tests/command/test_function_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,22 @@ def function(sub: cappa.Subcommands[Union[Sub, None]] = None):

result = invoke(function, "sub", "--bar", "34", backend=backend)
assert result == 35


def foo():
return 5


@backends
def test_invoke_partial_arg_partial_dep(backend):
def function(
dep: Annotated[int, cappa.Dep(foo)],
foo: Annotated[int, cappa.Arg(long=True)] = 15,
):
return dep + foo

result = invoke(function, backend=backend)
assert result == 20

result = invoke(function, "--foo", "53", backend=backend)
assert result == 58

0 comments on commit 11ae187

Please sign in to comment.