Skip to content

Commit

Permalink
Remove 'insertion_errors' as required argument (#36435)
Browse files Browse the repository at this point in the history
  • Loading branch information
utkarsharma2 committed Dec 26, 2023
1 parent e393889 commit 8850715
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 7 deletions.
4 changes: 1 addition & 3 deletions airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def batch_data(
self,
class_name: str,
data: list[dict[str, Any]] | pd.DataFrame,
insertion_errors: list,
batch_config_params: dict[str, Any] | None = None,
vector_col: str = "Vector",
uuid_col: str = "id",
Expand All @@ -397,7 +396,6 @@ def batch_data(
:param class_name: The name of the class that objects belongs to.
:param data: list or dataframe of objects we want to add.
:param insertion_errors: list to hold errors while inserting.
:param batch_config_params: dict of batch configuration option.
.. seealso:: `batch_config_params options <https://weaviate-python-client.readthedocs.io/en/v3.25.3/weaviate.batch.html#weaviate.batch.Batch.configure>`__
:param vector_col: name of the column containing the vector.
Expand All @@ -408,6 +406,7 @@ def batch_data(
data = self._convert_dataframe_to_list(data)
total_results = 0
error_results = 0
insertion_errors: list = []

def _process_batch_errors(
results: list,
Expand Down Expand Up @@ -1070,7 +1069,6 @@ def create_or_replace_document_objects(
insertion_errors = self.batch_data(
class_name=class_name,
data=data,
insertion_errors=insertion_errors,
batch_config_params=batch_config_params,
vector_col=vector_column,
uuid_col=uuid_column,
Expand Down
1 change: 0 additions & 1 deletion airflow/providers/weaviate/operators/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def execute(self, context: Context) -> list:
data=self.input_data,
batch_config_params=self.batch_params,
vector_col=self.vector_col,
insertion_errors=insertion_errors,
uuid_col=self.uuid_column,
tenant=self.tenant,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/weaviate/hooks/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def test_batch_data(data, expected_length, weaviate_hook):
test_class_name = "TestClass"

# Test the batch_data method
weaviate_hook.batch_data(test_class_name, data, insertion_errors=[])
weaviate_hook.batch_data(test_class_name, data)

# Assert that the batch_data method was called with the correct arguments
mock_client.batch.configure.assert_called_once()
Expand All @@ -446,7 +446,7 @@ def test_batch_data_retry(get_conn, weaviate_hook):
error.response = response
side_effect = [None, error, None, error, None]
get_conn.return_value.batch.__enter__.return_value.add_data_object.side_effect = side_effect
weaviate_hook.batch_data("TestClass", data, insertion_errors=[])
weaviate_hook.batch_data("TestClass", data)
assert get_conn.return_value.batch.__enter__.return_value.add_data_object.call_count == len(side_effect)


Expand Down
1 change: 0 additions & 1 deletion tests/providers/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def test_execute_with_input_json(self, mock_log, operator):
data=[{"data": "sample_data"}],
batch_config_params={},
vector_col="Vector",
insertion_errors=[],
uuid_col="id",
tenant=None,
)
Expand Down

0 comments on commit 8850715

Please sign in to comment.