Skip to content

Commit

Permalink
Add utils methods in pinecone provider (#35502)
Browse files Browse the repository at this point in the history
* Add missing methods to pinecone provider

* Fix static

* Add words to spelling-wordlist.txt

* Update airflow/providers/pinecone/hooks/pinecone.py

Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>

* Update airflow/providers/pinecone/hooks/pinecone.py

Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>

---------

Co-authored-by: Pankaj Singh <98807258+pankajastro@users.noreply.github.com>
  • Loading branch information
utkarsharma2 and pankajastro committed Nov 7, 2023
1 parent b1ee724 commit 65020ee
Show file tree
Hide file tree
Showing 3 changed files with 308 additions and 2 deletions.
219 changes: 217 additions & 2 deletions airflow/providers/pinecone/hooks/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
"""Hook for Pinecone."""
from __future__ import annotations

import itertools
from typing import TYPE_CHECKING, Any

import pinecone

from airflow.hooks.base import BaseHook

if TYPE_CHECKING:
from pinecone.core.client.models import UpsertResponse
from pinecone.core.client.model.sparse_values import SparseValues
from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse


class PineconeHook(BaseHook):
Expand Down Expand Up @@ -86,11 +88,16 @@ def get_conn(self) -> None:

def test_connection(self) -> tuple[bool, str]:
try:
pinecone.list_indexes()
self.list_indexes()
return True, "Connection established"
except Exception as e:
return False, str(e)

@staticmethod
def list_indexes() -> Any:
"""Retrieve a list of all indexes in your project."""
return pinecone.list_indexes()

@staticmethod
def upsert(
index_name: str,
Expand Down Expand Up @@ -126,3 +133,211 @@ def upsert(
show_progress=show_progress,
**kwargs,
)

@staticmethod
def create_index(
index_name: str,
dimension: int,
index_type: str | None = "approximated",
metric: str | None = "cosine",
replicas: int | None = 1,
shards: int | None = 1,
pods: int | None = 1,
pod_type: str | None = "p1",
index_config: dict[str, str] | None = None,
metadata_config: dict[str, str] | None = None,
source_collection: str | None = "",
timeout: int | None = None,
) -> None:
"""
Create a new index.
.. seealso:: https://docs.pinecone.io/reference/create_index/
:param index_name: The name of the index to create.
:param dimension: the dimension of vectors that would be inserted in the index
:param index_type: type of index, one of {"approximated", "exact"}, defaults to "approximated".
:param metric: type of metric used in the vector index, one of {"cosine", "dotproduct", "euclidean"}
:param replicas: the number of replicas, defaults to 1.
:param shards: the number of shards per index, defaults to 1.
:param pods: Total number of pods to be used by the index. pods = shard*replicas
:param pod_type: the pod type to be used for the index. can be one of p1 or s1.
:param index_config: Advanced configuration options for the index
:param metadata_config: Configuration related to the metadata index
:param source_collection: Collection name to create the index from
:param timeout: Timeout for wait until index gets ready.
"""
pinecone.create_index(
name=index_name,
timeout=timeout,
index_type=index_type,
dimension=dimension,
metric=metric,
pods=pods,
replicas=replicas,
shards=shards,
pod_type=pod_type,
metadata_config=metadata_config,
source_collection=source_collection,
index_config=index_config,
)

@staticmethod
def describe_index(index_name: str) -> Any:
"""
Retrieve information about a specific index.
:param index_name: The name of the index to describe.
"""
return pinecone.describe_index(name=index_name)

@staticmethod
def delete_index(index_name: str, timeout: int | None = None) -> None:
"""
Delete a specific index.
:param index_name: the name of the index.
:param timeout: Timeout for wait until index gets ready.
"""
pinecone.delete_index(name=index_name, timeout=timeout)

@staticmethod
def configure_index(index_name: str, replicas: int | None = None, pod_type: str | None = "") -> None:
"""
Changes current configuration of the index.
:param index_name: The name of the index to configure.
:param replicas: The new number of replicas.
:param pod_type: the new pod_type for the index.
"""
pinecone.configure_index(name=index_name, replicas=replicas, pod_type=pod_type)

@staticmethod
def create_collection(collection_name: str, index_name: str) -> None:
"""
Create a new collection from a specified index.
:param collection_name: The name of the collection to create.
:param index_name: The name of the source index.
"""
pinecone.create_collection(name=collection_name, source=index_name)

@staticmethod
def delete_collection(collection_name: str) -> None:
"""
Delete a specific collection.
:param collection_name: The name of the collection to delete.
"""
pinecone.delete_collection(collection_name)

@staticmethod
def describe_collection(collection_name: str) -> Any:
"""
Retrieve information about a specific collection.
:param collection_name: The name of the collection to describe.
"""
return pinecone.describe_collection(collection_name)

@staticmethod
def list_collections() -> Any:
"""Retrieve a list of all collections in the current project."""
return pinecone.list_collections()

@staticmethod
def query_vector(
index_name: str,
vector: list[Any],
query_id: str | None = None,
top_k: int = 10,
namespace: str | None = None,
query_filter: dict[str, str | float | int | bool | list[Any] | dict[Any, Any]] | None = None,
include_values: bool | None = None,
include_metadata: bool | None = None,
sparse_vector: SparseValues | dict[str, list[float] | list[int]] | None = None,
) -> QueryResponse:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
API reference: https://docs.pinecone.io/reference/query
:param index_name: The name of the index to query.
:param vector: The query vector.
:param query_id: The unique ID of the vector to be used as a query vector.
:param top_k: The number of results to return.
:param namespace: The namespace to fetch vectors from. If not specified, the default namespace is used.
:param query_filter: The filter to apply. See https://www.pinecone.io/docs/metadata-filtering/
:param include_values: Whether to include the vector values in the result.
:param include_metadata: Indicates whether metadata is included in the response as well as the ids.
:param sparse_vector: sparse values of the query vector. Expected to be either a SparseValues object or a dict
of the form: {'indices': List[int], 'values': List[float]}, where the lists each have the same length.
"""
index = pinecone.Index(index_name)
return index.query(
vector=vector,
id=query_id,
top_k=top_k,
namespace=namespace,
filter=query_filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
)

@staticmethod
def _chunks(iterable: list[Any], batch_size: int = 100) -> Any:
"""Helper function to break an iterable into chunks of size batch_size."""
it = iter(iterable)
chunk = tuple(itertools.islice(it, batch_size))
while chunk:
yield chunk
chunk = tuple(itertools.islice(it, batch_size))

def upsert_data_async(
self,
index_name: str,
data: list[tuple[Any]],
async_req: bool = False,
pool_threads: int | None = None,
) -> None | list[Any]:
"""
Upserts (insert/update) data into the Pinecone index.
:param index_name: Name of the index.
:param data: List of tuples to be upserted. Each tuple is of form (id, vector, metadata).
Metadata is optional.
:param async_req: If True, upsert operations will be asynchronous.
:param pool_threads: Number of threads for parallel upserting. If async_req is True, this must be provided.
"""
responses = []
with pinecone.Index(index_name, pool_threads=pool_threads) as index:
if async_req and pool_threads:
async_results = [index.upsert(vectors=chunk, async_req=True) for chunk in self._chunks(data)]
responses = [async_result.get() for async_result in async_results]
else:
for chunk in self._chunks(data):
response = index.upsert(vectors=chunk)
responses.append(response)
return responses

@staticmethod
def describe_index_stats(
index_name: str,
stats_filter: dict[str, str | float | int | bool | list[Any] | dict[Any, Any]] | None = None,
**kwargs: Any,
) -> DescribeIndexStatsResponse:
"""
Describes the index statistics.
Returns statistics about the index's contents. For example: The vector count per
namespace and the number of dimensions.
API reference: https://docs.pinecone.io/reference/describe_index_stats_post
:param index_name: Name of the index.
:param stats_filter: If this parameter is present, the operation only returns statistics for vectors that
satisfy the filter. See https://www.pinecone.io/docs/metadata-filtering/
"""
index = pinecone.Index(index_name)
return index.describe_index_stats(filter=stats_filter, **kwargs)
2 changes: 2 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ dogstatsd
donot
Dont
DOS'ing
dotproduct
DownloadReportV
downscaling
downstreams
Expand Down Expand Up @@ -1668,6 +1669,7 @@ updateonly
Upsert
upsert
upserted
upserting
upserts
Upsight
upstreams
Expand Down
89 changes: 89 additions & 0 deletions tests/providers/pinecone/hooks/test_pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,92 @@ def test_upsert(self, mock_index):
mock_index.return_value.upsert = mock_upsert
self.pinecone_hook.upsert(self.index_name, data)
mock_upsert.assert_called_once_with(vectors=data, namespace="", batch_size=None, show_progress=True)

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_indexes")
def test_list_indexes(self, mock_list_indexes):
"""Test that the list_indexes method of PineconeHook is called correctly."""
self.pinecone_hook.list_indexes()
mock_list_indexes.assert_called_once()

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_index")
def test_create_index(self, mock_create_index):
"""Test that the create_index method of PineconeHook is called with correct arguments."""
self.pinecone_hook.create_index(index_name=self.index_name, dimension=128)
mock_create_index.assert_called_once_with(index_name="test_index", dimension=128)

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index")
def test_describe_index(self, mock_describe_index):
"""Test that the describe_index method of PineconeHook is called with correct arguments."""
self.pinecone_hook.describe_index(index_name=self.index_name)
mock_describe_index.assert_called_once_with(index_name=self.index_name)

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.delete_index")
def test_delete_index(self, mock_delete_index):
"""Test that the delete_index method of PineconeHook is called with the correct index name."""
self.pinecone_hook.delete_index(index_name="test_index")
mock_delete_index.assert_called_once_with(index_name="test_index")

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.create_collection")
def test_create_collection(self, mock_create_collection):
"""
Test that the create_collection method of PineconeHook is called correctly.
"""
self.pinecone_hook.create_collection(collection_name="test_collection")
mock_create_collection.assert_called_once_with(collection_name="test_collection")

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.configure_index")
def test_configure_index(self, mock_configure_index):
"""
Test that the configure_index method of PineconeHook is called correctly.
"""
self.pinecone_hook.configure_index(index_configuration={})
mock_configure_index.assert_called_once_with(index_configuration={})

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_collection")
def test_describe_collection(self, mock_describe_collection):
"""
Test that the describe_collection method of PineconeHook is called correctly.
"""
self.pinecone_hook.describe_collection(collection_name="test_collection")
mock_describe_collection.assert_called_once_with(collection_name="test_collection")

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.list_collections")
def test_list_collections(self, mock_list_collections):
"""
Test that the list_collections method of PineconeHook is called correctly.
"""
self.pinecone_hook.list_collections()
mock_list_collections.assert_called_once()

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.query_vector")
def test_query_vector(self, mock_query_vector):
"""
Test that the query_vector method of PineconeHook is called correctly.
"""
self.pinecone_hook.query_vector(vector=[1.0, 2.0, 3.0])
mock_query_vector.assert_called_once_with(vector=[1.0, 2.0, 3.0])

def test__chunks(self):
"""
Test that the _chunks method of PineconeHook behaves as expected.
"""
data = list(range(10))
chunked_data = list(self.pinecone_hook._chunks(data, 3))
assert chunked_data == [(0, 1, 2), (3, 4, 5), (6, 7, 8), (9,)]

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.upsert_data_async")
def test_upsert_data_async_correctly(self, mock_upsert_data_async):
"""
Test that the upsert_data_async method of PineconeHook is called correctly.
"""
data = [("id1", [1.0, 2.0, 3.0], {"meta": "data"})]
self.pinecone_hook.upsert_data_async(index_name="test_index", data=data)
mock_upsert_data_async.assert_called_once_with(index_name="test_index", data=data)

@patch("airflow.providers.pinecone.hooks.pinecone.PineconeHook.describe_index_stats")
def test_describe_index_stats(self, mock_describe_index_stats):
"""
Test that the describe_index_stats method of PineconeHook is called correctly.
"""
self.pinecone_hook.describe_index_stats(index_name="test_index")
mock_describe_index_stats.assert_called_once_with(index_name="test_index")

0 comments on commit 65020ee

Please sign in to comment.