Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the original function pointer to node/variable objects #165

Merged
merged 1 commit into from May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 19 additions & 11 deletions hamilton/driver.py
Expand Up @@ -71,6 +71,22 @@ class Variable:
type: typing.Type
tags: Dict[str, str] = field(default_factory=dict)
is_external_input: bool = field(default=False)
originating_functions: Optional[Tuple[Callable, ...]] = None

@staticmethod
def from_node(n: node.Node) -> "Variable":
"""Creates a Variable from a Node.

:param n: Node to create the Variable from.
:return: Variable created from the Node.
"""
return Variable(
name=n.name,
type=n.type,
tags=n.tags,
is_external_input=n.user_defined,
originating_functions=n.originating_functions,
)


class Driver(object):
Expand Down Expand Up @@ -389,10 +405,7 @@ def list_available_variables(self) -> List[Variable]:

:return: list of available variables (i.e. outputs).
"""
return [
Variable(node.name, node.type, node.tags, node.user_defined)
for node in self.graph.get_nodes()
]
return [Variable.from_node(n) for n in self.graph.get_nodes()]

@capture_function_usage
def display_all_functions(
Expand Down Expand Up @@ -477,10 +490,7 @@ def what_is_downstream_of(self, *node_names: str) -> List[Variable]:
in function names.
"""
downstream_nodes = self.graph.get_impacted_nodes(list(node_names))
return [
Variable(node.name, node.type, node.tags, node.user_defined)
for node in downstream_nodes
]
return [Variable.from_node(n) for n in downstream_nodes]

@capture_function_usage
def display_downstream_of(
Expand Down Expand Up @@ -522,9 +532,7 @@ def what_is_upstream_of(self, *node_names: str) -> List[Variable]:
in function names.
"""
upstream_nodes, _ = self.graph.get_upstream_nodes(list(node_names))
return [
Variable(node.name, node.type, node.tags, node.user_defined) for node in upstream_nodes
]
return [Variable.from_node(n) for n in upstream_nodes]


if __name__ == "__main__":
Expand Down
29 changes: 28 additions & 1 deletion hamilton/function_modifiers/base.py
Expand Up @@ -647,6 +647,33 @@ def get_node_decorators(
return defaults


def _add_original_function_to_nodes(fn: Callable, nodes: List[node.Node]) -> List[node.Node]:
"""Adds the original function to the nodes. We do this so that we can have appropriate metadata
on the function -- this is valuable to see if/how the function changes over time to manage node
versions, etc...

Note that this will add it so the "external" function is always last. They *should* correspond
to namespaces, but this is not

This is not mutating them, rather
copying them with the original function. If it gets slow we *can* mutate them, but
this is just another O(n) operation so I'm not concerned.


:param fn: The function to add
:param nodes: The nodes to add it to
:return: The nodes with the function added
"""
out = []
for node_ in nodes:
current_originating_functions = node_.originating_functions
new_originating_functions = (
current_originating_functions if current_originating_functions is not None else ()
) + (fn,)
out.append(node_.copy_with(originating_functions=new_originating_functions))
return out


def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]:
"""Gets a list of nodes from a function. This is meant to be an abstraction between the node
and the function that it implements. This will end up coordinating with the decorators we build
Expand Down Expand Up @@ -696,7 +723,7 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]
function_decorators = function_decorators[NodeDecorator.get_lifecycle_name()]
for node_decorator in function_decorators:
nodes = node_decorator.transform_dag(nodes, filter_config(config, node_decorator), fn)
return nodes
return _add_original_function_to_nodes(fn, nodes)
except InvalidDecoratorException as e:
raise InvalidDecoratorException(f"Invalid decorator {e} for function {fn.__name__}.") from e

Expand Down
14 changes: 13 additions & 1 deletion hamilton/node.py
@@ -1,7 +1,7 @@
import inspect
import typing
from enum import Enum
from typing import Any, Callable, Dict, List, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

"""
Module that contains the primitive components of the graph.
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(
input_types: Dict[str, Union[Type, Tuple[Type, DependencyType]]] = None,
tags: Dict[str, Any] = None,
namespace: Tuple[str, ...] = (),
originating_functions: Optional[Tuple[Callable, ...]] = None,
):
"""Constructor for our Node object.

Expand All @@ -71,6 +72,7 @@ def __init__(
self._depended_on_by = []
self._namespace = namespace
self._input_types = {}
self._originating_functions = originating_functions

if self._node_source == NodeSource.STANDARD:
if input_types is not None:
Expand Down Expand Up @@ -145,6 +147,16 @@ def depended_on_by(self) -> List["Node"]:
def tags(self) -> Dict[str, str]:
return self._tags

@property
def originating_functions(self) -> Optional[Tuple[Callable, ...]]:
"""Gives all functions from which this node was created. None if the data
is not available (it is user-defined, or we have not added it yet). Note that this can be
multiple in the case of subdags (the subdag function + the other function). In that case,

:return: A Tuple consisting of functions from which this node was created.
"""
return self._originating_functions

def add_tag(self, tag_name: str, tag_value: str):
self._tags[tag_name] = tag_value

Expand Down
10 changes: 10 additions & 0 deletions tests/function_modifiers/test_base.py
Expand Up @@ -139,3 +139,13 @@ def test_select_nodes_happy(
def test_select_nodes_sad(target: TargetType, nodes: Collection[node.Node]):
with pytest.raises(InvalidDecoratorException):
NodeTransformer.select_nodes(target, nodes)


def test_add_fn_metadata():
nodes_og = _create_node_set({"d": ["e"]})
nodes = base._add_original_function_to_nodes(test_add_fn_metadata, nodes_og)
nodes_with_fn_pointer = [
n.originating_functions for n in nodes if n.originating_functions is not None
]
assert len(nodes_with_fn_pointer) == len(nodes)
assert all([n.originating_functions == (test_add_fn_metadata,) for n in nodes])
13 changes: 11 additions & 2 deletions tests/test_hamilton_driver.py
Expand Up @@ -48,7 +48,7 @@ def test_driver_cycles_execute_recursion_error():
dr.execute(["C"], inputs={"b": 2, "c": 2})


def test_driver_variables():
def test_driver_variables_exposes_tags():
dr = Driver({}, tests.resources.tagging)
tags = {var.name: var.tags for var in dr.list_available_variables()}
assert tags["a"] == {"module": "tests.resources.tagging", "test": "a"}
Expand All @@ -57,13 +57,22 @@ def test_driver_variables():
assert tags["d"] == {"module": "tests.resources.tagging"}


def test_driver_external_input():
def test_driver_variables_external_input():
dr = Driver({}, tests.resources.very_simple_dag)
input_types = {var.name: var.is_external_input for var in dr.list_available_variables()}
assert input_types["a"] is True
assert input_types["b"] is False


def test_driver_variables_exposes_original_function():
dr = Driver({}, tests.resources.very_simple_dag)
originating_functions = {
var.name: var.originating_functions for var in dr.list_available_variables()
}
assert originating_functions["b"] == (tests.resources.very_simple_dag.b,)
assert originating_functions["a"] is None


@mock.patch("hamilton.telemetry.send_event_json")
def test_capture_constructor_telemetry_disabled(send_event_json):
"""Tests that we don't do anything if telemetry is disabled."""
Expand Down