Skip to content

Commit

Permalink
Add retry mechanism and dataframe support for WeaviateIngestOperator (#…
Browse files Browse the repository at this point in the history
…36085)

* Add retry and dataframe support

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
utkarsharma2 and uranusjr committed Dec 8, 2023
1 parent 4824ca7 commit a8333b7
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 26 deletions.
60 changes: 51 additions & 9 deletions airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

from __future__ import annotations

import contextlib
import json
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Any, Dict, List, cast

import requests
from tenacity import Retrying, retry_if_exception, stop_after_attempt
from weaviate import Client as WeaviateClient
from weaviate.auth import AuthApiKey, AuthBearerToken, AuthClientCredentials, AuthClientPassword
from weaviate.exceptions import ObjectAlreadyExistsException
Expand All @@ -30,7 +34,7 @@
from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from typing import Any
from typing import Sequence

import pandas as pd
from weaviate import ConsistencyLevel
Expand Down Expand Up @@ -144,22 +148,60 @@ def create_schema(self, schema_json: dict[str, Any]) -> None:
client = self.conn
client.schema.create(schema_json)

@staticmethod
def check_http_error_should_retry(exc: BaseException):
return isinstance(exc, requests.HTTPError) and not exc.response.ok

@staticmethod
def _convert_dataframe_to_list(data: list[dict[str, Any]] | pd.DataFrame) -> list[dict[str, Any]]:
"""Helper function to convert dataframe to list of dicts.
In scenario where Pandas isn't installed and we pass data as a list of dictionaries, importing
Pandas will fail, which is invalid. This function handles this scenario.
"""
with contextlib.suppress(ImportError):
import pandas

if isinstance(data, pandas.DataFrame):
data = json.loads(data.to_json(orient="records"))
return cast(List[Dict[str, Any]], data)

def batch_data(
self, class_name: str, data: list[dict[str, Any]], batch_config_params: dict[str, Any] | None = None
self,
class_name: str,
data: list[dict[str, Any]] | pd.DataFrame,
batch_config_params: dict[str, Any] | None = None,
vector_col: str = "Vector",
retry_attempts_per_object: int = 5,
) -> None:
"""
Add multiple objects or object references at once into weaviate.
:param class_name: The name of the class that objects belongs to.
:param data: list or dataframe of objects we want to add.
:param batch_config_params: dict of batch configuration option.
.. seealso:: `batch_config_params options <https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
:param vector_col: name of the column containing the vector.
:param retry_attempts_per_object: number of time to try in case of failure before giving up.
"""
client = self.conn
if not batch_config_params:
batch_config_params = {}
client.batch.configure(**batch_config_params)
data = self._convert_dataframe_to_list(data)
with client.batch as batch:
# Batch import all data
for index, data_obj in enumerate(data):
self.log.debug("importing data: %s", index + 1)
vector = data_obj.pop("Vector", None)
if vector is not None:
batch.add_data_object(data_obj, class_name, vector=vector)
else:
batch.add_data_object(data_obj, class_name)
for attempt in Retrying(
stop=stop_after_attempt(retry_attempts_per_object),
retry=retry_if_exception(self.check_http_error_should_retry),
):
with attempt:
self.log.debug(
"Attempt %s of importing data: %s", attempt.retry_state.attempt_number, index + 1
)
vector = data_obj.pop(vector_col, None)
batch.add_data_object(data_obj, class_name, vector=vector)

def delete_class(self, class_name: str) -> None:
"""Delete an existing class."""
Expand Down
44 changes: 35 additions & 9 deletions airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

from __future__ import annotations

import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.models import BaseOperator
from airflow.providers.weaviate.hooks.weaviate import WeaviateHook

if TYPE_CHECKING:
import pandas as pd

from airflow.utils.context import Context


