# Setup

In [2]:
from IPython.display import clear_output

!pip install astrapy datasets pandas python-dotenv

clear_output()

In [3]:
from google.colab import userdata
import os

# Initialize secrets
os.environ["ASTRA_DB_API_ENDPOINT"] = "https://047b29ef-98d0-4ff5-a342-d959b024506a-us-east-2.apps.astra.datastax.com"

# This notebook is designed to be run on Google Colab. The userdata object is a helper object that
# will obtain secrets stored in Google Colab's secret store.
os.environ["ASTRA_DB_APPLICATION_TOKEN"] = userdata.get("ASTRA_DB_APPLICATION_TOKEN")
os.environ["LANGFLOW_API_KEY"] = userdata.get("LANGFLOW_API_KEY")
os.environ["RAPID_API_KEY"] = userdata.get("RAPID_API_KEY")

## Populate DB

In [None]:
from astrapy import DataAPIClient

# Connect to Astra DB
client = DataAPIClient(os.environ["ASTRA_DB_APPLICATION_TOKEN"])
database = client.get_database(os.environ["ASTRA_DB_API_ENDPOINT"])
collection = database.get_collection("grocerrify_food")

In [None]:
# Cleanup vector database
collection.delete_all()

In [15]:
# Get sample foods from Coles
import requests

url = "https://coles-product-price-api.p.rapidapi.com/coles/product-search/?query=chicken"

payload = {}
headers = {
  'x-rapidapi-host': 'coles-product-price-api.p.rapidapi.com',
  'x-rapidapi-key': os.environ["RAPID_API_KEY"]
}

response = requests.request("GET", url, headers=headers, data=payload).json()

coles_products = []
for product in response["results"]:
  coles_products.append({"product_name": product["product_name"],
                         "source": "coles"})

In [16]:
# Check the products
coles_products

[{'product_name': 'Creamy Chicken with Mushroom Simmer Sauce',
  'source': 'coles'},
 {'product_name': 'Indian Mild Butter Chicken Simmer Sauce',
  'source': 'coles'},
 {'product_name': 'Chicken Salted Chicken Chips', 'source': 'coles'},
 {'product_name': 'Chicken', 'source': 'coles'},
 {'product_name': 'Creamy Treats Cat Food Chicken & Chicken Whitefish 4X12g',
  'source': 'coles'},
 {'product_name': 'Chicken Necks', 'source': 'coles'},
 {'product_name': 'Butter Chicken', 'source': 'coles'},
 {'product_name': 'Rice Chicken', 'source': 'coles'},
 {'product_name': 'Chicken Masala', 'source': 'coles'},
 {'product_name': 'Whole Chicken', 'source': 'coles'}]

In [5]:
# Truncate content so string does not overflow
# max_bytes=512 for Nvidia NV-Embed-QA embedding model
def truncate_content(content, max_bytes=512):
    # Encode the string into bytes (UTF-8 encoding)
    content_bytes = content.encode('utf-8')

    # Check if the byte length exceeds the limit
    if len(content_bytes) > max_bytes:
        # Truncate the content to the maximum byte size
        truncated_content_bytes = content_bytes[:max_bytes]

        # Decode back to a string, ensuring no decoding errors occur
        truncated_content = truncated_content_bytes.decode('utf-8', errors='ignore')
    else:
        truncated_content = content

    return truncated_content

In [34]:
# Insert record to Astra DB
for product in coles_products:

    # Combine relevant fields into a single string to be vectorized
    content = product["product_name"]

    try:
        truncated_content = truncate_content(content)
        collection.insert_one(
            document={
                'content': truncated_content,
                '$vectorize': truncated_content,
                'metadata': product
            }
        )
        print(f"Inserted a product: {product}")
    except Exception as ex:
        print(ex)

