Skip to content

Commit

Permalink
Updates to cache policies (#14164)
Browse files Browse the repository at this point in the history
  • Loading branch information
cicdw committed Jun 20, 2024
1 parent 8990e9a commit 7caeb2f
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

@dataclass
class CachePolicy:
"""
Base class for all cache policies.
"""

@classmethod
def from_cache_key_fn(
cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
Expand Down Expand Up @@ -59,6 +63,11 @@ def __add__(self, other: "CachePolicy") -> "CompoundCachePolicy":

@dataclass
class CacheKeyFnPolicy(CachePolicy):
"""
This policy accepts a custom function with signature f(task_run_context, task_parameters, flow_parameters) -> str
and uses it to compute a task run cache key.
"""

# making it optional for tests
cache_key_fn: Optional[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
Expand All @@ -77,6 +86,13 @@ def compute_key(

@dataclass
class CompoundCachePolicy(CachePolicy):
"""
This policy is constructed from two or more other cache policies and works by computing the keys
for each policy individually, and then hashing a sorted tuple of all computed keys.
Any keys that return `None` will be ignored.
"""

policies: Optional[list] = None

def compute_key(
Expand All @@ -88,20 +104,25 @@ def compute_key(
) -> Optional[str]:
keys = []
for policy in self.policies or []:
keys.append(
policy.compute_key(
task_ctx=task_ctx,
inputs=inputs,
flow_parameters=flow_parameters,
**kwargs,
)
policy_key = policy.compute_key(
task_ctx=task_ctx,
inputs=inputs,
flow_parameters=flow_parameters,
**kwargs,
)
if policy_key is not None:
keys.append(policy_key)
if not keys:
return None
return hash_objects(*keys)


@dataclass
class Default(CachePolicy):
"Execution run ID only"
class _None(CachePolicy):
"""
Policy that always returns `None` for the computed cache key.
This policy prevents persistence.
"""

def compute_key(
self,
Expand All @@ -110,12 +131,14 @@ def compute_key(
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
return str(task_ctx.task_run.id)
return None


@dataclass
class _None(CachePolicy):
"ignore key policies altogether, always run - prevents persistence"
class TaskSource(CachePolicy):
"""
Policy for computing a cache key based on the source code of the task.
"""

def compute_key(
self,
Expand All @@ -124,33 +147,60 @@ def compute_key(
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
return None
if not task_ctx:
return None
try:
lines = inspect.getsource(task_ctx.task)
except TypeError:
lines = inspect.getsource(task_ctx.task.fn.__class__)

return hash_objects(lines)


@dataclass
class TaskDef(CachePolicy):
class FlowParameters(CachePolicy):
"""
Policy that computes the cache key based on a hash of the flow parameters.
"""

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
lines = inspect.getsource(task_ctx.task)
return hash_objects(lines)
if not flow_parameters:
return None
return hash_objects(flow_parameters)


@dataclass
class FlowParameters(CachePolicy):
pass
class RunId(CachePolicy):
"""
Returns either the prevailing flow run ID, or if not found, the prevailing task
run ID.
"""

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
if not task_ctx:
return None
run_id = task_ctx.task_run.flow_run_id
if run_id is None:
run_id = task_ctx.task_run.id
return str(run_id)


@dataclass
class Inputs(CachePolicy):
"""
Exposes flag for whether to include flow parameters as well.
And exclude/include config.
Policy that computes a cache key based on a hash of the runtime inputs provided to the task..
"""

exclude: Optional[list] = None
Expand All @@ -166,14 +216,19 @@ def compute_key(
inputs = inputs or {}
exclude = self.exclude or []

if not inputs:
return None

for key, val in inputs.items():
if key not in exclude:
hashed_inputs[key] = val

return hash_objects(hashed_inputs)


DEFAULT = Default()
INPUTS = Inputs()
NONE = _None()
TASKDEF = TaskDef()
TASK_SOURCE = TaskSource()
FLOW_PARAMETERS = FlowParameters()
RUN_ID = RunId()
DEFAULT = INPUTS + TASK_SOURCE + RUN_ID
9 changes: 8 additions & 1 deletion src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,18 @@ def call_hooks(self, state: State = None) -> Iterable[Callable]:
def compute_transaction_key(self) -> str:
key = None
if self.task.cache_policy:
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if flow_run_context:
parameters = flow_run_context.parameters
else:
parameters = None

key = self.task.cache_policy.compute_key(
task_ctx=task_run_context,
inputs=self.parameters,
flow_parameters=None,
flow_parameters=parameters,
)
elif self.task.result_storage_key is not None:
key = _format_user_supplied_storage_key(self.task.result_storage_key)
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from typing_extensions import Literal, ParamSpec

from prefect.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.client.orchestration import get_client
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import TaskRunInput, TaskRunResult
Expand All @@ -43,7 +44,6 @@
)
from prefect.futures import PrefectDistributedFuture, PrefectFuture
from prefect.logging.loggers import get_logger
from prefect.records.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
from prefect.settings import (
PREFECT_TASK_DEFAULT_RETRIES,
Expand Down
124 changes: 96 additions & 28 deletions tests/records/test_cache_policies.py → tests/test_cache_policies.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import itertools
from dataclasses import dataclass
from typing import Callable

import pytest

from prefect.records.cache_policies import (
from prefect.cache_policies import (
DEFAULT,
CachePolicy,
CompoundCachePolicy,
Default,
Inputs,
TaskDef,
RunId,
TaskSource,
_None,
)

Expand Down Expand Up @@ -40,26 +43,6 @@ def test_addition_of_none_is_noop(self, typ):
assert policy + other == other


class TestDefaultPolicy:
def test_initializes(self):
policy = Default()
assert isinstance(policy, CachePolicy)

def test_returns_run_id(self):
class Run:
id = "foo"

class TaskCtx:
pass

task_ctx = TaskCtx()
task_ctx.task_run = Run()

policy = Default()
key = policy.compute_key(task_ctx=task_ctx, inputs=None, flow_parameters=None)
assert key == "foo"


class TestInputsPolicy:
def test_initializes(self):
policy = Inputs()
Expand Down Expand Up @@ -133,7 +116,7 @@ def test_initializes(self):
assert isinstance(policy, CachePolicy)

def test_creation_via_addition(self):
one, two = Inputs(), Default()
one, two = Inputs(), TaskSource()
policy = one + two
assert isinstance(policy, CompoundCachePolicy)

Expand All @@ -152,18 +135,29 @@ def test_subtraction_creates_new_policies(self):
assert policy.policies != new_policy.policies

def test_creation_via_subtraction(self):
one = Default()
one = RunId()
policy = one - "foo"
assert isinstance(policy, CompoundCachePolicy)

def test_nones_are_ignored(self):
one, two = _None(), _None()
policy = CompoundCachePolicy(policies=[one, two])
assert isinstance(policy, CompoundCachePolicy)

fparams = dict(x=42, y="foo")
compound_key = policy.compute_key(
task_ctx=None, inputs=dict(z=[1, 2]), flow_parameters=fparams
)
assert compound_key is None


class TestTaskDefPolicy:
class TestTaskSourcePolicy:
def test_initializes(self):
policy = TaskDef()
policy = TaskSource()
assert isinstance(policy, CachePolicy)

def test_changes_in_def_change_key(self):
policy = TaskDef()
policy = TaskSource()

class TaskCtx:
pass
Expand All @@ -189,3 +183,77 @@ def my_func(x):
)

assert key != new_key


class TestDefaultPolicy:
def test_changing_the_inputs_busts_the_cache(self):
inputs = dict(x=42)
key = DEFAULT.compute_key(task_ctx=None, inputs=inputs, flow_parameters=None)

inputs = dict(x=43)
new_key = DEFAULT.compute_key(
task_ctx=None, inputs=inputs, flow_parameters=None
)

assert key != new_key

def test_changing_the_run_id_busts_the_cache(self):
@dataclass
class Run:
id: str
flow_run_id: str = None

def my_task():
pass

@dataclass
class TaskCtx:
task_run: Run
task = my_task

task_run_a = Run(id="a", flow_run_id="a")
task_run_b = Run(id="b", flow_run_id="b")
task_run_c = Run(id="c", flow_run_id=None)
task_run_d = Run(id="d", flow_run_id=None)

key_a = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_a), inputs=None, flow_parameters=None
)
key_b = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_b), inputs=None, flow_parameters=None
)
key_c = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_c), inputs=None, flow_parameters=None
)
key_d = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_d), inputs=None, flow_parameters=None
)

assert key_a not in [key_b, key_c, key_d]
assert key_b not in [key_a, key_c, key_d]
assert key_c not in [key_a, key_b, key_d]
assert key_d not in [key_a, key_b, key_c]

def test_changing_the_source_busts_the_cache(self):
@dataclass
class Run:
id: str
flow_run_id: str = None

@dataclass
class TaskCtx:
task_run: Run
task: Callable = None

task_run = Run(id="a", flow_run_id="b")
ctx_one = TaskCtx(task_run=task_run, task=lambda: "foo")
ctx_two = TaskCtx(task_run=task_run, task=lambda: "bar")

key_one = DEFAULT.compute_key(
task_ctx=ctx_one, inputs=None, flow_parameters=None
)
key_two = DEFAULT.compute_key(
task_ctx=ctx_two, inputs=None, flow_parameters=None
)

assert key_one != key_two

0 comments on commit 7caeb2f

Please sign in to comment.