Skip to content

Commit

Permalink
Merge pull request #2602 from jcrist/dynamic-doc-sig
Browse files Browse the repository at this point in the history
Dynamic `__doc__`/`__signature__` on tasks
  • Loading branch information
jcrist committed Jun 2, 2020
2 parents c417ad0 + ffbe022 commit 6f1c8e1
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 4 deletions.
3 changes: 3 additions & 0 deletions changes/pr2602.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
enhancement:
- "Task instances define a `__signature__` attribute, for improved introspection and tab-completion - [#2602](https://github.com/PrefectHQ/prefect/pull/2602)"
- "Tasks created with `@task` forward the wrapped function's docstring - [#2602](https://github.com/PrefectHQ/prefect/pull/2602)"
22 changes: 22 additions & 0 deletions src/prefect/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,19 @@ def copy(self, **task_args: Any) -> "Task":

return new

@property
def __signature__(self) -> inspect.Signature:
"""Dynamically generate the signature, replacing ``*args``/``**kwargs``
with parameters from ``run``"""
if not hasattr(self, "_cached_signature"):
sig = inspect.Signature.from_callable(self.run)
parameters = list(sig.parameters.values())
parameters.extend(EXTRA_CALL_PARAMETERS)
self._cached_signature = inspect.Signature(
parameters=parameters, return_annotation="Task"
)
return self._cached_signature

def __call__(
self,
*args: Any,
Expand Down Expand Up @@ -1170,3 +1183,12 @@ def serialize(self) -> Dict[str, Any]:
- dict representing this parameter
"""
return prefect.serialization.task.ParameterSchema().dump(self)


# All keyword-only arguments to Task.__call__, used for dynamically generating
# Signature objects for Task objects
EXTRA_CALL_PARAMETERS = [
p
for p in inspect.Signature.from_callable(Task.__call__).parameters.values()
if p.kind == inspect.Parameter.KEYWORD_ONLY
]
28 changes: 24 additions & 4 deletions src/prefect/tasks/core/function.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,32 @@
"""
The tasks in this module can be used to represent arbitrary functions.
In general, users will not instantiate these tasks by hand; they will automatically be
applied when users apply the `@task` decorator.
In general, users will not instantiate these tasks by hand; they will
automatically be applied when users apply the `@task` decorator.
"""

from typing import Any, Callable

import prefect


class _DocProxy(object):
"""A descriptor that proxies through the docstring for the wrapped task as
the docstring for a `FunctionTask` instance."""

def __init__(self, cls_doc):
self._cls_doc = cls_doc

def __get__(self, obj, cls):
if obj is None:
return self._cls_doc
else:
return getattr(obj.run, "__doc__", None) or self._cls_doc


class FunctionTask(prefect.Task):
"""
A convenience Task for functionally creating Task instances with
__doc__ = _DocProxy(
"""A convenience Task for functionally creating Task instances with
arbitrary callable `run` methods.
Args:
Expand All @@ -33,6 +47,7 @@ class FunctionTask(prefect.Task):
result = task(42)
```
"""
)

def __init__(self, fn: Callable, name: str = None, **kwargs: Any):
if not callable(fn):
Expand All @@ -46,3 +61,8 @@ def __init__(self, fn: Callable, name: str = None, **kwargs: Any):
self.run = fn

super().__init__(name=name, **kwargs)

def __getattr__(self, k):
if k == "__wrapped__":
return self.run
raise AttributeError(f"'FunctionTask' object has no attribute {k}")
18 changes: 18 additions & 0 deletions tests/core/test_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import uuid
from datetime import timedelta
Expand Down Expand Up @@ -232,6 +233,23 @@ def test_class_instantiation_raises_helpful_warning_for_unsupported_callables(se
with pytest.raises(ValueError, match="This function can not be inspected"):
task(zip)

def test_task_signature_generation(self):
class Test(Task):
def run(self, x: int, y: bool, z: int = 1):
pass

t = Test()

sig = inspect.signature(t)
# signature is a superset of the `run` method
for k, p in inspect.signature(t.run).parameters.items():
assert sig.parameters[k] == p
# extra kwonly args to __call__ also in sig
assert set(sig.parameters).issuperset(
{"mapped", "task_args", "upstream_tasks", "flow"}
)
assert sig.return_annotation == "Task"

def test_create_task_with_and_without_cache_for(self):
t1 = Task()
assert t1.cache_validator is never_use
Expand Down
25 changes: 25 additions & 0 deletions tests/tasks/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,31 @@ def my_fn():
f = FunctionTask(fn=my_fn, name="test")
assert f.name == "test"

def test_function_task_docstring(self):
def my_fn():
"""An example docstring."""
pass

# Original docstring available on class
assert "FunctionTask" in FunctionTask.__doc__

# Wrapped function is docstring on instance
f = FunctionTask(fn=my_fn)
assert f.__doc__ == my_fn.__doc__

# Except when no docstring on wrapped function
f = FunctionTask(fn=lambda x: x + 1)
assert "FunctionTask" in f.__doc__

def test_function_task_sets__wrapped__(self):
def my_fn():
"""An example function"""
pass

t = FunctionTask(fn=my_fn)
assert t.__wrapped__ == my_fn
assert not hasattr(FunctionTask, "__wrapped__")


class TestCollections:
def test_list_returns_a_list(self):
Expand Down

0 comments on commit 6f1c8e1

Please sign in to comment.