Expand All @@ -35,14 +39,16 @@ class WeaviateIngestOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:WeaviateIngestOperator`
Operator that accepts input json to generate embeddings on or accepting provided custom vectors
and store them in the Weaviate class.
Operator that accepts input json or pandas dataframe to generate embeddings on or accepting provided
custom vectors and store them in the Weaviate class.
:param conn_id: The Weaviate connection.
:param class_name: The Weaviate class to be used for storing the data objects into.
:param input_json: The JSON representing Weaviate data objects to generate embeddings on (or provides
custom vectors) and store them in the Weaviate class. Either input_json or input_callable should be
provided.
:param input_data: The list of dicts or pandas dataframe representing Weaviate data objects to generate
embeddings on (or provides custom vectors) and store them in the Weaviate class.
:param input_json: (Deprecated) The JSON representing Weaviate data objects to generate embeddings on (or provides
custom vectors) and store them in the Weaviate class.
:param vector_col: key/column name in which the vectors are stored.
"""

template_fields: Sequence[str] = ("input_json",)
Expand All @@ -51,21 +57,41 @@ def __init__(
self,
conn_id: str,
class_name: str,
input_json: list[dict[str, Any]],
input_json: list[dict[str, Any]] | pd.DataFrame | None = None,
input_data: list[dict[str, Any]] | pd.DataFrame | None = None,
vector_col: str = "Vector",
**kwargs: Any,
) -> None:
self.batch_params = kwargs.pop("batch_params", {})
self.hook_params = kwargs.pop("hook_params", {})

super().__init__(**kwargs)
self.class_name = class_name
self.conn_id = conn_id
self.input_json = input_json
self.vector_col = vector_col

if input_data is not None:
self.input_data = input_data
elif input_json is not None:
warnings.warn(
"Passing 'input_json' to WeaviateIngestOperator is deprecated and"
" you should use 'input_data' instead",
AirflowProviderDeprecationWarning,
)
self.input_data = input_json
else:
raise TypeError("Either input_json or input_data is required")

@cached_property
def hook(self) -> WeaviateHook:
"""Return an instance of the WeaviateHook."""
return WeaviateHook(conn_id=self.conn_id, **self.hook_params)

def execute(self, context: Context) -> None:
self.log.debug("Input json: %s", self.input_json)
self.hook.batch_data(self.class_name, self.input_json, **self.batch_params)
self.log.debug("Input data: %s", self.input_data)
self.hook.batch_data(
self.class_name,
self.input_data,
**self.batch_params,
vector_col=self.vector_col,
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ into the database.
Using the Operator
^^^^^^^^^^^^^^^^^^

The WeaviateIngestOperator requires the ``input_text`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
The WeaviateIngestOperator requires the ``input_data`` as an input to the operator. Use the ``conn_id`` parameter to specify the Weaviate connection to use to
connect to your account.

An example using the operator to ingest data with custom vectors retrieved from XCOM:
Expand Down
31 changes: 27 additions & 4 deletions tests/providers/weaviate/hooks/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from unittest import mock
from unittest.mock import MagicMock, Mock, patch

import pandas as pd
import pytest
import requests
from weaviate import ObjectAlreadyExistsException

from airflow.models import Connection
Expand Down Expand Up @@ -404,7 +406,15 @@ def test_create_schema(weaviate_hook):
mock_client.schema.create.assert_called_once_with(test_schema_json)


def test_batch_data(weaviate_hook):
@pytest.mark.parametrize(
argnames=["data", "expected_length"],
argvalues=[
([{"name": "John"}, {"name": "Jane"}], 2),
(pd.DataFrame.from_dict({"name": ["John", "Jane"]}), 2),
],
ids=("data as list of dicts", "data as dataframe"),
)
def test_batch_data(data, expected_length, weaviate_hook):
"""
Test the batch_data method of WeaviateHook.
"""
Expand All @@ -414,12 +424,25 @@ def test_batch_data(weaviate_hook):

# Define test data
test_class_name = "TestClass"
test_data = [{"name": "John"}, {"name": "Jane"}]

# Test the batch_data method
weaviate_hook.batch_data(test_class_name, test_data)
weaviate_hook.batch_data(test_class_name, data)

# Assert that the batch_data method was called with the correct arguments
mock_client.batch.configure.assert_called_once()
mock_batch_context = mock_client.batch.__enter__.return_value
assert mock_batch_context.add_data_object.call_count == len(test_data)
assert mock_batch_context.add_data_object.call_count == expected_length


@patch("airflow.providers.weaviate.hooks.weaviate.WeaviateHook.get_conn")
def test_batch_data_retry(get_conn, weaviate_hook):
"""Test to ensure retrying working as expected"""
data = [{"name": "chandler"}, {"name": "joey"}, {"name": "ross"}]
response = requests.Response()
response.status_code = 429
error = requests.exceptions.HTTPError()
error.response = response
side_effect = [None, error, None, error, None]
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect
weaviate_hook.batch_data("TestClass", data)
assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect)
8 changes: 5 additions & 3 deletions tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def operator(self):
def test_constructor(self, operator):
assert operator.conn_id == "weaviate_conn"
assert operator.class_name == "my_class"
assert operator.input_json == {"data": "sample_data"}
assert operator.input_data == {"data": "sample_data"}
assert operator.batch_params == {}
assert operator.hook_params == {}

Expand All @@ -46,5 +46,7 @@ def test_execute_with_input_json(self, mock_log, operator):

operator.execute(context=None)

operator.hook.batch_data.assert_called_once_with("my_class", {"data": "sample_data"}, **{})
mock_log.debug.assert_called_once_with("Input json: %s", {"data": "sample_data"})
operator.hook.batch_data.assert_called_once_with(
"my_class", {"data": "sample_data"}, vector_col="Vector", **{}
)
mock_log.debug.assert_called_once_with("Input data: %s", {"data": "sample_data"})

0 comments on commit a8333b7

Please sign in to comment.