Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cloud Caching #885

Merged
merged 11 commits into from
Apr 3, 2019
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ These changes are available in the [master branch](https://github.com/PrefectHQ/

- API reference documentation is now versioned - [#270](https://github.com/PrefectHQ/prefect/issues/270)
- Add `S3ResultHandler` for handling results to / from S3 buckets - [#879](https://github.com/PrefectHQ/prefect/pull/879)
- Add ability to use `Cached` states across flow runs in Cloud - [#885](https://github.com/PrefectHQ/prefect/pull/885)

### Enhancements
- Bump to latest version of `pytest` (4.3) - [#814](https://github.com/PrefectHQ/prefect/issues/814)
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/development/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ Documentation (including both concepts and API references) is built and deployed
To preview docs locally, you'll first need to install [VuePress](https://vuepress.vuejs.org/) and its dependencies. This requires the [yarn](https://yarnpkg.com/) package manager. You will also need to install the rest of Prefect's dependencies for generating docs. You only need to do this once:

```bash
git clone https://github.com/PrefectHQ/prefect.git
cd prefect
yarn install
pip install "prefect[all_extras]"
pip install ".[all_extras]"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this because otherwise it would break if new classes / methods were added in the master branch vs. the latest release

```

To launch a documentation preview:
Expand Down
29 changes: 29 additions & 0 deletions src/prefect/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,35 @@ def set_flow_run_state(

self.graphql(mutation, state=serialized_state) # type: Any

def get_latest_cached_states(
self, task_id: str, created_after: datetime.datetime
) -> List["prefect.engine.state.State"]:
"""
Pulls all Cached states for the given task which were created after the provided date.

Args:
- task_id (str): the task id for this task run
- created_after (datetime.datetime): the earliest date the state should have been created at

Returns:
- List[State]: a list of Cached states created after the given date
"""
where_clause = {
"where": {
"state": {"_eq": "Cached"},
"task_id": {"_eq": task_id},
"state_timestamp": {"_gte": created_after.isoformat()},
},
"order_by": {"state_timestamp": EnumValue("desc")},
}
query = {"query": {with_args("task_run", where_clause): "serialized_state"}}
cicdw marked this conversation as resolved.
Show resolved Hide resolved
result = self.graphql(query) # type: Any
deserializer = prefect.engine.state.State.deserialize
valid_states = [
deserializer(res.serialized_state) for res in result.data.task_run
]
return valid_states

def get_task_run_info(
self, flow_run_id: str, task_id: str, map_index: Optional[int] = None
) -> TaskRunInfoResult:
Expand Down
56 changes: 55 additions & 1 deletion src/prefect/engine/cloud/task_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import copy
import datetime
import warnings
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import prefect
from prefect.client import Client
from prefect.core import Edge, Task
from prefect.engine.cloud.utilities import prepare_state_for_cloud
from prefect.engine.result import NoResult
from prefect.engine.result import NoResult, Result
from prefect.engine.result_handlers import ResultHandler
from prefect.engine.runner import ENDRUN, call_state_handlers
from prefect.engine.state import Cached, Failed, Mapped, State
Expand Down Expand Up @@ -148,3 +149,56 @@ def initialize_run( # type: ignore
context.update(cloud=True)

return super().initialize_run(state=state, context=context)

@call_state_handlers
def check_task_is_cached(self, state: State, inputs: Dict[str, Result]) -> State:
"""
Checks if task is cached in the DB and whether any of the caches are still valid.

Args:
- state (State): the current state of this task
- inputs (Dict[str, Result]): a dictionary of inputs whose keys correspond
to the task's `run()` arguments.

Returns:
- State: the state of the task after running the check

Raises:
- ENDRUN: if the task is not ready to run
"""
if self.task.cache_for is not None:
oldest_valid_cache = datetime.datetime.utcnow() - self.task.cache_for
cached_states = self.client.get_latest_cached_states(
task_id=self.task.id, created_after=oldest_valid_cache
)

if not cached_states:
self.logger.debug(
"Task '{name}': can't use cache because no Cached states were found".format(
name=prefect.context.get("task_full_name", self.task.name)
)
)
else:
self.logger.debug(
"Task '{name}': {num} candidate cached states were found".format(
name=prefect.context.get("task_full_name", self.task.name),
num=len(cached_states),
)
)

for candidate_state in cached_states:
cicdw marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(candidate_state, Cached) # mypy assert
if self.task.cache_validator(
candidate_state, inputs, prefect.context.get("parameters")
):
candidate_state._result = candidate_state._result.to_result()
return candidate_state

self.logger.debug(
"Task '{name}': can't use cache because no candidate Cached states "
"were valid".format(
name=prefect.context.get("task_full_name", self.task.name)
)
)

return state
49 changes: 49 additions & 0 deletions tests/engine/cloud/test_cloud_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_task_runner_doesnt_call_client_if_map_index_is_none(client):
## assertions
assert client.get_task_run_info.call_count == 0 # never called
assert client.set_task_run_state.call_count == 2 # Pending -> Running -> Success
assert client.get_latest_cached_states.call_count == 0

states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
assert [type(s).__name__ for s in states] == ["Running", "Success"]
Expand Down Expand Up @@ -129,6 +130,54 @@ def raise_error():
assert res.is_running()


def test_task_runner_queries_for_cached_states_if_task_has_caching(client):
@prefect.task(cache_for=datetime.timedelta(minutes=1))
def cached_task():
return 42

state = Cached(
cached_result_expiration=datetime.datetime.utcnow()
+ datetime.timedelta(days=1),
result=99,
)
old_state = Cached(
cached_result_expiration=datetime.datetime.utcnow()
- datetime.timedelta(days=1),
result=13,
)
client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

res = CloudTaskRunner(task=cached_task).run()
assert client.get_latest_cached_states.called
assert res.is_successful()
assert res.is_cached()
assert res.result == 99


def test_task_runner_validates_cached_states_if_task_has_caching(client):
@prefect.task(cache_for=datetime.timedelta(minutes=1))
def cached_task():
return 42

state = Cached(
cached_result_expiration=datetime.datetime.utcnow()
- datetime.timedelta(minutes=2),
result=99,
)
old_state = Cached(
cached_result_expiration=datetime.datetime.utcnow()
- datetime.timedelta(days=1),
result=13,
)
client.get_latest_cached_states = MagicMock(return_value=[state, old_state])

res = CloudTaskRunner(task=cached_task).run()
assert client.get_latest_cached_states.called
assert res.is_successful()
assert res.is_cached()
assert res.result == 42


def test_task_runner_raises_endrun_if_client_cant_receive_state_updates(monkeypatch):
task = Task(name="test")
get_task_run_info = MagicMock(side_effect=SyntaxError)
Expand Down