In [None]:
from dotenv import load_dotenv
from snowflake.snowpark.session import Session
import os
from snowflake.core import Root
from typing import List

In [None]:
load_dotenv(override=True)

connection_params = {
  "account":  os.environ["SNOWFLAKE_ACCOUNT"],
  "user": os.environ["SNOWFLAKE_USER"],
  "password": os.environ["SNOWFLAKE_USER_PASSWORD"],
  "role": os.environ["SNOWFLAKE_ROLE"],
  "database": os.environ["SNOWFLAKE_DATABASE"],
  "schema": os.environ["SNOWFLAKE_SCHEMA"],
  "warehouse": os.environ["SNOWFLAKE_WAREHOUSE"],
}

snowpark_session = Session.builder.configs(connection_params).create()

In [None]:
class CortexSearchRetriever:

    def __init__(self, snowpark_session: Session, limit_to_retrieve: int = 4):
        self._snowpark_session = snowpark_session
        self._limit_to_retrieve = limit_to_retrieve

    def retrieve(self, query: str) -> List[str]:
        root = Root(self._snowpark_session)
        cortex_search_service = (
            root.databases[os.environ["SNOWFLAKE_DATABASE"]]
            .schemas[os.environ["SNOWFLAKE_SCHEMA"]]
            .cortex_search_services[os.environ["SNOWFLAKE_CORTEX_SEARCH_SERVICE"]]
        )
        resp = cortex_search_service.search(
            query=query,
            columns=["INTRODUCTION"],
            limit=self._limit_to_retrieve,
        )

        if resp.results:
            return [curr["INTRODUCTION"] for curr in resp.results]
        else:
            return []

In [None]:
retriever = CortexSearchRetriever(snowpark_session=snowpark_session, limit_to_retrieve=4)

In [None]:
retrieved_context = retriever.retrieve(query="the largest city in")
print(retrieved_context)

In [None]:
print(len(retrieved_context))

In [None]:
snowpark_session.close()