diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 28978ab3a80287..35aa66c3204b56 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -216,7 +216,7 @@ def create_index( index_name: str, dimension: int, spec: ServerlessSpec | PodSpec, - metric: str | None = None, + metric: str | None = "cosine", timeout: int | None = None, ) -> None: """ @@ -226,7 +226,7 @@ def create_index( :param dimension: The dimension of the vectors to be indexed. :param spec: Pass a `ServerlessSpec` object to create a serverless index or a `PodSpec` object to create a pod index. ``get_serverless_spec_obj`` and ``get_pod_spec_obj`` can be used to create the Spec objects. - :param metric: The metric to use. + :param metric: The metric to use. Defaults to cosine. :param timeout: The timeout to use. """ self.pinecone_client.create_index( diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index 102bd0eae0d82a..cbcba9e695cebd 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -99,10 +99,10 @@ class CreatePodIndexOperator(BaseOperator): :param replicas: The number of replicas to use. :param shards: The number of shards to use. :param pods: The number of pods to use. - :param pod_type: The type of pod to use. + :param pod_type: The type of pod to use. Defaults to p1.x1 :param metadata_config: The metadata configuration to use. :param source_collection: The source collection to use. - :param metric: The metric to use. + :param metric: The metric to use. Defaults to cosine. :param timeout: The timeout to use. """ @@ -116,7 +116,7 @@ def __init__( replicas: int | None = None, shards: int | None = None, pods: int | None = None, - pod_type: str | None = None, + pod_type: str | None = "p1.x1", metadata_config: dict | None = None, source_collection: str | None = None, metric: str | None = "cosine", diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py b/tests/system/providers/pinecone/example_pinecone_cohere.py index c39263fcba8916..9e35c6c89fd418 100644 --- a/tests/system/providers/pinecone/example_pinecone_cohere.py +++ b/tests/system/providers/pinecone/example_pinecone_cohere.py @@ -20,9 +20,9 @@ from datetime import datetime from airflow import DAG -from airflow.decorators import task, teardown +from airflow.decorators import setup, task, teardown from airflow.providers.cohere.operators.embedding import CohereEmbeddingOperator -from airflow.providers.pinecone.operators.pinecone import CreatePodIndexOperator, PineconeIngestOperator +from airflow.providers.pinecone.operators.pinecone import PineconeIngestOperator index_name = os.getenv("INDEX_NAME", "example-pinecone-index") namespace = os.getenv("NAMESPACE", "example-pinecone-index") @@ -36,15 +36,15 @@ start_date=datetime(2023, 1, 1), catchup=False, ) as dag: - create_index = CreatePodIndexOperator( - task_id="create_index", - index_name=index_name, - dimension=768, - replicas=1, - shards=1, - pods=1, - pod_type="p1.x1", - ) + + @setup + @task + def create_index(): + from airflow.providers.pinecone.hooks.pinecone import PineconeHook + + hook = PineconeHook() + pod_spec = hook.get_pod_spec_obj() + hook.create_index(index_name=index_name, dimension=768, spec=pod_spec) embed_task = CohereEmbeddingOperator( task_id="embed_task", @@ -69,7 +69,7 @@ def delete_index(): hook = PineconeHook() hook.delete_index(index_name=index_name) - create_index >> embed_task >> perform_ingestion >> delete_index() + create_index() >> embed_task >> perform_ingestion >> delete_index() from tests.system.utils import get_test_run # noqa: E402 diff --git a/tests/system/providers/pinecone/example_pinecone_openai.py b/tests/system/providers/pinecone/example_pinecone_openai.py index 44aa2c49cdaac3..d338e25542ce03 100644 --- a/tests/system/providers/pinecone/example_pinecone_openai.py +++ b/tests/system/providers/pinecone/example_pinecone_openai.py @@ -78,10 +78,6 @@ task_id="create_index", index_name=index_name, dimension=1536, - replicas=1, - shards=1, - pods=1, - pod_type="p1.x1", ) embed_task = OpenAIEmbeddingOperator(