Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to cache policies #14164

Merged
merged 9 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

@dataclass
class CachePolicy:
"""
Base class for all cache policies.
"""

@classmethod
def from_cache_key_fn(
cls, cache_key_fn: Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
Expand Down Expand Up @@ -59,6 +63,11 @@ def __add__(self, other: "CachePolicy") -> "CompoundCachePolicy":

@dataclass
class CacheKeyFnPolicy(CachePolicy):
"""
This policy accepts a custom function with signature f(task_run_context, task_parameters, flow_parameters) -> str
and uses it to compute a task run cache key.
"""

# making it optional for tests
cache_key_fn: Optional[
Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
Expand All @@ -77,6 +86,13 @@ def compute_key(

@dataclass
class CompoundCachePolicy(CachePolicy):
"""
This policy is constructed from two or more other cache policies and works by computing the keys
for each policy individually, and then hashing a sorted tuple of all computed keys.

Any keys that return `None` will be ignored.
"""

policies: Optional[list] = None

def compute_key(
Expand All @@ -88,20 +104,25 @@ def compute_key(
) -> Optional[str]:
keys = []
for policy in self.policies or []:
keys.append(
policy.compute_key(
task_ctx=task_ctx,
inputs=inputs,
flow_parameters=flow_parameters,
**kwargs,
)
policy_key = policy.compute_key(
task_ctx=task_ctx,
inputs=inputs,
flow_parameters=flow_parameters,
**kwargs,
)
if policy_key is not None:
keys.append(policy_key)
if not keys:
return None
return hash_objects(*keys)


@dataclass
class Default(CachePolicy):
"Execution run ID only"
class _None(CachePolicy):
"""
Policy that always returns `None` for the computed cache key.
This policy prevents persistence.
"""

def compute_key(
self,
Expand All @@ -110,12 +131,14 @@ def compute_key(
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
return str(task_ctx.task_run.id)
return None


@dataclass
class _None(CachePolicy):
"ignore key policies altogether, always run - prevents persistence"
class TaskSource(CachePolicy):
cicdw marked this conversation as resolved.
Show resolved Hide resolved
"""
Policy for computing a cache key based on the source code of the task.
"""

def compute_key(
self,
Expand All @@ -124,33 +147,60 @@ def compute_key(
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
return None
if not task_ctx:
return None
try:
lines = inspect.getsource(task_ctx.task)
except TypeError:
lines = inspect.getsource(task_ctx.task.fn.__class__)

return hash_objects(lines)


@dataclass
class TaskDef(CachePolicy):
class FlowParameters(CachePolicy):
cicdw marked this conversation as resolved.
Show resolved Hide resolved
"""
Policy that computes the cache key based on a hash of the flow parameters.
"""

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
lines = inspect.getsource(task_ctx.task)
return hash_objects(lines)
if not flow_parameters:
return None
return hash_objects(flow_parameters)


@dataclass
class FlowParameters(CachePolicy):
pass
class RunId(CachePolicy):
cicdw marked this conversation as resolved.
Show resolved Hide resolved
"""
Returns either the prevailing flow run ID, or if not found, the prevailing task
run ID.
"""

def compute_key(
self,
task_ctx: TaskRunContext,
inputs: Dict[str, Any],
flow_parameters: Dict[str, Any],
**kwargs,
) -> Optional[str]:
if not task_ctx:
return None
run_id = task_ctx.task_run.flow_run_id
if run_id is None:
run_id = task_ctx.task_run.id
return str(run_id)


@dataclass
class Inputs(CachePolicy):
"""
Exposes flag for whether to include flow parameters as well.

And exclude/include config.
Policy that computes a cache key based on a hash of the runtime inputs provided to the task..
"""

exclude: Optional[list] = None
Expand All @@ -166,14 +216,19 @@ def compute_key(
inputs = inputs or {}
exclude = self.exclude or []

if not inputs:
return None

for key, val in inputs.items():
if key not in exclude:
hashed_inputs[key] = val

return hash_objects(hashed_inputs)


DEFAULT = Default()
INPUTS = Inputs()
NONE = _None()
TASKDEF = TaskDef()
TASK_SOURCE = TaskSource()
FLOW_PARAMETERS = FlowParameters()
RUN_ID = RunId()
DEFAULT = INPUTS + TASK_SOURCE + RUN_ID
9 changes: 8 additions & 1 deletion src/prefect/task_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,18 @@ def call_hooks(self, state: State = None) -> Iterable[Callable]:
def compute_transaction_key(self) -> str:
key = None
if self.task.cache_policy:
flow_run_context = FlowRunContext.get()
task_run_context = TaskRunContext.get()

if flow_run_context:
parameters = flow_run_context.parameters
else:
parameters = None

key = self.task.cache_policy.compute_key(
task_ctx=task_run_context,
inputs=self.parameters,
flow_parameters=None,
flow_parameters=parameters,
)
elif self.task.result_storage_key is not None:
key = _format_user_supplied_storage_key(self.task.result_storage_key)
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from typing_extensions import Literal, ParamSpec

from prefect.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.client.orchestration import get_client
from prefect.client.schemas import TaskRun
from prefect.client.schemas.objects import TaskRunInput, TaskRunResult
Expand All @@ -44,7 +45,6 @@
)
from prefect.futures import PrefectDistributedFuture, PrefectFuture
from prefect.logging.loggers import get_logger
from prefect.records.cache_policies import DEFAULT, NONE, CachePolicy
from prefect.results import ResultFactory, ResultSerializer, ResultStorage
from prefect.settings import (
PREFECT_TASK_DEFAULT_RETRIES,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import itertools
from dataclasses import dataclass
from typing import Callable

import pytest

from prefect.records.cache_policies import (
from prefect.cache_policies import (
DEFAULT,
CachePolicy,
CompoundCachePolicy,
Default,
Inputs,
TaskDef,
RunId,
TaskSource,
_None,
)

Expand Down Expand Up @@ -40,26 +43,6 @@ def test_addition_of_none_is_noop(self, typ):
assert policy + other == other


class TestDefaultPolicy:
cicdw marked this conversation as resolved.
Show resolved Hide resolved
def test_initializes(self):
policy = Default()
assert isinstance(policy, CachePolicy)

def test_returns_run_id(self):
class Run:
id = "foo"

class TaskCtx:
pass

task_ctx = TaskCtx()
task_ctx.task_run = Run()

policy = Default()
key = policy.compute_key(task_ctx=task_ctx, inputs=None, flow_parameters=None)
assert key == "foo"


class TestInputsPolicy:
def test_initializes(self):
policy = Inputs()
Expand Down Expand Up @@ -133,7 +116,7 @@ def test_initializes(self):
assert isinstance(policy, CachePolicy)

def test_creation_via_addition(self):
one, two = Inputs(), Default()
one, two = Inputs(), TaskSource()
policy = one + two
assert isinstance(policy, CompoundCachePolicy)

Expand All @@ -152,18 +135,29 @@ def test_subtraction_creates_new_policies(self):
assert policy.policies != new_policy.policies

def test_creation_via_subtraction(self):
one = Default()
one = RunId()
policy = one - "foo"
assert isinstance(policy, CompoundCachePolicy)

def test_nones_are_ignored(self):
one, two = _None(), _None()
policy = CompoundCachePolicy(policies=[one, two])
assert isinstance(policy, CompoundCachePolicy)

fparams = dict(x=42, y="foo")
compound_key = policy.compute_key(
task_ctx=None, inputs=dict(z=[1, 2]), flow_parameters=fparams
)
assert compound_key is None


class TestTaskDefPolicy:
class TestTaskSourcePolicy:
def test_initializes(self):
policy = TaskDef()
policy = TaskSource()
assert isinstance(policy, CachePolicy)

def test_changes_in_def_change_key(self):
policy = TaskDef()
policy = TaskSource()

class TaskCtx:
pass
Expand All @@ -189,3 +183,77 @@ def my_func(x):
)

assert key != new_key


class TestDefaultPolicy:
def test_changing_the_inputs_busts_the_cache(self):
inputs = dict(x=42)
key = DEFAULT.compute_key(task_ctx=None, inputs=inputs, flow_parameters=None)

inputs = dict(x=43)
new_key = DEFAULT.compute_key(
task_ctx=None, inputs=inputs, flow_parameters=None
)

assert key != new_key

def test_changing_the_run_id_busts_the_cache(self):
@dataclass
class Run:
id: str
flow_run_id: str = None

def my_task():
pass

@dataclass
class TaskCtx:
task_run: Run
task = my_task

task_run_a = Run(id="a", flow_run_id="a")
task_run_b = Run(id="b", flow_run_id="b")
task_run_c = Run(id="c", flow_run_id=None)
task_run_d = Run(id="d", flow_run_id=None)

key_a = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_a), inputs=None, flow_parameters=None
)
key_b = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_b), inputs=None, flow_parameters=None
)
key_c = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_c), inputs=None, flow_parameters=None
)
key_d = DEFAULT.compute_key(
task_ctx=TaskCtx(task_run=task_run_d), inputs=None, flow_parameters=None
)

assert key_a not in [key_b, key_c, key_d]
assert key_b not in [key_a, key_c, key_d]
assert key_c not in [key_a, key_b, key_d]
assert key_d not in [key_a, key_b, key_c]

def test_changing_the_source_busts_the_cache(self):
@dataclass
class Run:
id: str
flow_run_id: str = None

@dataclass
class TaskCtx:
task_run: Run
task: Callable = None

task_run = Run(id="a", flow_run_id="b")
ctx_one = TaskCtx(task_run=task_run, task=lambda: "foo")
ctx_two = TaskCtx(task_run=task_run, task=lambda: "bar")

key_one = DEFAULT.compute_key(
task_ctx=ctx_one, inputs=None, flow_parameters=None
)
key_two = DEFAULT.compute_key(
task_ctx=ctx_two, inputs=None, flow_parameters=None
)

assert key_one != key_two
Loading
Loading