Skip to content

Commit

Permalink
Fix WeaviateIngestOperator/WeaviateDocumentIngestOperator argumen…
Browse files Browse the repository at this point in the history
…ts in `MappedOperator` (#38402)
  • Loading branch information
Taragolis committed Mar 26, 2024
1 parent a3f7ddd commit f4bd0b3
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
30 changes: 20 additions & 10 deletions airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,12 @@ class WeaviateIngestOperator(BaseOperator):
:param class_name: The Weaviate class to be used for storing the data objects into.
:param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate
embeddings on (or provides custom vectors) and store them in the Weaviate class.
:param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on (or provides
custom vectors) and store them in the Weaviate class.
:param vector_col: key/column name in which the vectors are stored.
:param batch_params: Additional parameters for Weaviate batch configuration.
:param hook_params: Optional config params to be passed to the underlying hook.
Should match the desired hook constructor params.
:param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on
(or provides custom vectors) and store them in the Weaviate class.
"""

template_fields: Sequence[str] = ("input_json", "input_data")
Expand All @@ -57,16 +60,15 @@ def __init__(
self,
conn_id: str,
class_name: str,
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
vector_col: str = "Vector",
uuid_column: str = "id",
tenant: str | None = None,
batch_params: dict | None = None,
hook_params: dict | None = None,
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
**kwargs: Any,
) -> None:
self.batch_params = kwargs.pop("batch_params", {})
self.hook_params = kwargs.pop("hook_params", {})

super().__init__(**kwargs)
self.class_name = class_name
self.conn_id = conn_id
Expand All @@ -75,6 +77,9 @@ def __init__(
self.uuid_column = uuid_column
self.tenant = tenant
self.input_data = input_data
self.batch_params = batch_params or {}
self.hook_params = hook_params or {}

if (self.input_data is None) and (input_json is not None):
warnings.warn(
"Passing 'input_json' to WeaviateIngestOperator is deprecated and"
Expand Down Expand Up @@ -135,7 +140,8 @@ class WeaviateDocumentIngestOperator(BaseOperator):
:param batch_config_params: Additional parameters for Weaviate batch configuration.
:param tenant: The tenant to which the object will be added.
:param verbose: Flag to enable verbose output during the ingestion process.
:return: list of UUID which failed to create
:param hook_params: Optional config params to be passed to the underlying hook.
Should match the desired hook constructor params.
"""

template_fields: Sequence[str] = ("input_data",)
Expand All @@ -152,12 +158,10 @@ def __init__(
batch_config_params: dict | None = None,
tenant: str | None = None,
verbose: bool = False,
hook_params: dict | None = None,
**kwargs: Any,
) -> None:
self.hook_params = kwargs.pop("hook_params", {})

super().__init__(**kwargs)

self.conn_id = conn_id
self.input_data = input_data
self.class_name = class_name
Expand All @@ -168,13 +172,19 @@ def __init__(
self.batch_config_params = batch_config_params
self.tenant = tenant
self.verbose = verbose
self.hook_params = hook_params or {}

@cached_property
def hook(self) -> WeaviateHook:
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)

def execute(self, context: Context) -> list:
"""
Create or replace objects belonging to documents.
:return: List of UUID which failed to create
"""
self.log.debug("Total input objects : %s", len(self.input_data))
insertion_errors = self.hook.create_or_replace_document_objects(
data=self.input_data,
Expand Down
39 changes: 39 additions & 0 deletions tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import pytest

from airflow.utils.task_instance_session import set_current_task_instance_session

pytest.importorskip("weaviate")

from airflow.providers.weaviate.operators.weaviate import ( # noqa: E402
Expand Down Expand Up @@ -78,6 +80,25 @@ def test_templates(self, create_task_instance_of_operator):
assert dag_id == ti.task.input_json
assert dag_id == ti.task.input_data

@pytest.mark.db_test
def test_partial_batch_hook_params(self, dag_maker, session):
with dag_maker(dag_id="test_partial_batch_hook_params", session=session):
WeaviateIngestOperator.partial(
task_id="fake-task-id",
conn_id="weaviate_conn",
class_name="FooBar",
batch_params={"spam": "egg"},
hook_params={"baz": "biz"},
).expand(input_data=[{}, {}])

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
ti.render_templates()
assert ti.task.batch_params == {"spam": "egg"}
assert ti.task.hook_params == {"baz": "biz"}


class TestWeaviateDocumentIngestOperator:
@pytest.fixture
Expand Down Expand Up @@ -123,3 +144,21 @@ def test_execute_with_input_json(self, mock_log, operator):
verbose=False,
)
mock_log.debug.assert_called_once_with("Total input objects : %s", len([{"data": "sample_data"}]))

@pytest.mark.db_test
def test_partial_hook_params(self, dag_maker, session):
with dag_maker(dag_id="test_partial_hook_params", session=session):
WeaviateDocumentIngestOperator.partial(
task_id="fake-task-id",
conn_id="weaviate_conn",
class_name="FooBar",
document_column="spam-egg",
hook_params={"baz": "biz"},
).expand(input_data=[{}, {}])

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
ti.render_templates()
assert ti.task.hook_params == {"baz": "biz"}

0 comments on commit f4bd0b3

Please sign in to comment.