diff --git a/src/prefect/records/cache_policies.py b/src/prefect/cache_policies.py similarity index 64% rename from src/prefect/records/cache_policies.py rename to src/prefect/cache_policies.py index f0ec6d78abf0..d8f1eb314ac4 100644 --- a/src/prefect/records/cache_policies.py +++ b/src/prefect/cache_policies.py @@ -8,6 +8,10 @@ @dataclass class CachePolicy: + """ + Base class for all cache policies. + """ + @classmethod def from_cache_key_fn( cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] @@ -59,6 +63,11 @@ def __add__(self, other: "CachePolicy") -> "CompoundCachePolicy": @dataclass class CacheKeyFnPolicy(CachePolicy): + """ + This policy accepts a custom function with signature f(task_run_context, task_parameters, flow_parameters) -> str + and uses it to compute a task run cache key. + """ + # making it optional for tests cache_key_fn: Optional[ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] @@ -77,6 +86,13 @@ def compute_key( @dataclass class CompoundCachePolicy(CachePolicy): + """ + This policy is constructed from two or more other cache policies and works by computing the keys + for each policy individually, and then hashing a sorted tuple of all computed keys. + + Any keys that return `None` will be ignored. + """ + policies: Optional[list] = None def compute_key( @@ -88,20 +104,25 @@ def compute_key( ) -> Optional[str]: keys = [] for policy in self.policies or []: - keys.append( - policy.compute_key( - task_ctx=task_ctx, - inputs=inputs, - flow_parameters=flow_parameters, - **kwargs, - ) + policy_key = policy.compute_key( + task_ctx=task_ctx, + inputs=inputs, + flow_parameters=flow_parameters, + **kwargs, ) + if policy_key is not None: + keys.append(policy_key) + if not keys: + return None return hash_objects(*keys) @dataclass -class Default(CachePolicy): - "Execution run ID only" +class _None(CachePolicy): + """ + Policy that always returns `None` for the computed cache key. + This policy prevents persistence. + """ def compute_key( self, @@ -110,12 +131,14 @@ def compute_key( flow_parameters: Dict[str, Any], **kwargs, ) -> Optional[str]: - return str(task_ctx.task_run.id) + return None @dataclass -class _None(CachePolicy): - "ignore key policies altogether, always run - prevents persistence" +class TaskSource(CachePolicy): + """ + Policy for computing a cache key based on the source code of the task. + """ def compute_key( self, @@ -124,11 +147,22 @@ def compute_key( flow_parameters: Dict[str, Any], **kwargs, ) -> Optional[str]: - return None + if not task_ctx: + return None + try: + lines = inspect.getsource(task_ctx.task) + except TypeError: + lines = inspect.getsource(task_ctx.task.fn.__class__) + + return hash_objects(lines) @dataclass -class TaskDef(CachePolicy): +class FlowParameters(CachePolicy): + """ + Policy that computes the cache key based on a hash of the flow parameters. + """ + def compute_key( self, task_ctx: TaskRunContext, @@ -136,21 +170,37 @@ def compute_key( flow_parameters: Dict[str, Any], **kwargs, ) -> Optional[str]: - lines = inspect.getsource(task_ctx.task) - return hash_objects(lines) + if not flow_parameters: + return None + return hash_objects(flow_parameters) @dataclass -class FlowParameters(CachePolicy): - pass +class RunId(CachePolicy): + """ + Returns either the prevailing flow run ID, or if not found, the prevailing task + run ID. + """ + + def compute_key( + self, + task_ctx: TaskRunContext, + inputs: Dict[str, Any], + flow_parameters: Dict[str, Any], + **kwargs, + ) -> Optional[str]: + if not task_ctx: + return None + run_id = task_ctx.task_run.flow_run_id + if run_id is None: + run_id = task_ctx.task_run.id + return str(run_id) @dataclass class Inputs(CachePolicy): """ - Exposes flag for whether to include flow parameters as well. - - And exclude/include config. + Policy that computes a cache key based on a hash of the runtime inputs provided to the task.. """ exclude: Optional[list] = None @@ -166,6 +216,9 @@ def compute_key( inputs = inputs or {} exclude = self.exclude or [] + if not inputs: + return None + for key, val in inputs.items(): if key not in exclude: hashed_inputs[key] = val @@ -173,7 +226,9 @@ def compute_key( return hash_objects(hashed_inputs) -DEFAULT = Default() INPUTS = Inputs() NONE = _None() -TASKDEF = TaskDef() +TASK_SOURCE = TaskSource() +FLOW_PARAMETERS = FlowParameters() +RUN_ID = RunId() +DEFAULT = INPUTS + TASK_SOURCE + RUN_ID diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index 6c1a42efdfc6..da01c6a8afe3 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -174,11 +174,18 @@ def call_hooks(self, state: State = None) -> Iterable[Callable]: def compute_transaction_key(self) -> str: key = None if self.task.cache_policy: + flow_run_context = FlowRunContext.get() task_run_context = TaskRunContext.get() + + if flow_run_context: + parameters = flow_run_context.parameters + else: + parameters = None + key = self.task.cache_policy.compute_key( task_ctx=task_run_context, inputs=self.parameters, - flow_parameters=None, + flow_parameters=parameters, ) elif self.task.result_storage_key is not None: key = _format_user_supplied_storage_key(self.task.result_storage_key) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index a5f5fd65e3a1..b024cb4d50ad 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -32,6 +32,7 @@ from typing_extensions import Literal, ParamSpec +from prefect.cache_policies import DEFAULT, NONE, CachePolicy from prefect.client.orchestration import get_client from prefect.client.schemas import TaskRun from prefect.client.schemas.objects import TaskRunInput, TaskRunResult @@ -43,7 +44,6 @@ ) from prefect.futures import PrefectDistributedFuture, PrefectFuture from prefect.logging.loggers import get_logger -from prefect.records.cache_policies import DEFAULT, NONE, CachePolicy from prefect.results import ResultFactory, ResultSerializer, ResultStorage from prefect.settings import ( PREFECT_TASK_DEFAULT_RETRIES, diff --git a/tests/records/test_cache_policies.py b/tests/test_cache_policies.py similarity index 62% rename from tests/records/test_cache_policies.py rename to tests/test_cache_policies.py index 4977eecd006a..26c4d5cb39e8 100644 --- a/tests/records/test_cache_policies.py +++ b/tests/test_cache_policies.py @@ -1,13 +1,16 @@ import itertools +from dataclasses import dataclass +from typing import Callable import pytest -from prefect.records.cache_policies import ( +from prefect.cache_policies import ( + DEFAULT, CachePolicy, CompoundCachePolicy, - Default, Inputs, - TaskDef, + RunId, + TaskSource, _None, ) @@ -40,26 +43,6 @@ def test_addition_of_none_is_noop(self, typ): assert policy + other == other -class TestDefaultPolicy: - def test_initializes(self): - policy = Default() - assert isinstance(policy, CachePolicy) - - def test_returns_run_id(self): - class Run: - id = "foo" - - class TaskCtx: - pass - - task_ctx = TaskCtx() - task_ctx.task_run = Run() - - policy = Default() - key = policy.compute_key(task_ctx=task_ctx, inputs=None, flow_parameters=None) - assert key == "foo" - - class TestInputsPolicy: def test_initializes(self): policy = Inputs() @@ -133,7 +116,7 @@ def test_initializes(self): assert isinstance(policy, CachePolicy) def test_creation_via_addition(self): - one, two = Inputs(), Default() + one, two = Inputs(), TaskSource() policy = one + two assert isinstance(policy, CompoundCachePolicy) @@ -152,18 +135,29 @@ def test_subtraction_creates_new_policies(self): assert policy.policies != new_policy.policies def test_creation_via_subtraction(self): - one = Default() + one = RunId() policy = one - "foo" assert isinstance(policy, CompoundCachePolicy) + def test_nones_are_ignored(self): + one, two = _None(), _None() + policy = CompoundCachePolicy(policies=[one, two]) + assert isinstance(policy, CompoundCachePolicy) + + fparams = dict(x=42, y="foo") + compound_key = policy.compute_key( + task_ctx=None, inputs=dict(z=[1, 2]), flow_parameters=fparams + ) + assert compound_key is None + -class TestTaskDefPolicy: +class TestTaskSourcePolicy: def test_initializes(self): - policy = TaskDef() + policy = TaskSource() assert isinstance(policy, CachePolicy) def test_changes_in_def_change_key(self): - policy = TaskDef() + policy = TaskSource() class TaskCtx: pass @@ -189,3 +183,77 @@ def my_func(x): ) assert key != new_key + + +class TestDefaultPolicy: + def test_changing_the_inputs_busts_the_cache(self): + inputs = dict(x=42) + key = DEFAULT.compute_key(task_ctx=None, inputs=inputs, flow_parameters=None) + + inputs = dict(x=43) + new_key = DEFAULT.compute_key( + task_ctx=None, inputs=inputs, flow_parameters=None + ) + + assert key != new_key + + def test_changing_the_run_id_busts_the_cache(self): + @dataclass + class Run: + id: str + flow_run_id: str = None + + def my_task(): + pass + + @dataclass + class TaskCtx: + task_run: Run + task = my_task + + task_run_a = Run(id="a", flow_run_id="a") + task_run_b = Run(id="b", flow_run_id="b") + task_run_c = Run(id="c", flow_run_id=None) + task_run_d = Run(id="d", flow_run_id=None) + + key_a = DEFAULT.compute_key( + task_ctx=TaskCtx(task_run=task_run_a), inputs=None, flow_parameters=None + ) + key_b = DEFAULT.compute_key( + task_ctx=TaskCtx(task_run=task_run_b), inputs=None, flow_parameters=None + ) + key_c = DEFAULT.compute_key( + task_ctx=TaskCtx(task_run=task_run_c), inputs=None, flow_parameters=None + ) + key_d = DEFAULT.compute_key( + task_ctx=TaskCtx(task_run=task_run_d), inputs=None, flow_parameters=None + ) + + assert key_a not in [key_b, key_c, key_d] + assert key_b not in [key_a, key_c, key_d] + assert key_c not in [key_a, key_b, key_d] + assert key_d not in [key_a, key_b, key_c] + + def test_changing_the_source_busts_the_cache(self): + @dataclass + class Run: + id: str + flow_run_id: str = None + + @dataclass + class TaskCtx: + task_run: Run + task: Callable = None + + task_run = Run(id="a", flow_run_id="b") + ctx_one = TaskCtx(task_run=task_run, task=lambda: "foo") + ctx_two = TaskCtx(task_run=task_run, task=lambda: "bar") + + key_one = DEFAULT.compute_key( + task_ctx=ctx_one, inputs=None, flow_parameters=None + ) + key_two = DEFAULT.compute_key( + task_ctx=ctx_two, inputs=None, flow_parameters=None + ) + + assert key_one != key_two diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index 9e75d8882083..7b62903d7a1a 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -11,6 +11,7 @@ import pytest from prefect import Task, flow, task +from prefect.cache_policies import FLOW_PARAMETERS from prefect.client.orchestration import PrefectClient, SyncPrefectClient from prefect.client.schemas.objects import StateType from prefect.context import ( @@ -1071,39 +1072,6 @@ async def async_task(): state = await async_task(return_state=True) assert await state.result() == 42 - async def test_task_persists_results_with_run_id_key(self): - @task(persist_result=True) - async def async_task(): - return 42 - - state = await async_task(return_state=True) - assert state.is_completed() - assert await state.result() == 42 - assert isinstance(state.data, PersistedResult) - assert state.data.storage_key == str(state.state_details.task_run_id) - - async def test_task_loads_result_if_exists(self, prefect_client, tmp_path): - run_id = uuid4() - - fs = LocalFileSystem(basepath=tmp_path) - - factory = await ResultFactory.default_factory( - client=prefect_client, persist_result=True, result_storage=fs - ) - await factory.create_result(1800, key=str(run_id)) - - @task(result_storage=fs) - async def async_task(): - return 42 - - state = await run_task_async( - async_task, task_run_id=run_id, return_type="state" - ) - assert state.is_completed() - assert await state.result() == 1800 - assert isinstance(state.data, PersistedResult) - assert state.data.storage_key == str(run_id) - async def test_task_loads_result_if_exists_using_result_storage_key( self, prefect_client, tmp_path ): @@ -1239,6 +1207,42 @@ async def async_task(): assert first_val is None assert second_val is None + async def test_flow_parameter_caching(self, prefect_client, tmp_path): + fs = LocalFileSystem(basepath=tmp_path) + + @task( + cache_policy=FLOW_PARAMETERS, + result_storage=fs, + ) + def my_random_task(x: int): + import random + + return random.randint(0, x) + + @flow + def my_param_flow(x: int, other_val: str): + first_val = my_random_task(x, return_state=True) + second_val = my_random_task(x, return_state=True) + return first_val, second_val + + first, second = my_param_flow(4200, other_val="foo") + assert first.name == "Completed" + assert second.name == "Cached" + + first_result = await first.result() + second_result = await second.result() + assert first_result == second_result + + third, fourth = my_param_flow(4200, other_val="bar") + assert third.name == "Completed" + assert fourth.name == "Cached" + + third_result = await third.result() + fourth_result = await fourth.result() + + assert third_result not in [first_result, second_result] + assert fourth_result not in [first_result, second_result] + class TestGenerators: async def test_generator_task(self): diff --git a/tests/test_tasks.py b/tests/test_tasks.py index a3cfab81ece1..00a5e09c5600 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -17,6 +17,7 @@ import prefect from prefect import flow, tags from prefect.blocks.core import Block +from prefect.cache_policies import DEFAULT, INPUTS, NONE, TASK_SOURCE from prefect.client.orchestration import PrefectClient from prefect.client.schemas.filters import LogFilter, LogFilterFlowRunId from prefect.client.schemas.objects import StateType, TaskRunResult @@ -31,7 +32,6 @@ from prefect.futures import PrefectDistributedFuture from prefect.futures import PrefectFuture as NewPrefectFuture from prefect.logging import get_run_logger -from prefect.records.cache_policies import DEFAULT, INPUTS, NONE, TASKDEF from prefect.results import ResultFactory from prefect.runtime import task_run as task_run_ctx from prefect.server import models @@ -1186,7 +1186,7 @@ def test_flow(): class TestTaskCaching: - async def test_repeated_task_call_within_flow_is_not_cached_by_default(self): + async def test_repeated_task_call_within_flow_is_cached_by_default(self): @task def foo(x): return x @@ -1197,7 +1197,7 @@ def bar(): first_state, second_state = bar() assert first_state.name == "Completed" - assert second_state.name == "Completed" + assert second_state.name == "Cached" assert await second_state.result() == await first_state.result() async def test_cache_hits_within_flows_are_cached(self): @@ -3728,7 +3728,7 @@ async def test_sets_run_name_once_per_call(): task_calls = 0 generate_task_run_name = MagicMock(return_value="some-string") - def test_task(): + def test_task(x: str): nonlocal task_calls task_calls += 1 @@ -3736,8 +3736,8 @@ def test_task(): @flow def my_flow(name): - decorated_task_method() - decorated_task_method() + decorated_task_method("a") + decorated_task_method("b") return "hi" @@ -4618,11 +4618,11 @@ def my_task(): assert my_task.result_storage_key == "foo" def test_cache_policy_inits_as_expected(self): - @task(cache_policy=TASKDEF) + @task(cache_policy=TASK_SOURCE) def my_task(): pass - assert my_task.cache_policy is TASKDEF + assert my_task.cache_policy is TASK_SOURCE class TestTransactions: @@ -4642,7 +4642,6 @@ def commit(txn): assert state.is_completed() assert state.name == "Completed" assert isinstance(data["txn"], Transaction) - assert str(state.state_details.task_run_id) == data["txn"].key class TestApplyAsync: