Skip to content

Commit

Permalink
Adds the original function pointer to node/variable objects
Browse files Browse the repository at this point in the history
Before we had no way of knowing which nodes were generated by which
functions. This adds the capability. Each variable now points to the
functions that generated it.

Note this is a tuple of functions as multiple functions could have
generated it (in the case of a subdag). The order of these will always
have the lowest level function first, and will likely correspond to
namespaces, in reverse (although this contract isn't ironed in yet).
  • Loading branch information
elijahbenizzy committed May 5, 2023
1 parent b6bedbd commit 415f09e
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 15 deletions.
30 changes: 19 additions & 11 deletions hamilton/driver.py
Expand Up @@ -71,6 +71,22 @@ class Variable:
type: typing.Type
tags: Dict[str, str] = field(default_factory=dict)
is_external_input: bool = field(default=False)
originating_functions: Optional[Tuple[Callable, ...]] = None

@staticmethod
def from_node(n: node.Node) -> "Variable":
"""Creates a Variable from a Node.
:param n: Node to create the Variable from.
:return: Variable created from the Node.
"""
return Variable(
name=n.name,
type=n.type,
tags=n.tags,
is_external_input=n.user_defined,
originating_functions=n.originating_functions,
)


class Driver(object):
Expand Down Expand Up @@ -389,10 +405,7 @@ def list_available_variables(self) -> List[Variable]:
:return: list of available variables (i.e. outputs).
"""
return [
Variable(node.name, node.type, node.tags, node.user_defined)
for node in self.graph.get_nodes()
]
return [Variable.from_node(n) for n in self.graph.get_nodes()]

@capture_function_usage
def display_all_functions(
Expand Down Expand Up @@ -477,10 +490,7 @@ def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
in function names.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
return [
Variable(node.name, node.type, node.tags, node.user_defined)
for node in downstream_nodes
]
return [Variable.from_node(n) for n in downstream_nodes]

@capture_function_usage
def display_downstream_of(
Expand Down Expand Up @@ -522,9 +532,7 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
in function names.
"""
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [
Variable(node.name, node.type, node.tags, node.user_defined) for node in upstream_nodes
]
return [Variable.from_node(n) for n in upstream_nodes]


if __name__ == "__main__":
Expand Down
29 changes: 28 additions & 1 deletion hamilton/function_modifiers/base.py
Expand Up @@ -647,6 +647,33 @@ def get_node_decorators(
return defaults


def _add_original_function_to_nodes(fn: Callable, nodes: List[node.Node]) -> List[node.Node]:
"""Adds the original function to the nodes. We do this so that we can have appropriate metadata
on the function -- this is valuable to see if/how the function changes over time to manage node
versions, etc...
Note that this will add it so the "external" function is always last. They *should* correspond
to namespaces, but this is not
This is not mutating them, rather
copying them with the original function. If it gets slow we *can* mutate them, but
this is just another O(n) operation so I'm not concerned.
:param fn: The function to add
:param nodes: The nodes to add it to
:return: The nodes with the function added
"""
out = []
for node_ in nodes:
current_originating_functions = node_.originating_functions
new_originating_functions = (
current_originating_functions if current_originating_functions is not None else ()
) + (fn,)
out.append(node_.copy_with(originating_functions=new_originating_functions))
return out


def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]:
"""Gets a list of nodes from a function. This is meant to be an abstraction between the node
and the function that it implements. This will end up coordinating with the decorators we build
Expand Down Expand Up @@ -696,7 +723,7 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]
function_decorators = function_decorators[NodeDecorator.get_lifecycle_name()]
for node_decorator in function_decorators:
nodes = node_decorator.transform_dag(nodes, filter_config(config, node_decorator), fn)
return nodes
return _add_original_function_to_nodes(fn, nodes)
except InvalidDecoratorException as e:
raise InvalidDecoratorException(f"Invalid decorator {e} for function {fn.__name__}.") from e

Expand Down
14 changes: 13 additions & 1 deletion hamilton/node.py
@@ -1,7 +1,7 @@
import inspect
import typing
from enum import Enum
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

"""
Module that contains the primitive components of the graph.
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
input_types: Dict[str, Union[Type, Tuple[Type, DependencyType]]] = None,
tags: Dict[str, Any] = None,
namespace: Tuple[str, ...] = (),
originating_functions: Optional[Tuple[Callable, ...]] = None,
):
"""Constructor for our Node object.
Expand All @@ -71,6 +72,7 @@ def __init__(
self._depended_on_by = []
self._namespace = namespace
self._input_types = {}
self._originating_functions = originating_functions

if self._node_source == NodeSource.STANDARD:
if input_types is not None:
Expand Down Expand Up @@ -145,6 +147,16 @@ def depended_on_by(self) -> List["Node"]:
def tags(self) -> Dict[str, str]:
return self._tags

@property
def originating_functions(self) -> Optional[Tuple[Callable, ...]]:
"""Gives all functions from which this node was created. None if the data
is not available (it is user-defined, or we have not added it yet). Note that this can be
multiple in the case of subdags (the subdag function + the other function). In that case,
:return: A Tuple consisting of functions from which this node was created.
"""
return self._originating_functions

def add_tag(self, tag_name: str, tag_value: str):
self._tags[tag_name] = tag_value

Expand Down
10 changes: 10 additions & 0 deletions tests/function_modifiers/test_base.py
Expand Up @@ -139,3 +139,13 @@ def test_select_nodes_happy(
def test_select_nodes_sad(target: TargetType, nodes: Collection[node.Node]):
with pytest.raises(InvalidDecoratorException):
NodeTransformer.select_nodes(target, nodes)


def test_add_fn_metadata():
nodes_og = _create_node_set({"d": ["e"]})
nodes = base._add_original_function_to_nodes(test_add_fn_metadata, nodes_og)
nodes_with_fn_pointer = [
n.originating_functions for n in nodes if n.originating_functions is not None
]
assert len(nodes_with_fn_pointer) == len(nodes)
assert all([n.originating_functions == (test_add_fn_metadata,) for n in nodes])
13 changes: 11 additions & 2 deletions tests/test_hamilton_driver.py
Expand Up @@ -48,7 +48,7 @@ def test_driver_cycles_execute_recursion_error():
dr.execute(["C"], inputs={"b": 2, "c": 2})


def test_driver_variables():
def test_driver_variables_exposes_tags():
dr = Driver({}, tests.resources.tagging)
tags = {var.name: var.tags for var in dr.list_available_variables()}
assert tags["a"] == {"module": "tests.resources.tagging", "test": "a"}
Expand All @@ -57,13 +57,22 @@ def test_driver_variables():
assert tags["d"] == {"module": "tests.resources.tagging"}


def test_driver_external_input():
def test_driver_variables_external_input():
dr = Driver({}, tests.resources.very_simple_dag)
input_types = {var.name: var.is_external_input for var in dr.list_available_variables()}
assert input_types["a"] is True
assert input_types["b"] is False


def test_driver_variables_exposes_original_function():
dr = Driver({}, tests.resources.very_simple_dag)
originating_functions = {
var.name: var.originating_functions for var in dr.list_available_variables()
}
assert originating_functions["b"] == (tests.resources.very_simple_dag.b,)
assert originating_functions["a"] is None


@mock.patch("hamilton.telemetry.send_event_json")
def test_capture_constructor_telemetry_disabled(send_event_json):
"""Tests that we don't do anything if telemetry is disabled."""
Expand Down

0 comments on commit 415f09e

Please sign in to comment.