Skip to content

Commit

Permalink
add task run name to context (#5055)
Browse files Browse the repository at this point in the history
* initial commit, add task run name to context

* updates docs

* add changes file

* include more tests from similar dev effort
  • Loading branch information
jakekaplan committed Oct 14, 2021
1 parent 014550f commit e699fce
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 28 deletions.
2 changes: 2 additions & 0 deletions changes/pr5055.yaml
@@ -0,0 +1,2 @@
enhancement:
- 'Adds `task_run_name` to `prefect.context` [#5055](https://github.com/PrefectHQ/prefect/pull/5055)'
18 changes: 4 additions & 14 deletions src/prefect/engine/cloud/task_runner.py
Expand Up @@ -303,21 +303,11 @@ def set_task_run_name(self, task_inputs: Dict[str, Result]) -> None:
- task_inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
to the task's `run()` arguments.
"""
task_run_name = self.task.task_run_name

if task_run_name:
raw_inputs = {k: r.value for k, r in task_inputs.items()}
formatting_kwargs = {
**prefect.context.get("parameters", {}),
**prefect.context,
**raw_inputs,
}

if not isinstance(task_run_name, str):
task_run_name = task_run_name(**formatting_kwargs)
else:
task_run_name = task_run_name.format(**formatting_kwargs)
super().set_task_run_name(task_inputs)

task_run_name = prefect.context.get("task_run_name")

if task_run_name is not None:
self.client.set_task_run_name(
task_run_id=self.task_run_id, name=task_run_name # type: ignore
)
Expand Down
21 changes: 19 additions & 2 deletions src/prefect/engine/task_runner.py
Expand Up @@ -685,13 +685,30 @@ def load_results(

def set_task_run_name(self, task_inputs: Dict[str, Result]) -> None:
"""
Sets the name for this task run.
Sets the name for this task run and adds to `prefect.context`
Args:
- task_inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
to the task's `run()` arguments.
"""
pass

task_run_name = self.task.task_run_name

if task_run_name:
raw_inputs = {k: r.value for k, r in task_inputs.items()}
formatting_kwargs = {
**prefect.context.get("parameters", {}),
**prefect.context,
**raw_inputs,
}

if not isinstance(task_run_name, str):
task_run_name = task_run_name(**formatting_kwargs)
else:
task_run_name = task_run_name.format(**formatting_kwargs)

prefect.context.update({"task_run_name": task_run_name})

@call_state_handlers
def check_target(self, state: State, inputs: Dict[str, Result]) -> State:
Expand Down
1 change: 1 addition & 0 deletions src/prefect/utilities/context.py
Expand Up @@ -42,6 +42,7 @@
| `task_tags` | the tags on the current task |
| `task_run_count` | the run count of the task run - typically only interesting for retrying tasks |
| `task_loop_count` | if the Task utilizes looping, the loop count of the task run |
| `task_run_name` | the run name of the current task (if provided, otherwise `None`) |
| `task_loop_result` | if the Task is looping, the current loop result |
In addition, Prefect Cloud supplies some additional context variables:
Expand Down
36 changes: 24 additions & 12 deletions tests/engine/cloud/test_cloud_task_runner.py
Expand Up @@ -999,11 +999,15 @@ def test_task_runner_sets_task_name(monkeypatch, cloud_settings):
runner = CloudTaskRunner(task=task)
runner.task_run_id = "id"

runner.set_task_run_name(task_inputs={})
with prefect.context():
assert prefect.context.get("task_run_name") is None

assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "asdf"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
runner.set_task_run_name(task_inputs={})

assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "asdf"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
assert prefect.context.get("task_run_name") == "asdf"

task = Task(name="test", task_run_name="{map_index}")
runner = CloudTaskRunner(task=task)
Expand All @@ -1012,21 +1016,29 @@ def test_task_runner_sets_task_name(monkeypatch, cloud_settings):
class Temp:
value = 100

runner.set_task_run_name(task_inputs={"map_index": Temp()})
with prefect.context():
assert prefect.context.get("task_run_name") is None

runner.set_task_run_name(task_inputs={"map_index": Temp()})

assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "100"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "100"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
assert prefect.context.get("task_run_name") == "100"

task = Task(name="test", task_run_name=lambda **kwargs: "name")
runner = CloudTaskRunner(task=task)
runner.task_run_id = "id"

runner.set_task_run_name(task_inputs={})
with prefect.context():
assert prefect.context.get("task_run_name") is None

runner.set_task_run_name(task_inputs={})

assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "name"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
assert client.set_task_run_name.called
assert client.set_task_run_name.call_args[1]["name"] == "name"
assert client.set_task_run_name.call_args[1]["task_run_id"] == "id"
assert prefect.context.get("task_run_name") == "name"


def test_task_runner_set_task_name_same_as_prefect_context(client):
Expand Down
64 changes: 64 additions & 0 deletions tests/engine/test_task_runner.py
Expand Up @@ -2356,3 +2356,67 @@ def run(self):
msg = line.split("INFO")[1]
logged_map_index = msg[-1]
assert msg.count(logged_map_index) == 2


class TestTaskRunNames:
def test_task_runner_set_task_name(self):
task = Task(name="test", task_run_name="asdf")
runner = TaskRunner(task=task)
runner.task_run_id = "id"

with prefect.context():
assert prefect.context.get("task_run_name") is None
runner.set_task_run_name(task_inputs={})
assert prefect.context.get("task_run_name") == "asdf"

task = Task(name="test", task_run_name="{map_index}")
runner = TaskRunner(task=task)
runner.task_run_id = "id"

class Temp:
value = 100

with prefect.context():
assert prefect.context.get("task_run_name") is None
runner.set_task_run_name(task_inputs={"map_index": Temp()})
assert prefect.context.get("task_run_name") == "100"

task = Task(name="test", task_run_name=lambda **kwargs: "name")
runner = TaskRunner(task=task)
runner.task_run_id = "id"

with prefect.context():
assert prefect.context.get("task_run_name") is None
runner.set_task_run_name(task_inputs={})
assert prefect.context.get("task_run_name") == "name"

def test_task_runner_sets_task_run_name_in_context(self):
def dynamic_task_run_name(**task_inputs):
return f"hello-{task_inputs['input']}"

@prefect.task(name="hey", task_run_name=dynamic_task_run_name)
def test_task(input):
return prefect.context.get("task_run_name")

edge = Edge(Task(), Task(), key="input")
state = Success(result="my-value")
state = TaskRunner(task=test_task).run(upstream_states={edge: state})

assert state.result == "hello-my-value"

def test_mapped_task_run_name_set_in_context(self):
def dynamic_task_run_name(**task_inputs):
return f"hello-{task_inputs['input']}"

@prefect.task(name="hey", task_run_name=dynamic_task_run_name)
def test_task(input):
return prefect.context.get("task_run_name")

from prefect import Flow

with Flow("test") as flow:
data = [1, 2, 3]
test_task_key = test_task.map(data)

state = flow.run()
assert state.result[test_task_key].result == ["hello-1", "hello-2", "hello-3"]

0 comments on commit e699fce

Please sign in to comment.