From e699fce534f77106e52dda1f0f0b23a3f8bcdf81 Mon Sep 17 00:00:00 2001 From: jakekaplan <40362401+jakekaplan@users.noreply.github.com> Date: Thu, 14 Oct 2021 14:29:13 -0400 Subject: [PATCH] add task run name to context (#5055) * initial commit, add task run name to context * updates docs * add changes file * include more tests from similar dev effort --- changes/pr5055.yaml | 2 + src/prefect/engine/cloud/task_runner.py | 18 ++---- src/prefect/engine/task_runner.py | 21 ++++++- src/prefect/utilities/context.py | 1 + tests/engine/cloud/test_cloud_task_runner.py | 36 +++++++---- tests/engine/test_task_runner.py | 64 ++++++++++++++++++++ 6 files changed, 114 insertions(+), 28 deletions(-) create mode 100644 changes/pr5055.yaml diff --git a/changes/pr5055.yaml b/changes/pr5055.yaml new file mode 100644 index 000000000000..ea46831fa37d --- /dev/null +++ b/changes/pr5055.yaml @@ -0,0 +1,2 @@ +enhancement: + - 'Adds `task_run_name` to `prefect.context` [#5055](https://github.com/PrefectHQ/prefect/pull/5055)' diff --git a/src/prefect/engine/cloud/task_runner.py b/src/prefect/engine/cloud/task_runner.py index c36552277078..b941bf52b11a 100644 --- a/src/prefect/engine/cloud/task_runner.py +++ b/src/prefect/engine/cloud/task_runner.py @@ -303,21 +303,11 @@ def set_task_run_name(self, task_inputs: Dict[str, Result]) -> None: - task_inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. """ - task_run_name = self.task.task_run_name - - if task_run_name: - raw_inputs = {k: r.value for k, r in task_inputs.items()} - formatting_kwargs = { - **prefect.context.get("parameters", {}), - **prefect.context, - **raw_inputs, - } - - if not isinstance(task_run_name, str): - task_run_name = task_run_name(**formatting_kwargs) - else: - task_run_name = task_run_name.format(**formatting_kwargs) + super().set_task_run_name(task_inputs) + + task_run_name = prefect.context.get("task_run_name") + if task_run_name is not None: self.client.set_task_run_name( task_run_id=self.task_run_id, name=task_run_name # type: ignore ) diff --git a/src/prefect/engine/task_runner.py b/src/prefect/engine/task_runner.py index 00ab11764652..2a1da280c3a4 100644 --- a/src/prefect/engine/task_runner.py +++ b/src/prefect/engine/task_runner.py @@ -685,13 +685,30 @@ def load_results( def set_task_run_name(self, task_inputs: Dict[str, Result]) -> None: """ - Sets the name for this task run. + Sets the name for this task run and adds to `prefect.context` Args: - task_inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond to the task's `run()` arguments. + """ - pass + + task_run_name = self.task.task_run_name + + if task_run_name: + raw_inputs = {k: r.value for k, r in task_inputs.items()} + formatting_kwargs = { + **prefect.context.get("parameters", {}), + **prefect.context, + **raw_inputs, + } + + if not isinstance(task_run_name, str): + task_run_name = task_run_name(**formatting_kwargs) + else: + task_run_name = task_run_name.format(**formatting_kwargs) + + prefect.context.update({"task_run_name": task_run_name}) @call_state_handlers def check_target(self, state: State, inputs: Dict[str, Result]) -> State: diff --git a/src/prefect/utilities/context.py b/src/prefect/utilities/context.py index 90b6ea457205..03157c007885 100644 --- a/src/prefect/utilities/context.py +++ b/src/prefect/utilities/context.py @@ -42,6 +42,7 @@ | `task_tags` | the tags on the current task | | `task_run_count` | the run count of the task run - typically only interesting for retrying tasks | | `task_loop_count` | if the Task utilizes looping, the loop count of the task run | +| `task_run_name` | the run name of the current task (if provided, otherwise `None`) | | `task_loop_result` | if the Task is looping, the current loop result | In addition, Prefect Cloud supplies some additional context variables: diff --git a/tests/engine/cloud/test_cloud_task_runner.py b/tests/engine/cloud/test_cloud_task_runner.py index d1623b52bd52..f9288f3dbafb 100644 --- a/tests/engine/cloud/test_cloud_task_runner.py +++ b/tests/engine/cloud/test_cloud_task_runner.py @@ -999,11 +999,15 @@ def test_task_runner_sets_task_name(monkeypatch, cloud_settings): runner = CloudTaskRunner(task=task) runner.task_run_id = "id" - runner.set_task_run_name(task_inputs={}) + with prefect.context(): + assert prefect.context.get("task_run_name") is None - assert client.set_task_run_name.called - assert client.set_task_run_name.call_args[1]["name"] == "asdf" - assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + runner.set_task_run_name(task_inputs={}) + + assert client.set_task_run_name.called + assert client.set_task_run_name.call_args[1]["name"] == "asdf" + assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + assert prefect.context.get("task_run_name") == "asdf" task = Task(name="test", task_run_name="{map_index}") runner = CloudTaskRunner(task=task) @@ -1012,21 +1016,29 @@ def test_task_runner_sets_task_name(monkeypatch, cloud_settings): class Temp: value = 100 - runner.set_task_run_name(task_inputs={"map_index": Temp()}) + with prefect.context(): + assert prefect.context.get("task_run_name") is None + + runner.set_task_run_name(task_inputs={"map_index": Temp()}) - assert client.set_task_run_name.called - assert client.set_task_run_name.call_args[1]["name"] == "100" - assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + assert client.set_task_run_name.called + assert client.set_task_run_name.call_args[1]["name"] == "100" + assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + assert prefect.context.get("task_run_name") == "100" task = Task(name="test", task_run_name=lambda **kwargs: "name") runner = CloudTaskRunner(task=task) runner.task_run_id = "id" - runner.set_task_run_name(task_inputs={}) + with prefect.context(): + assert prefect.context.get("task_run_name") is None + + runner.set_task_run_name(task_inputs={}) - assert client.set_task_run_name.called - assert client.set_task_run_name.call_args[1]["name"] == "name" - assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + assert client.set_task_run_name.called + assert client.set_task_run_name.call_args[1]["name"] == "name" + assert client.set_task_run_name.call_args[1]["task_run_id"] == "id" + assert prefect.context.get("task_run_name") == "name" def test_task_runner_set_task_name_same_as_prefect_context(client): diff --git a/tests/engine/test_task_runner.py b/tests/engine/test_task_runner.py index 82f066a0a2dc..c9a922d53b67 100644 --- a/tests/engine/test_task_runner.py +++ b/tests/engine/test_task_runner.py @@ -2356,3 +2356,67 @@ def run(self): msg = line.split("INFO")[1] logged_map_index = msg[-1] assert msg.count(logged_map_index) == 2 + + +class TestTaskRunNames: + def test_task_runner_set_task_name(self): + task = Task(name="test", task_run_name="asdf") + runner = TaskRunner(task=task) + runner.task_run_id = "id" + + with prefect.context(): + assert prefect.context.get("task_run_name") is None + runner.set_task_run_name(task_inputs={}) + assert prefect.context.get("task_run_name") == "asdf" + + task = Task(name="test", task_run_name="{map_index}") + runner = TaskRunner(task=task) + runner.task_run_id = "id" + + class Temp: + value = 100 + + with prefect.context(): + assert prefect.context.get("task_run_name") is None + runner.set_task_run_name(task_inputs={"map_index": Temp()}) + assert prefect.context.get("task_run_name") == "100" + + task = Task(name="test", task_run_name=lambda **kwargs: "name") + runner = TaskRunner(task=task) + runner.task_run_id = "id" + + with prefect.context(): + assert prefect.context.get("task_run_name") is None + runner.set_task_run_name(task_inputs={}) + assert prefect.context.get("task_run_name") == "name" + + def test_task_runner_sets_task_run_name_in_context(self): + def dynamic_task_run_name(**task_inputs): + return f"hello-{task_inputs['input']}" + + @prefect.task(name="hey", task_run_name=dynamic_task_run_name) + def test_task(input): + return prefect.context.get("task_run_name") + + edge = Edge(Task(), Task(), key="input") + state = Success(result="my-value") + state = TaskRunner(task=test_task).run(upstream_states={edge: state}) + + assert state.result == "hello-my-value" + + def test_mapped_task_run_name_set_in_context(self): + def dynamic_task_run_name(**task_inputs): + return f"hello-{task_inputs['input']}" + + @prefect.task(name="hey", task_run_name=dynamic_task_run_name) + def test_task(input): + return prefect.context.get("task_run_name") + + from prefect import Flow + + with Flow("test") as flow: + data = [1, 2, 3] + test_task_key = test_task.map(data) + + state = flow.run() + assert state.result[test_task_key].result == ["hello-1", "hello-2", "hello-3"]