Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enables the materializers to depend on functions & driver variables #431

Merged
58 changes: 58 additions & 0 deletions hamilton/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# code in this module should no depend on much
from typing import Any, Callable, List, Optional, Set, Tuple, Union


def convert_output_value(
output_value: Union[str, Callable, Any], module_set: Set[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Converts output values that one can request into strings.

It checks that if it's a function, it's in the passed in module set.

:param output_value: the value we want to convert into a string. We don't annotate driver.Variable here for
import reasons.
:param module_set: the set of modules functions could come from.
:return: a tuple, (string value, string error). One or the other is returned, never both.
"""
if isinstance(output_value, str):
return output_value, None
elif hasattr(output_value, "name"):
return output_value.name, None
elif isinstance(output_value, Callable):
if output_value.__module__ in module_set:
return output_value.__name__, None
else:
return None, (
f"Function {output_value.__module__}.{output_value.__name__} is a function not "
f"in a "
f"module given to the materializer. Valid choices are {module_set}."
)
else:
return None, (
f"Materializer dependency {output_value} is not a string, a function, or a driver.Variable."
)


def convert_output_values(
output_values: List[Union[str, Callable, Any]], module_set: Set[str]
) -> List[str]:
"""Checks & converts outputs values to strings. This is used in building dependencies for the DAG.

:param output_values: the values to convert.
:param module_set: the modules any functions could come from.
:return: the final values
:raises ValueError: if there are values that can't be used/converted.
"""
final_values = []
errors = []
for final_var in output_values:
_val, _error = convert_output_value(final_var, module_set)
if _val:
final_values.append(_val)
if _error:
errors.append(_error)
if errors:
errors.sort()
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
raise ValueError(error_str)
return final_values
72 changes: 36 additions & 36 deletions hamilton/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import pandas as pd

from hamilton import common
from hamilton.execution import executors, graph_functions, grouping, state
from hamilton.io import materialization

Expand Down Expand Up @@ -419,31 +420,8 @@ def _create_final_vars(self, final_vars: List[Union[str, Callable, Variable]]) -
:param final_vars:
:return: list of strings in the order that final_vars was provided.
"""
_final_vars = []
errors = []
module_set = {_module.__name__ for _module in self.graph_modules}
for final_var in final_vars:
if isinstance(final_var, str):
_final_vars.append(final_var)
elif isinstance(final_var, Variable):
_final_vars.append(final_var.name)
elif isinstance(final_var, Callable):
if final_var.__module__ in module_set:
_final_vars.append(final_var.__name__)
else:
errors.append(
f"Function {final_var.__module__}.{final_var.__name__} is a function not "
f"in a "
f"module given to the driver. Valid choices are {module_set}."
)
else:
errors.append(
f"Final var {final_var} is not a string, a function, or a driver.Variable."
)
if errors:
errors.sort()
error_str = f"{len(errors)} errors encountered:\n " + "\n ".join(errors)
raise ValueError(error_str)
_module_set = {_module.__name__ for _module in self.graph_modules}
_final_vars = common.convert_output_values(final_vars, _module_set)
return _final_vars

def capture_execute_telemetry(
Expand Down Expand Up @@ -872,6 +850,7 @@ def visualize_path_between(
except ImportError as e:
logger.warning(f"Unable to import {e}", exc_info=True)

@capture_function_usage
def materialize(
self,
*materializers: materialization.MaterializerFactory,
Expand Down Expand Up @@ -997,24 +976,42 @@ def materialize(
:param additional_vars: Additional variables to return from the graph
:param overrides: Overrides to pass to execution
:param inputs: Inputs to pass to execution
:return: Tuple[Materialization metadata, additional_vars result]
:return: Tuple[Materialization metadata|data, additional_vars result]
"""
if additional_vars is None:
additional_vars = []
function_graph = materialization.modify_graph(self.graph, materializers)
start_time = time.time()
run_successful = True
error = None

final_vars = self._create_final_vars(additional_vars)
materializer_vars = [materializer.id for materializer in materializers]
raw_results = self.graph_executor.execute(
function_graph,
final_vars=final_vars + materializer_vars,
overrides=overrides,
inputs=inputs,
)
materialization_output = {key: raw_results[key] for key in materializer_vars}
raw_results_output = {key: raw_results[key] for key in final_vars}
try:
module_set = {_module.__name__ for _module in self.graph_modules}
materializers = [m.sanitize_dependencies(module_set) for m in materializers]
function_graph = materialization.modify_graph(self.graph, materializers)
raw_results = self.graph_executor.execute(
function_graph,
final_vars=final_vars + materializer_vars,
overrides=overrides,
inputs=inputs,
)
materialization_output = {key: raw_results[key] for key in materializer_vars}
raw_results_output = {key: raw_results[key] for key in final_vars}

return materialization_output, raw_results_output
return materialization_output, raw_results_output
except Exception as e:
run_successful = False
logger.error(SLACK_ERROR_MESSAGE)
error = telemetry.sanitize_error(*sys.exc_info())
raise e
finally:
duration = time.time() - start_time
self.capture_execute_telemetry(
error, final_vars + materializer_vars, inputs, overrides, run_successful, duration
)

@capture_function_usage
def visualize_materialization(
self,
*materializers: materialization.MaterializerFactory,
Expand All @@ -1039,6 +1036,9 @@ def visualize_materialization(
"""
if additional_vars is None:
additional_vars = []

module_set = {_module.__name__ for _module in self.graph_modules}
materializers = [m.sanitize_dependencies(module_set) for m in materializers]
function_graph = materialization.modify_graph(self.graph, materializers)
_final_vars = self._create_final_vars(additional_vars) + [
materializer.id for materializer in materializers
Expand Down
18 changes: 15 additions & 3 deletions hamilton/function_modifiers/adapters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import logging
import typing
from typing import Any, Callable, Collection, Dict, List, Tuple, Type
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type

from hamilton import node
from hamilton.function_modifiers.base import (
Expand All @@ -17,6 +18,8 @@
from hamilton.node import DependencyType
from hamilton.registry import LOADER_REGISTRY, SAVER_REGISTRY

logger = logging.getLogger(__name__)


class AdapterFactory:
"""Factory for data loaders. This handles the fact that we pass in source(...) and value(...)
Expand Down Expand Up @@ -93,16 +96,25 @@ def resolve_kwargs(kwargs: Dict[str, Any]) -> Tuple[Dict[str, str], Dict[str, An

def resolve_adapter_class(
type_: Type[Type], loader_classes: List[Type[AdapterCommon]]
) -> Type[AdapterCommon]:
) -> Optional[Type[AdapterCommon]]:
"""Resolves the loader class for a function. This will return the most recently
registered loader class that applies to the injection type, hence the reversed order.

:param fn: Function to inject the loaded data into.
:return: The loader class to use.
"""
applicable_adapters: List[Type[AdapterCommon]] = []
for loader_cls in reversed(loader_classes):
if loader_cls.applies_to(type_):
return loader_cls
applicable_adapters.append(loader_cls)
if len(applicable_adapters) > 0:
if len(applicable_adapters) > 1:
logger.warning(
f"More than one applicable adapter detected for {type_}. "
f"Using the last one registered {applicable_adapters[0]}."
)
return applicable_adapters[0]
return None


class LoadFromDecorator(NodeInjector):
Expand Down
9 changes: 6 additions & 3 deletions hamilton/htypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,18 @@ def custom_subclass_check(requested_type: Type, param_type: Type):

We will likely need to revisit this in the future (perhaps integrate with graphadapter?)

:param requested_type: Candidate subclass
:param param_type: Type of parameter to check
:return: Whether or not this is a valid subclass.
:param requested_type: Candidate subclass.
:param param_type: Type of parameter to check against.
:return: Whether or not requested_type is a valid subclass of param_type.
"""
# handles case when someone is using primitives and generics
requested_origin_type = requested_type
param_type, _ = get_type_information(param_type)
param_origin_type = param_type
has_generic = False
if param_type == Any:
# any type is a valid subclass of Any.
return True
if _safe_subclass(requested_type, param_type):
return True
if typing_inspect.is_union_type(param_type):
Expand Down
55 changes: 47 additions & 8 deletions hamilton/io/data_adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@ def applicable_types(cls) -> Collection[Type]:
pass

@classmethod
@abc.abstractmethod
def applies_to(cls, type_: Type[Type]) -> bool:
"""Tells whether or not this data loader can load to a specific type.
For instance, a CSV data loader might be able to load to a dataframe,
a json, but not an integer.
"""Tells whether or not this adapter applies to the given type.

Note: you need to understand the edge direction to properly determine applicability.
For loading data, the loader type needs to be a subclass of the type being loaded into.
For saving data, the saver type needs to be a superclass of the type being passed in.

This is a classmethod as it will be easier to validate, and we have to
construct this, delayed, with a factory.

:param type_: Candidate type
:return: True if this data loader can load to the type, False otherwise.
:return: True if this adapter can be used with that type, False otherwise.
"""
for load_to in cls.applicable_types():
if custom_subclass_check(load_to, type_):
return True
return False
pass

@classmethod
@abc.abstractmethod
Expand Down Expand Up @@ -137,6 +137,26 @@ def load_data(self, type_: Type[Type]) -> Tuple[Type, Dict[str, Any]]:
def can_load(cls) -> bool:
return True

@classmethod
def applies_to(cls, type_: Type[Type]) -> bool:
"""Tells whether or not this data loader can load to a specific type.
For instance, a CSV data loader might be able to load to a dataframe,
a json, but not an integer.

I.e. is the adapter type a subclass of the passed in type?

This is a classmethod as it will be easier to validate, and we have to
construct this, delayed, with a factory.

:param type_: Candidate type
:return: True if this data loader can load to the type, False otherwise.
"""
for load_to in cls.applicable_types():
# is the adapter type `load_to` a subclass of `type_` ?
if custom_subclass_check(load_to, type_):
return True
return False


class DataSaver(AdapterCommon, abc.ABC):
"""Base class for data savers. Data savers are used to save data to a data source.
Expand Down Expand Up @@ -165,3 +185,22 @@ def save_data(self, data: Any) -> Dict[str, Any]:
@classmethod
def can_save(cls) -> bool:
return True

@classmethod
def applies_to(cls, type_: Type[Type]) -> bool:
"""Tells whether or not this data saver can ingest a specific type to save it.

I.e. is the adapter type a superclass of the passed in type?

This is a classmethod as it will be easier to validate, and we have to
construct this, delayed, with a factory.

:param type_: Candidate type
:return: True if this data saver can handle to the type, False otherwise.
"""
for save_to in cls.applicable_types():
# is the adapter type `save_to` a superclass of `type_` ?
# i.e. is `type_` a subclass of `save_to` ?
if custom_subclass_check(type_, save_to):
return True
return False
Loading