Inserted a product: {'product_name': 'Creamy Chicken with Mushroom Simmer Sauce', 'source': 'coles'}
Inserted a product: {'product_name': 'Indian Mild Butter Chicken Simmer Sauce', 'source': 'coles'}
Inserted a product: {'product_name': 'Chicken Salted Chicken Chips', 'source': 'coles'}
Inserted a product: {'product_name': 'Chicken', 'source': 'coles'}
Inserted a product: {'product_name': 'Creamy Treats Cat Food Chicken & Chicken Whitefish 4X12g', 'source': 'coles'}
Inserted a product: {'product_name': 'Chicken Necks', 'source': 'coles'}
Inserted a product: {'product_name': 'Butter Chicken', 'source': 'coles'}
Inserted a product: {'product_name': 'Rice Chicken', 'source': 'coles'}
Inserted a product: {'product_name': 'Chicken Masala', 'source': 'coles'}
Inserted a product: {'product_name': 'Whole Chicken', 'source': 'coles'}


### Test!

In [39]:
# Retrieve data from Astra DB
try:
  result = collection.find(
      {},
      sort={"$vectorize": "Butter chicken"},
      limit=10
  )

  # Convert cursor to a list of documents
  all_docs = list(result)
  print(f"Found {len(all_docs)} documents")

  # Now you can work with the list
  for doc in all_docs:
      print(doc.get('metadata', {}).get('product_name', 'N/A'))

except Exception as ex:
  print(ex)

Found 10 documents
Butter Chicken
Chicken
Whole Chicken
Rice Chicken
Chicken Masala
Indian Mild Butter Chicken Simmer Sauce
Creamy Chicken with Mushroom Simmer Sauce
Chicken Necks
Creamy Treats Cat Food Chicken & Chicken Whitefish 4X12g
Chicken Salted Chicken Chips


## The Product Vector DB Accessor Class

In [84]:
from astrapy import DataAPIClient
from typing import Literal, List

class ProductVectorDBProviderConfig():
  def __init__(self,
               astra_db_application_token: str,
               astra_db_api_endpoint: str,
               astra_db_collection_name: str):
    self.application_token = astra_db_application_token
    self.api_endpoint = astra_db_api_endpoint
    self.collection_name = astra_db_collection_name

class ProductVectorDBProvider():
  def __init__(self,
               config: ProductVectorDBProviderConfig):
    """
    Arguments:
      config (ProductVectorDBProviderConfig): Config object to connect to the database
    """

    # Connect to Astra DB
    client = DataAPIClient(config.application_token)
    database = client.get_database(config.api_endpoint)
    self.collection = database.get_collection(config.collection_name)

  def insert_product(product_name: str,
                     source: Literal["coles", "woolworths"]) -> bool:
    """A function for inserting a product into the vector DB

    Arguments:
      product_name (str): The product name we want to insert
      source (str): The source of the product (can only be "coles" or "woolworths")
    Returns:
      bool: True if the insertion is successful. Otherwise, returns False
    """
    # Combine relevant fields into a single string to be vectorized
    content = product_name

    try:
      truncated_content = self._truncate_content(content)
      self.collection.insert_one(
        document={
          'content': truncated_content,
          '$vectorize': truncated_content,
          'metadata': product
        }
      )
      return True
    except Exception as ex:
      print(ex)
      return False

  def get_products_by_similarity(self,
                                 product_name: str,
                                 limit: int = 5) -> List[str]:
    """A function for retrieving product names based on similarity.

    Arguments:
      product_name (str): The product name from which we will retrieve similar products.
      limit (Optional[int]): Limit of how many products to retrieve.
    Returns:
      List[str]: A list of similar product names.
    """
    # Retrieve data from Astra DB
    try:
      result = self.collection.find(
        {},
        sort={"$vectorize": "product_name"},
        limit=limit
      )

      # Convert cursor to a list of documents
      all_docs = list(result)

      # Extract the product names and return them
      return [doc.get('metadata', {}).get('product_name', 'N/A') for doc in all_docs]

    except Exception as ex:
      print(f"Error retrieving products by similarity. Error: {ex}")
      return []


  def _truncate_content(content, max_bytes=512):
    """A function for truncating content so the string does not overflow.
    The max_bytes=512 for Nvidia NV-Embed-QA embedding model.
    """
    # Encode the string into bytes (UTF-8 encoding)
    content_bytes = content.encode('utf-8')

    # Check if the byte length exceeds the limit
    if len(content_bytes) > max_bytes:
      # Truncate the content to the maximum byte size
      truncated_content_bytes = content_bytes[:max_bytes]

      # Decode back to a string, ensuring no decoding errors occur
      truncated_content = truncated_content_bytes.decode('utf-8', errors='ignore')
    else:
      truncated_content = content

    return truncated_content

