Skip to content

Commit

Permalink
Plugin: cache hook with diskcache (#684)
Browse files Browse the repository at this point in the history
* diskcache hook + tests added

* install via sf-hamilton[diskcache] (requires Python >=3.9)

* examples/cache_hook added

```python
from hamilton import driver
from hamilton.plugins import h_diskcache
import functions

# get the logger to view cache retrieval info
import logging
logger = logging.getLogger("hamilton.plugins.h_diskcache")
logger.setLevel(logging.DEBUG)  # or logging.INFO
logger.addHandler(logging.StreamHandler())

# build driver with cache hook
dr = (
    driver.Builder()
    .with_modules(functions)
    .with_adapters(h_diskcache.CacheHook())
    .build()
)

# use execute or materialize as usual
dr.execute(["C"])
```

---------

Co-authored-by: zilto <tjean@DESKTOP-V6JDCS2>
  • Loading branch information
zilto and zilto committed Feb 16, 2024
1 parent 4fefdb6 commit dab7fac
Show file tree
Hide file tree
Showing 14 changed files with 711 additions and 1 deletion.
62 changes: 62 additions & 0 deletions examples/cache_hook/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Cache hook
This hook uses the [diskcache](https://grantjenks.com/docs/diskcache/tutorial.html) to cache node execution on disk. The cache key is a tuple of the function's `(source code, input a, ..., input n)`.

> 💡 This can be a great tool for developing inside a Jupyter notebook or other interactive environments.
Disk cache has great features to:
- set maximum cache size
- set automated eviction policy once maximum size is reached
- allow custom `Disk` implementations to change the serialization protocol (e.g., pickle, JSON)

> ⚠ The default `Disk` serializes objects using the `pickle` module. Changing Python or library versions could break your cache (both keys and values). Learn more about [caveats](https://grantjenks.com/docs/diskcache/tutorial.html#caveats).
> ❓ To store artifacts robustly, please use Hamilton materializers or the [CachingGraphAdapter](https://github.com/DAGWorks-Inc/hamilton/tree/main/examples/caching_nodes) instead. The `CachingGraphAdapter` stores tagged nodes directly on the file system using common formats (JSON, CSV, Parquet, etc.). However, it isn't aware of your function version and requires you to manually manage your disk space.

# How to use it
## Use the hook
Find it under plugins at `hamilton.plugins.h_diskcache` and add it to your Driver definition.

```python
from hamilton import driver
from hamilton.plugins import h_diskcache
import functions

dr = (
driver.Builder()
.with_modules(functions)
.with_adapters(h_diskcache.CacheHook())
.build()
)
```

## Inspect the hook
To inspect the caching behavior in real-time, you can get the logger:

```python
logger = logging.getLogger("hamilton.plugins.h_diskcache")
logger.setLevel(logging.DEBUG) # or logging.INFO
logger.addHandler(logging.StreamHandler())
```
- INFO will only return the total cache after executing the Driver
- DEBUG will return inputs for each node and specify if the value is `from cache` or `executed`

## Clear cache
The utility function `h_diskcache.evict_except_driver` allows you to clear cached values for all nodes except those in the passed driver. This is an efficient tool to clear old artifacts as your project evolves.

```python
from hamilton import driver
from hamilton.plugins import h_diskcache
import functions

dr = (
driver.Builder()
.with_modules(functions)
.with_adapters(h_diskcache.CacheHook())
.build()
)
h_diskcache_evict_except_driver(dr)
```

## Cache settings
Find all the cache settings in the [diskcache docs](https://grantjenks.com/docs/diskcache/api.html#constants).
10 changes: 10 additions & 0 deletions examples/cache_hook/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def A(external: int) -> int:
return external % 7 + 1


def B(A: int) -> float:
return A / 4


def C(A: int, B: float) -> float:
return A**2 + B
86 changes: 86 additions & 0 deletions examples/cache_hook/notebook.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from hamilton import driver\n",
"from hamilton.plugins import h_diskcache\n",
"\n",
"import functions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"\n",
"# get the plugin logger\n",
"logger = logging.getLogger(\"hamilton.plugins.h_diskcache\")\n",
"logger.setLevel(logging.DEBUG) # set logging.INFO for less info\n",
"logger.addHandler(logging.StreamHandler())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A {'external': 10}: from cache\n",
"B {'A': 4}: from cache\n",
"C {'A': 4, 'B': 1.0}: from cache\n",
"Cache size: 0.03 MB\n"
]
}
],
"source": [
"dr = (\n",
" driver.Builder()\n",
" .with_modules(functions)\n",
" .with_adapters(h_diskcache.CacheHook())\n",
" .build()\n",
")\n",
"# if you ran `run.py`, you should see the nodes being\n",
"# read from cache\n",
"results = dr.execute([\"C\"], inputs=dict(external=10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions examples/cache_hook/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sf-hamilton[diskcache]
19 changes: 19 additions & 0 deletions examples/cache_hook/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import logging

import functions

from hamilton import driver
from hamilton.plugins import h_diskcache


def main():
dr = driver.Builder().with_modules(functions).with_adapters(h_diskcache.CacheHook()).build()
results = dr.execute(["C"], inputs=dict(external=10))
print(results)


if __name__ == "__main__":
logger = logging.getLogger("hamilton.plugins.h_diskcache")
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())
main()
18 changes: 18 additions & 0 deletions hamilton/ad_hoc_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A suite of tools for ad-hoc use"""

import linecache
import sys
import types
import uuid
Expand Down Expand Up @@ -57,3 +58,20 @@ def create_temporary_module(*functions: Callable, module_name: str = None) -> Mo
setattr(module, fn_name, fn)
sys.modules[module_name] = module
return module


def module_from_source(source: str) -> ModuleType:
"""Create a temporary module from source code"""
module_name = _generate_unique_temp_module_name()
module_object = ModuleType(module_name)
code_object = compile(source, module_name, "exec")
sys.modules[module_name] = module_object
exec(code_object, module_object.__dict__)

linecache.cache[module_name] = (
len(source.splitlines()),
None,
source.splitlines(True),
module_name,
)
return module_object
68 changes: 67 additions & 1 deletion hamilton/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import ast
import hashlib
import inspect
from types import ModuleType
from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Union


def is_submodule(child: ModuleType, parent: ModuleType):
Expand All @@ -22,3 +24,67 @@ def valid_fn(fn):
)

return [f for f in inspect.getmembers(function_module, predicate=valid_fn)]


def hash_source_code(source: Union[str, Callable], strip: bool = False) -> str:
"""Hashes the source code of a function (str).
The `strip` parameter requires Python 3.9
If strip, try to remove docs and comments from source code string. Since
they don't impact function behavior, they shouldn't influence the hash.
"""
if isinstance(source, Callable):
source = inspect.getsource(source)

source = source.strip()

if strip:
try:
# could fail if source is indented code.
# see `remove_docs_and_comments` docstring for details.
source = remove_docs_and_comments(source)
except Exception:
pass

return hashlib.sha256(source.encode()).hexdigest()


def remove_docs_and_comments(source: str) -> str:
"""Remove the docs and comments from a source code string.
The use of `ast.unparse()` requires Python 3.9
1. Parsing then unparsing the AST of the source code will
create a code object and convert it back to a string. In the
process, comments are stripped.
2. walk the AST to check if first element after `def` is a
docstring. If so, edit AST to skip the docstring
NOTE. The ast parsing will fail if `source` has syntax errors. For the
majority of cases this is caught upstream (e.g., by calling `import`).
The foreseeable edge case is if `source` is the result of `inspect.getsource`
on a nested function, method, or callable where `def` isn't at column 0.
Standard usage of Hamilton requires users to define functions/nodes at the top
level of a module, and therefore no issues should arise.
"""
parsed = ast.parse(source)
for node in ast.walk(parsed):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue

if not len(node.body):
continue

# check if 1st node is a docstring
if not isinstance(node.body[0], ast.Expr):
continue

if not hasattr(node.body[0], "value") or not isinstance(node.body[0].value, ast.Str):
continue

# skip docstring
node.body = node.body[1:]

return ast.unparse(parsed)
Loading

0 comments on commit dab7fac

Please sign in to comment.