Skip to content

Commit

Permalink
Add a cache for weaviate client (#35983)
Browse files Browse the repository at this point in the history
* Add a cache for weaviate client

While working on another issue, I realized how often I had to call get_conn.
So instead of depreccating this, we can use it as a cache within the code so we
don't connect everytime a method is called.

* change cache to be on _conn
  • Loading branch information
ephraimbuddy authored Dec 1, 2023
1 parent 4117f1b commit 8be03c9
Showing 1 changed file with 15 additions and 8 deletions.
23 changes: 15 additions & 8 deletions airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import warnings
from functools import cached_property
from typing import Any

from weaviate import Client as WeaviateClient
Expand Down Expand Up @@ -94,18 +95,24 @@ def get_conn(self) -> WeaviateClient:
url=url, auth_client_secret=auth_client_secret, additional_headers=additional_headers
)

@cached_property
def conn(self) -> WeaviateClient:
"""Returns a Weaviate client."""
return self.get_conn()

def get_client(self) -> WeaviateClient:
"""Returns a Weaviate client."""
# Keeping this for backwards compatibility
warnings.warn(
"The `get_client` method has been renamed to `get_conn`",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return self.get_conn()
return self.conn

def test_connection(self) -> tuple[bool, str]:
try:
client = self.get_client()
client = self.conn
client.schema.get()
return True, "Connection established!"
except Exception as e:
Expand All @@ -114,7 +121,7 @@ def test_connection(self) -> tuple[bool, str]:

def create_class(self, class_json: dict[str, Any]) -> None:
"""Create a new class."""
client = self.get_client()
client = self.conn
client.schema.create_class(class_json)

def create_schema(self, schema_json: dict[str, Any]) -> None:
Expand All @@ -125,13 +132,13 @@ def create_schema(self, schema_json: dict[str, Any]) -> None:
:param schema_json: The schema to create
"""
client = self.get_client()
client = self.conn
client.schema.create(schema_json)

def batch_data(
self, class_name: str, data: list[dict[str, Any]], batch_config_params: dict[str, Any] | None = None
) -> None:
client = self.get_client()
client = self.conn
if not batch_config_params:
batch_config_params = {}
client.batch.configure(**batch_config_params)
Expand All @@ -147,7 +154,7 @@ def batch_data(

def delete_class(self, class_name: str) -> None:
"""Delete an existing class."""
client = self.get_client()
client = self.conn
client.schema.delete_class(class_name)

def query_with_vector(
Expand All @@ -166,7 +173,7 @@ def query_with_vector(
external vectorizer. Weaviate then converts this into a vector through the inference API
(OpenAI in this particular example) and uses that vector as the basis for a vector search.
"""
client = self.get_client()
client = self.conn
results: dict[str, dict[Any, Any]] = (
client.query.get(class_name, properties[0])
.with_near_vector({"vector": embeddings, "certainty": certainty})
Expand All @@ -185,7 +192,7 @@ def query_without_vector(
weaviate with a query search_text. Weaviate then converts this into a vector through the inference
API (OpenAI in this particular example) and uses that vector as the basis for a vector search.
"""
client = self.get_client()
client = self.conn
results: dict[str, dict[Any, Any]] = (
client.query.get(class_name, properties[0])
.with_near_text({"concepts": [search_text]})
Expand Down

0 comments on commit 8be03c9

Please sign in to comment.