In [85]:
# Initialize Astra DB Connection
COLLECTION_NAME = "grocerrify_food"

productVectorDBProviderConfig = ProductVectorDBProviderConfig(
    astra_db_application_token=os.environ["ASTRA_DB_APPLICATION_TOKEN"],
    astra_db_api_endpoint=os.environ["ASTRA_DB_API_ENDPOINT"],
    astra_db_collection_name=COLLECTION_NAME
)

productVectorDBProvider = ProductVectorDBProvider(
    config=productVectorDBProviderConfig
)

### Test the class

In [86]:
productVectorDBProvider.get_products_by_similarity(product_name="Chicken")

['Chicken',
 'Indian Mild Butter Chicken Simmer Sauce',
 'Creamy Treats Cat Food Chicken & Chicken Whitefish 4X12g',
 'Chicken Masala',
 'Whole Chicken']

## The LangFlow Flow Accessor Class

In [90]:
import json
import requests
from typing import Optional, List, Dict

SIMILARITY_SEARCH_FLOW_ID = "5993123f-28ec-4ae9-afca-8b74e387516d"

# You can tweak the flow by adding a tweaks dictionary
# e.g {"OpenAI-XXXXX": {"model_name": "gpt-4"}}
TWEAKS = {
  "ChatInput-FoMpG": {},
  "AstraDB-JjnZe": {},
  "ParseData-9jdkC": {},
  "Prompt-yg8Dh": {},
  "OpenAIModel-vws7B": {},
  "ChatOutput-9youq": {}
}

class LangFlowProviderConfig():
  def __init__(self,
               similarity_search_flow_id: str,
               tweaks: Dict[str, any] = None):
    self.similarity_search_flow_id = similarity_search_flow_id

class LangFlowProvider():
  def __init__(self,
               config: LangFlowProviderConfig,
               application_token: str,
               base_api_url: str = "https://api.langflow.astra.datastax.com",
               langflow_id: str = "7c6d0acb-34f0-4302-9953-803722048c01"
               ):
    self.config = config
    self.application_token = application_token
    self.base_api_url = base_api_url
    self.langflow_id = langflow_id

  def get_most_relevant_product(self, product_name: str) -> str:
    return self._run_rag_similarity_search_flow(
        product_name=product_name
    )

  def _run_rag_similarity_search_flow(
      self,
      product_name: str,
      output_type: str = "chat",
      input_type: str = "chat",
      tweaks: Optional[dict] = None,
      application_token: Optional[str] = None
    ) -> List[str]:
    """
    Run an RAG similarity search flow with a given message

    :param product_name: The product_name we want to search for
    :param endpoint: The ID or the endpoint name of the flow
    :param tweaks: Optional tweaks to customize the flow
    :return: A list of top three most relevant products
    """
    similarity_search_flow_id = self.config.similarity_search_flow_id
    api_url = f"{self.base_api_url}/lf/{self.langflow_id}/api/v1/run/{similarity_search_flow_id}?stream=false"

    payload = {
        "input_value": product_name,
        "output_type": output_type,
        "input_type": input_type,
    }
    headers = None
    if tweaks:
        payload["tweaks"] = tweaks
    if self.application_token:
        headers = {"Authorization": "Bearer " + self.application_token, "Content-Type": "application/json"}

    try:
      response = requests.post(api_url, json=payload, headers=headers).json()
      message = json.loads(response["outputs"][0]["outputs"][0]["results"]["message"]["data"]["text"])
      return message
    except Exception as e:
      print(e)
      return ""

In [91]:
# Initialize the client
langFlowProviderConfig = LangFlowProviderConfig(similarity_search_flow_id=SIMILARITY_SEARCH_FLOW_ID)
langFlowProvider = LangFlowProvider(config=langFlowProviderConfig,
                                    application_token=os.environ["LANGFLOW_API_KEY"])

In [92]:
# Test the class!
langFlowProvider.get_most_relevant_product(product_name="Butter chicken")

Expecting value: line 1 column 1 (char 0)


''