Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin: cache hook with diskcache #684

Merged
merged 21 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions examples/cache_hook/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Cache hook
This hook uses the [diskcache](https://grantjenks.com/docs/diskcache/tutorial.html) to cache node execution on disk. The cache key is a tuple of the function's `(source code, input a, ..., input n)`.

> 💡 This can be a great tool for developing inside a Jupyter notebook or other interactive environments.

Disk cache has great features to:
- set maximum cache size
- set automated eviction policy once maximum size is reached
- allow custom `Disk` implementations to change the serialization protocol (e.g., pickle, JSON)

> ⚠ The default `Disk` serializes objects using the `pickle` module. Changing Python or library versions could break your cache (both keys and values). Learn more about [caveats](https://grantjenks.com/docs/diskcache/tutorial.html#caveats).

> ❓ To store artifacts robustly, please use Hamilton materializers instead.
zilto marked this conversation as resolved.
Show resolved Hide resolved


# How to use it
## Use the hook
Find it under plugins at `hamilton.plugins.h_diskcache` and add it to your Driver definition.

```python
from hamilton import driver
from hamilton.plugins import h_diskcache
import functions

dr = (
driver.Builder()
.with_modules(functions)
.with_adapters(h_diskcache.CacheHook())
.build()
)
```

## Inspect the hook
To inspect the caching behavior in real-time, you can get the logger:

```python
logger = logging.getLogger("hamilton.plugins.h_diskcache")
logger.setLevel(logging.DEBUG) # or logging.INFO
logger.addHandler(logging.StreamHandler())
```
- INFO will only return the total cache after executing the Driver
- DEBUG will return inputs for each node and specify if the value is `from cache` or `executed`

## Clear cache
The utility function `h_diskcache.evict_except_driver` allows you to clear cached values for all nodes except those in the passed driver. This is an efficient tool to clear old artifacts as your project evolves.

```python
from hamilton import driver
from hamilton.plugins import h_diskcache
import functions

dr = (
driver.Builder()
.with_modules(functions)
.with_adapters(h_diskcache.CacheHook())
.build()
)
h_diskcache_evict_except_driver(dr)
```

## Cache settings
Find all the cache settings in the [diskcache docs](https://grantjenks.com/docs/diskcache/api.html#constants).
10 changes: 10 additions & 0 deletions examples/cache_hook/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def A(external: int) -> int:
return external % 7 + 1


def B(A: int) -> float:
return A / 4


def C(A: int, B: float) -> float:
return A**2 + B
86 changes: 86 additions & 0 deletions examples/cache_hook/notebook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from hamilton import driver\n",
"from hamilton.plugins import h_diskcache\n",
"\n",
"import functions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"# get the plugin logger\n",
"logger = logging.getLogger(\"hamilton.plugins.h_diskcache\")\n",
"logger.setLevel(logging.DEBUG) # set logging.INFO for less info\n",
"logger.addHandler(logging.StreamHandler())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A {'external': 10}: from cache\n",
"B {'A': 4}: from cache\n",
"C {'A': 4, 'B': 1.0}: from cache\n",
"Cache size: 0.03 MB\n"
]
}
],
"source": [
"dr = (\n",
" driver.Builder()\n",
" .with_modules(functions)\n",
" .with_adapters(h_diskcache.CacheHook())\n",
" .build()\n",
")\n",
"# if you ran `run.py`, you should see the nodes being\n",
"# read from cache\n",
"results = dr.execute([\"C\"], inputs=dict(external=10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions examples/cache_hook/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sf-hamilton[diskcache]
19 changes: 19 additions & 0 deletions examples/cache_hook/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging

import functions

from hamilton import driver
from hamilton.plugins import h_diskcache


def main():
dr = driver.Builder().with_modules(functions).with_adapters(h_diskcache.CacheHook()).build()
results = dr.execute(["C"], inputs=dict(external=10))
print(results)


if __name__ == "__main__":
logger = logging.getLogger("hamilton.plugins.h_diskcache")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())
main()
18 changes: 18 additions & 0 deletions hamilton/ad_hoc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A suite of tools for ad-hoc use"""

import linecache
import sys
import types
import uuid
Expand Down Expand Up @@ -57,3 +58,20 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo
setattr(module, fn_name, fn)
sys.modules[module_name] = module
return module


def module_from_source(source: str) -> ModuleType:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to be in this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this function, the CacheHook doesn't work in notebooks because it won't be able to use inspect to get the source code and create a hash.

The alternative would be to use the functions's __code__ attribute, but this would dependent on the Python version. It might not be an issue because the cache's pickling is already Python version dependent

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used? Think I'm missing it... But that makes sense I think. You're talking about temporary modules, right?

With temporary modules you'd have to special-case them, go through every function and use inspect on them.

"""Create a temporary module from source code"""
module_name = _generate_unique_temp_module_name()
module_object = ModuleType(module_name)
code_object = compile(source, module_name, "exec")
sys.modules[module_name] = module_object
exec(code_object, module_object.__dict__)

linecache.cache[module_name] = (
len(source.splitlines()),
None,
source.splitlines(True),
module_name,
)
return module_object
31 changes: 30 additions & 1 deletion hamilton/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import hashlib
import inspect
from types import ModuleType
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union


def is_submodule(child: ModuleType, parent: ModuleType):
Expand All @@ -22,3 +24,30 @@ def valid_fn(fn):
)

return [f for f in inspect.getmembers(function_module, predicate=valid_fn)]


def hash_source_code(source: Union[str, Callable], strip: bool = False) -> str:
"""Create a single hash (str) from the bytecode of a function"""
if isinstance(source, Callable):
source = inspect.getsource(source)

if strip:
try:
source = remove_docs_and_comments(source)
except Exception:
pass

return hashlib.sha256(source.encode()).hexdigest()


def remove_docs_and_comments(source: str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might not be worth removing comments TBH -- can you at least add some comments here as to what you're doing? Code is a bit cryptic.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also worth adding a test for this

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Behavior is tested via test_graph_utils:: test_different_hash_docstring and test_different_hash_comment. They check both if "the hashes are unequal when not stripping docs and comments" and "the hashes are equal when stripping docs and comments".

The tests would be improved by not relying on graph_utils.hash_source_code() though

Copy link
Collaborator Author

@zilto zilto Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add comments to the function

I think the value of removing docstring/comments will depend on the workflow for the CI / CLI code diff (which rely on the same hashing).

If we marked as "changed" a node that had a changed docstring or added #comment AND all downstream nodes, the amount of false positive will make the feature pointless IMHO

comments_stripped = ast.unparse(ast.parse(source))

formatted_code = ""
for line in comments_stripped.split("\n"):
if line.lstrip()[:1] in ("'", '"'):
continue

formatted_code += line + "\n"

return formatted_code
126 changes: 126 additions & 0 deletions hamilton/plugins/h_diskcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import logging
from typing import Any, Dict, List, Union

import diskcache

from hamilton import driver, graph_types, graph_utils, lifecycle, node

logger = logging.getLogger(__name__)


def _kb_to_mb(kb: int) -> float:
return kb / (1024**2)
zilto marked this conversation as resolved.
Show resolved Hide resolved


def evict_all_except(nodes_to_keep: Dict[str, node.Node], cache: diskcache.Cache) -> int:
"""Evicts all nodes and node version except those passed.
Remaining nodes may have multiple entries for different input values
"""
nodes_history: Dict[str, List[str]] = cache.get(key=CacheHook.nodes_history_key) # type: ignore

new_nodes_history = dict()
eviction_counter = 0
for node_name, history in nodes_history.items():
if len(history) < 1:
continue

if node_name in nodes_to_keep.keys():
node_to_keep = nodes_to_keep[node_name]
hash_to_keep = graph_utils.hash_source_code(node_to_keep.callable, strip=True)
history.remove(hash_to_keep)
new_nodes_history[node_name] = [hash_to_keep]

for hash_to_evict in history:
cache.evict(tag=f"{node_name}.{hash_to_evict}")
eviction_counter += 1

cache.set(key=CacheHook.nodes_history_key, value=new_nodes_history)
return eviction_counter


def evict_all_except_driver(dr: driver.Driver) -> dict:
"""Wrap the utility `evict_all_except` to receive a driver.Driver object"""
cache_hooks = [adapter for adapter in dr.adapter.adapters if isinstance(adapter, CacheHook)]

if len(cache_hooks) == 0:
raise AssertionError("0 `h_diskcache.CacheHook` defined for this Driver")
elif len(cache_hooks) > 1:
raise AssertionError(">1 `h_diskcache.CacheHook` defined for this Driver")

cache: diskcache.Cache = cache_hooks[0].cache
volume_before = cache.volume()
eviction_counter = evict_all_except(nodes_to_keep=dr.graph.nodes, cache=cache)
volume_after = cache.volume()
volume_difference = volume_before - volume_after

logger.info(f"Evicted: {_kb_to_mb(volume_difference):.2f} MB")
logger.debug(f"Evicted {eviction_counter} entries")
logger.debug(f"Cache size after: {_kb_to_mb(volume_after):.2f} MB")

return dict(
evicted_size_mb=_kb_to_mb(volume_difference),
eviction_counter=eviction_counter,
size_after=_kb_to_mb(volume_after),
)


class CacheHook(
lifecycle.NodeExecutionHook,
lifecycle.GraphExecutionHook,
lifecycle.NodeExecutionMethod,
):
nodes_history_key: str = "_nodes_history"

def __init__(
self, cache_vars: Union[List[str], None] = None, cache_path: str = ".", **cache_settings
):
self.cache_vars = cache_vars if cache_vars else []
self.cache_path = cache_path
self.cache = diskcache.Cache(directory=cache_path, **cache_settings)
self.nodes_history: Dict[str, List[str]] = self.cache.get(
key=CacheHook.nodes_history_key, default=dict()
) # type: ignore
self.used_nodes_hash: Dict[str, str] = dict()

def run_before_graph_execution(self, *, graph: graph_types.HamiltonGraph, **kwargs):
"""Set cache_vars to all nodes if not specified"""
if self.cache_vars == []:
self.cache_vars = [n.name for n in graph.nodes]

def run_to_execute_node(
zilto marked this conversation as resolved.
Show resolved Hide resolved
self, *, node_name: str, node_callable: Any, node_kwargs: Dict[str, Any], **kwargs
):
"""Create hash key then use cached value if exist"""
if node_name not in self.cache_vars:
return node_callable(**node_kwargs)

node_hash = graph_utils.hash_source_code(node_callable, strip=True)
self.used_nodes_hash[node_name] = node_hash
cache_key = (node_hash, *node_kwargs.values())

from_cache = self.cache.get(key=cache_key, default=None)
if from_cache is not None:
logger.debug(f"{node_name} {node_kwargs}: from cache")
return from_cache

logger.debug(f"{node_name} {node_kwargs}: executed")
self.nodes_history[node_name] = self.nodes_history.get(node_name, []) + [node_hash]
return node_callable(**node_kwargs)

def run_after_node_execution(self, *, node_name: str, node_kwargs: dict, result: Any, **kwargs):
if node_name not in self.cache_vars:
return

node_hash = self.used_nodes_hash[node_name]
cache_key = (node_hash, *node_kwargs.values())
cache_tag = f"{node_name}.{node_hash}"
# only adds if key doesn't exist
self.cache.add(key=cache_key, value=result, tag=cache_tag)

def run_after_graph_execution(self, *args, **kwargs):
self.cache.set(key=CacheHook.nodes_history_key, value=self.nodes_history)
logger.info(f"Cache size: {_kb_to_mb(self.cache.volume()):.2f} MB")
self.cache.close()

def run_before_node_execution(self, *args, **kwargs):
pass
1 change: 1 addition & 0 deletions requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ alabaster>=0.7,<0.8,!=0.7.5 # read the docs pins
commonmark==0.9.1 # read the docs pins
dask[distributed]
ddtrace
diskcache
# furo -- install from main for now until the next release is out:
git+https://github.com/pradyunsg/furo@main
gitpython # Required for parsing git info for generation of data-adapter docs
Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
dask
diskcache
fsspec
graphviz
kaleido
Expand Down
Loading