Skip to content

Commit

Permalink
Adds tag querying to listing variables
Browse files Browse the repository at this point in the history
This completes #620.

This implements the following behavior:

Behavior is now this. Given
```python
@tag( business_value=["CA", "US"] )`
def combo_node():
    ...

@tag( business_value=["US"] )`
def only_us_node():
    ...

@tag( business_value=["CA"], some_other_tag="BAR" )`
def only_ca_node():
    ...
```
then the following queries will return:
```python
dr.list_available_variables(tag_filter=dict(business_lines=["CA"]))
dr.list_available_variables(tag_filter=dict(business_lines="US"))
dr.list_available_variables(tag_filter=dict(business_lines=["CA", "US"] ))
dr.list_available_variables(tag_filter=dict(business_lines=None ))
dr.list_available_variables(tag_filter=dict(business_lines="UK" ))
dr.list_available_variables(tag_filter=dict(business_lines="US", some_other_tag="FOO" )
dr.list_available_variables(tag_filter=dict(business_lines="CA", some_other_tag="BAR" )
```
So values in lists are read as OR clauses, and multiple
tags in the query dict are read as AND clauses.

We settled on this design since it seemed the most intuitive to read,
while also leaving the door for more complex querying capabilities.

Squashed commits:
[f7a9752] Makes matches_query return True on empty dict
[f93282b] Updates logic for tag querying

This moves the tag filter logic to the node module,
since it's kind of coupled to it.

Behavior is now this. Given
```python
@tag( business_value=["CA", "US"] )`
def combo_node():
    ...

@tag( business_value=["US"] )`
def only_us_node():
    ...

@tag( business_value=["CA"], some_other_tag="BAR" )`
def only_ca_node():
    ...
```
then the following queries will return:
```python
dr.list_available_variables(tag_filter=dict(business_lines=["CA"]))
dr.list_available_variables(tag_filter=dict(business_lines="US"))
dr.list_available_variables(tag_filter=dict(business_lines=["CA", "US"] ))
dr.list_available_variables(tag_filter=dict(business_lines=None ))
dr.list_available_variables(tag_filter=dict(business_lines="UK" ))
dr.list_available_variables(tag_filter=dict(business_lines="US", some_other_tag="FOO" )
dr.list_available_variables(tag_filter=dict(business_lines="CA", some_other_tag="BAR" )
```

[94c58de] Fixing whitespace and type annotation for list_available_variables

These were incorrect.
[2cc83e9] Adds exposing tag_filter open on listing variables #620

This is a quick way to improve the ergonomics of listing
all available variables and filtering by some criteria.

```python
dr.list_available_variables() # gets all
dr.list_available_variables({"TAG_NAME": "TAG_VALUE"})  # gets all matching tag name & value
dr.list_available_variables({"TAG_NAME": ["TAG_VALUE", "TAG_VALUE2"]})  # gets all matching tag name & all values order invariant
dr.list_available_variables({"TAG_NAME": None})  # gets all with matching tag, irrespective of value
```

This completes 620, and handles the filtering
logic in the driver function. I think this is an reasonable
place to put it for now.

Assumptions:
 - if passed a list it's an exact match
 - AND clause across all values passed in.
 - None means just get me anything with that tag.
 - it's keyword only argument to enable us to maintain
backwards compatibility for future changes.

More complex clauses seem out of scope for this change.

But for the future, if we want to change things, e.g. take in
a query object that can express `not in, or, etc`  We can do that.

Another idea I didn't pursue was enabling query syntax:

```python
{"TAG_NAME": "=FOO"},  # equal
{"TAG_NAME": "!FOO"},  # not equal
etc
```
Since with lists of values of strings, this could get quite complex
to parse/understand...
  • Loading branch information
skrawcz committed Jan 12, 2024
1 parent 9744bbe commit eb5b2ce
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 10 deletions.
53 changes: 45 additions & 8 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,20 +638,57 @@ def raw_execute(
return out

@capture_function_usage
def list_available_variables(self) -> List[Variable]:
def list_available_variables(
self, *, tag_filter: Dict[str, Union[Optional[str], List[str]]] = None
) -> List[Variable]:
"""Returns available variables, i.e. outputs.
These variables corresond 1:1 with nodes in the DAG, and contain the following information:
1. name: the name of the node
2. tags: the tags associated with this node
3. type: The type of data this node returns
4. is_external_input: Whether this node represents an external input (required from outside),
or not (has a function specifying its behavior).
These variables correspond 1:1 with nodes in the DAG, and contain the following information:
1. name: the name of the node
2. tags: the tags associated with this node
3. type: The type of data this node returns
4. is_external_input: Whether this node represents an external input (required from outside), \
or not (has a function specifying its behavior).
.. code-block:: python
# gets all
dr.list_available_variables()
# gets exact matching tag name and tag value
dr.list_available_variables({"TAG_NAME": "TAG_VALUE"})
# gets all matching tag name and at least one of the values in the list
dr.list_available_variables({"TAG_NAME": ["TAG_VALUE1", "TAG_VALUE2"]})
# gets all with matching tag name, irrespective of value
dr.list_available_variables({"TAG_NAME": None})
# AND query between the two tags (i.e. both need to match)
dr.list_available_variables({"TAG_NAME": "TAG_VALUE", "TAG_NAME2": "TAG_VALUE2"}
:param tag_filter: A dictionary of tags to filter by. Only nodes matching the tags and their values will
be returned. If the value for a tag is None, then we will return all nodes with that tag. If the value
is non-empty we will return all nodes with that tag and that value.
:return: list of available variables (i.e. outputs).
"""
return [Variable.from_node(n) for n in self.graph.get_nodes()]
all_nodes = self.graph.get_nodes()
if tag_filter:
valid_filter_values = all(
map(
lambda x: isinstance(x, str)
or (isinstance(x, list) and len(x) != 0)
or x is None,
tag_filter.values(),
)
)
if not valid_filter_values:
raise ValueError("All tag query values must be a string or list of strings")
results = []
for n in all_nodes:
if node.matches_query(n.tags, tag_filter):
results.append(Variable.from_node(n))
else:
results = [Variable.from_node(n) for n in all_nodes]
return results

@capture_function_usage
def display_all_functions(
Expand Down
33 changes: 33 additions & 0 deletions hamilton/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,36 @@ def new_callable(**kwargs) -> Any:
return __transform(self.callable(**kwargs), kwargs)

return self.copy_with(callabl=new_callable, typ=__output_type)


def matches_query(
tags: Dict[str, Union[str, List[str]]], query_dict: Dict[str, Optional[Union[str, List[str]]]]
) -> bool:
"""Check whether a set of node tags matches the query based on tags.
An empty dict of a query matches all tags.
:param tags: the tags of the node.
:param query_dict: of tag to value. If value is None, we just check that the tag exists.
:return: True if we have tags that match all tag queries, False otherwise.
"""
# it's an AND clause between each tag and value in the query dict.
for tag, value in query_dict.items():
# if tag not in node we can return False immediately.
if tag not in tags:
return False
# if value is None -- we don't care about the value, just that the tag exists.
if value is None:
continue
node_tag_value = tags[tag]
if not isinstance(node_tag_value, list):
node_tag_value = [node_tag_value]
if not isinstance(value, list):
value = [value]
if set(value).intersection(set(node_tag_value)):
# if there is some overlap, we're good.
continue
else:
# else, return False.
return False
return True
53 changes: 52 additions & 1 deletion tests/test_hamilton_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,58 @@ def test_driver_variables_exposes_tags():
assert tags["a"] == {"module": "tests.resources.tagging", "test": "a"}
assert tags["b"] == {"module": "tests.resources.tagging", "test": "b_c"}
assert tags["c"] == {"module": "tests.resources.tagging", "test": "b_c"}
assert tags["d"] == {"module": "tests.resources.tagging"}
assert tags["d"] == {"module": "tests.resources.tagging", "test_list": ["us", "uk"]}


@pytest.mark.parametrize(
"filter,expected",
[
(None, {"a", "b_c", "b", "c", "d"}), # no filter
({}, {"a", "b_c", "b", "c", "d"}), # empty filter
({"test": "b_c"}, {"b", "c"}),
({"test": None}, {"a", "b", "c"}),
({"module": "tests.resources.tagging"}, {"a", "b_c", "b", "c", "d"}),
({"test_list": "us"}, {"d"}),
({"test_list": "uk"}, {"d"}),
({"test_list": ["uk"]}, {"d"}),
({"module": "tests.resources.tagging", "test": "b_c"}, {"b", "c"}),
({"test_list": ["nz", "uk"]}, {"d"}),
({"test_list": ["us", "uk"]}, {"d"}),
({"test_list": ["uk", "us"]}, {"d"}),
({"test": ["b_c"]}, {"b", "c"}),
({"test": ["b_c", "foo"]}, {"b", "c"}),
],
ids=[
"filter with None passed",
"filter with empty filter",
"filter by single tag with extract decorator",
"filter with None value",
"filter with specific value",
"filter tag with list values - value 1",
"filter tag with list values - value 2",
"filter tag with list values - query is single node list",
"filter with two filter clauses",
"filter with with list values not exact OR interpretation",
"filter with with list values exact",
"filter with with list values exact order invariant",
"filter with with list values edge case one item match",
"filter with with list values OR interpretation",
],
)
def test_driver_variables_filters_tags(filter, expected):
dr = Driver({}, tests.resources.tagging)
actual = {var.name for var in dr.list_available_variables(tag_filter=filter)}
assert actual == expected


def test_driver_variables_filters_tags_error():
dr = Driver({}, tests.resources.tagging)
with pytest.raises(ValueError):
# non string value is not allowed
dr.list_available_variables(tag_filter={"test": 1234})
with pytest.raises(ValueError):
# empty list shouldn't be allowed
dr.list_available_variables(tag_filter={"test": []})


def test_driver_variables_external_input():
Expand Down
45 changes: 44 additions & 1 deletion tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy.typing as npt
import pytest

from hamilton.node import DependencyType, Node
from hamilton.node import DependencyType, Node, matches_query


def test_node_from_fn_happy():
Expand Down Expand Up @@ -61,3 +61,46 @@ def annotated_func(first: ArrayN[np.float64], other: float = 2.0) -> ArrayN[np.f
}
assert node.input_types == expected
assert node.type == Annotated[np.ndarray[Any, np.dtype[np.float64]], Literal["N"]]


@pytest.mark.parametrize(
"tags, query, expected",
[
({}, {"module": "tests.resources.tagging"}, False),
({"module": "tests.resources.tagging"}, {}, True),
({"module": "tests.resources.tagging"}, {"module": "tests.resources.tagging"}, True),
({"module": "tests.resources.tagging"}, {"module": None}, True),
({"module": "tests.resources.tagging"}, {"module": None, "tag2": "value"}, False),
(
{"module": "tests.resources.tagging"},
{"module": "tests.resources.tagging", "tag2": "value"},
False,
),
({"tag1": ["tag_value1"]}, {"tag1": "tag_value1"}, True),
({"tag1": ["tag_value1"]}, {"tag1": ["tag_value1"]}, True),
({"tag1": ["tag_value1"]}, {"tag1": ["tag_value1", "tag_value2"]}, True),
({"tag1": "tag_value1"}, {"tag1": ["tag_value1", "tag_value2"]}, True),
({"tag1": "tag_value1"}, {"tag1": ["tag_value3", "tag_value4"]}, False),
({"tag1": ["tag_value1"]}, {"tag1": "tag_value2"}, False),
({"tag1": "tag_value1"}, {"tag1": "tag_value2"}, False),
({"tag1": ["tag_value1"]}, {"tag1": ["tag_value2"]}, False),
],
ids=[
"no tags fail",
"no query pass",
"exact match pass",
"match tag key pass",
"missing extra tag fail",
"missing extra tag2 fail",
"list single match pass",
"list list match pass",
"list list match one of pass",
"single list match one of pass",
"single list fail",
"list single fail",
"single single fail",
"list list fail",
],
)
def test_tags_match_query(tags: dict, query: dict, expected: bool):
assert matches_query(tags, query) == expected

0 comments on commit eb5b2ce

Please sign in to comment.