In [0]:
# Install dependencies
%pip install databricks-vectorsearch mlflow openai
%pip install --upgrade "mlflow[databricks]>=3.1" openai

In [0]:
# Restart Python environment
dbutils.library.restartPython()

In [0]:
from openai import OpenAI
from pyspark.sql.functions import col, concat_ws, udf
from pyspark.sql.types import ArrayType, FloatType
from databricks.vector_search.client import VectorSearchClient
from pyspark.sql import SparkSession
# from langchain_databricks import ChatDatabricks
from langchain_core.prompts import PromptTemplate

In [0]:
# Step 1: Load the data
reviews = spark.read.parquet('/databricks-datasets/amazon/test4K/')
reviews_30 = reviews.limit(30)

In [0]:
# Step 2: Prepare the data
reviews_prepared = reviews_30.withColumn(
    "combined_text",
    concat_ws(
        ". ",
        col("brand"),
        col("title"),
        col("review"),
        col("price").cast("string"),
        col("rating").cast("string")
    )
).select("asin", "brand", "title", "review", "price", "rating", "combined_text")

In [0]:
# Step 4: Generate embeddings
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

def get_embedding(text):
    client = OpenAI(
        api_key=DATABRICKS_TOKEN,
        base_url="https://dbc-f41c3a47-b6f9.cloud.databricks.com/serving-endpoints"
    )
    response = client.embeddings.create(input=text, model="databricks-gte-large-en")
    return response.data[0].embedding

get_embeddings_udf = udf(get_embedding, ArrayType(FloatType()))
reviews_with_embeddings = reviews_prepared.withColumn("embeddings", get_embeddings_udf(col("combined_text")))

# Step 3: Create Delta table
table_name = "workspace.default.amazon_reviews_rag"
reviews_with_embeddings.write.mode("overwrite").format("delta").saveAsTable(table_name)

In [0]:
# Enable Change Data Feed
spark = SparkSession.builder.getOrCreate()
spark.sql(f"ALTER TABLE {table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")

# Step 5: Create Vector Search index
vsc = VectorSearchClient(disable_notice = True)
endpoint_name = "vs_endpoint"
index_name = "workspace.default.amazon_reviews_index"

vsc.create_delta_sync_index(
    endpoint_name=endpoint_name,
    source_table_name=table_name,
    index_name=index_name,
    pipeline_type="TRIGGERED",
    primary_key="brand",
    embedding_dimension=1024,  # Correct dimension for databricks-gte-large-en
    embedding_vector_column="embeddings"
)

# Wait for index to be ready
# vsc.wait_for_index_to_be_ready(index_name)


In [0]:
# Steps 6 & 7: Perform semantic queries and generate responses
client = OpenAI(
    api_key=DATABRICKS_TOKEN,
    base_url="https://dbc-f41c3a47-b6f9.cloud.databricks.com/serving-endpoints"
)

def get_query_embedding(question):
    response = client.embeddings.create(input=question, model="databricks-gte-large-en")
    return response.data[0].embedding

questions = [
    "What is a good lotion for people with dry skin that lasts all day?",
    "Is there a gluten-free pasta that tastes like regular wheat pasta?",
    "Is it worth paying a bit more for a better meat grinder?",
    "Can a USB connector have mechanical issues?"
]

for question in questions:
    # Generate query embedding
    query_vector = get_query_embedding(question)
    
    # Query Vector Search with query_vector
    results = vsc.get_index(index_name = index_name).similarity_search(
        query_vector=query_vector,
        columns=["asin", "brand", "title", "review", "price", "rating"],
        num_results=3
    )
    
    # Prepare context from search results
    context = "\n".join([
        f"Product: {res[2]}\nReview: {res[3]}\nBrand: {res[1]}\nRating: {res[5]}\nPrice: {res[4]}"
        for res in results['result']['data_array']
    ])
    
    # Generate response using LLM
    prompt = f"Question: {question}\n\nContext: {context}\n\nAnswer:"
    response = client.chat.completions.create(
        model="databricks-meta-llama-3-1-8b-instruct",
        messages=[
            {"role": "system", "content": "You are a helpful assistant that answers questions based on provided context."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=500
    )
    
    # Display results
    print(f"\nQuestion: {question}")
    print("\nRetrieved Context:")
    print(context)
    print("\nGenerated Response:")
    print(response.choices[0].message.content)