# <span style="color:#1f77b4">**Generative AI - RAG System**</span>

This notebook builds an end-to-end RAG pipeline: load data, create a Vector Search index, retrieve context, generate answers with Azure OpenAI, then log and serve the model.


## <span style="color:#1f77b4">**Configure Unity Catalog paths**</span>

Set catalog, schema, volume, and external location names, then create the UC schema and external volume that point to your storage container.


In [0]:
# Reset widgets so reruns don't keep stale values.
dbutils.widgets.removeAll()

# Widgets let you override catalog/schema/volume without editing code.
dbutils.widgets.text("CATALOG", "")
dbutils.widgets.text("SCHEMA", "rag")
dbutils.widgets.text("VOLUME", "raw")
dbutils.widgets.text("EXTERNAL_LOCATION", "uc-external-location")

# Resolve the active catalog (widget wins, otherwise use a non-system catalog).
catalog_widget = dbutils.widgets.get("CATALOG")
if catalog_widget:
    catalog_name = catalog_widget
else:
    current = spark.sql("SELECT current_catalog()").first()[0]
    catalogs = [r.catalog for r in spark.sql("SHOW CATALOGS").collect()]
    catalog_name = current if current not in ("system",) else next(c for c in catalogs if c not in ("system",))

schema_name = dbutils.widgets.get("SCHEMA")
volume_leaf = dbutils.widgets.get("VOLUME")
external_location_name = dbutils.widgets.get("EXTERNAL_LOCATION")

# Build fully-qualified names used throughout the notebook.
table_name = f"{catalog_name}.{schema_name}.diabetes_faq_table"
index_name = f"{catalog_name}.{schema_name}.diabetes_faq_index"
volume_name = f"{catalog_name}.{schema_name}.{volume_leaf}"

# Ensure the schema exists.
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{schema_name}")

# Read the external location URL (shape differs across runtimes).
location_rows = spark.sql(f"DESCRIBE EXTERNAL LOCATION `{external_location_name}`")
if "url" in location_rows.columns:
    external_url = location_rows.select("url").first()["url"].rstrip("/")
else:
    external_url = location_rows.filter("key = "url"").select("value").first()["value"].rstrip("/")

# Create a UC external volume at that location.
spark.sql(
    f"CREATE EXTERNAL VOLUME IF NOT EXISTS {volume_name}
"
    f"    LOCATION '{external_url}'
"
)

# Build the CSV path inside the UC volume.
data_path = f"dbfs:/Volumes/{catalog_name}/{schema_name}/{volume_leaf}/diabetes_treatment_faq.csv"


## <span style="color:#1f77b4">**Load the CSV into a Spark DataFrame**</span>

Read the CSV from the UC volume and inspect a sample plus the schema to validate the columns.


In [0]:
from pyspark.sql.functions import *

# Load the raw CSV into a Spark DataFrame.
df = spark.read.csv(data_path, header=True)

# Preview and confirm the schema.
display(df.limit(10))
df.printSchema()


Topic,Description
What is diabetes?,"Diabetes is a chronic condition that affects how the body processes glucose (sugar). It occurs when the body cannot produce enough insulin or the insulin it produces is ineffective in regulating blood sugar. Insulin is a hormone produced by the pancreas that helps glucose enter the cells of the body for energy. Without sufficient insulin, glucose builds up in the bloodstream, leading to high blood sugar levels. Over time, uncontrolled diabetes can cause serious health complications such as heart disease, kidney damage, nerve damage, and vision problems. Proper management and treatment of diabetes are essential to preventing these complications and maintaining a good quality of life. Early detection, lifestyle changes, and medication are key factors in effectively managing the disease."
What are the different types of diabetes?,"Diabetes is categorized into two main types: Type 1 and Type 2. Type 1 diabetes is an autoimmune condition where the bodyâ€™s immune system attacks and destroys the insulin-producing cells in the pancreas, leading to little or no insulin production. It typically develops in children or young adults and requires lifelong insulin therapy. Type 2 diabetes, on the other hand, occurs when the body becomes resistant to insulin or does not produce enough insulin to meet the bodyâ€™s needs. It is more common in adults, particularly those who are overweight, inactive, or have a family history of the disease. While Type 1 is not preventable, Type 2 can often be prevented or delayed through lifestyle changes, including diet and exercise."
What are the symptoms of diabetes?,"The symptoms of diabetes can vary depending on the type and how long the condition has been present. Common signs include frequent urination, excessive thirst, hunger, and unexplained weight loss. Some people may experience blurred vision, fatigue, and slow-healing wounds. In the case of Type 1 diabetes, symptoms often develop rapidly, while Type 2 diabetes symptoms may be more subtle and develop over time. Because the early symptoms may not always be noticeable, it is important to get regular check-ups, especially if you are at risk for diabetes. Uncontrolled diabetes can lead to serious complications, so timely diagnosis and treatment are essential."
How is diabetes diagnosed?,"Diabetes is diagnosed through various blood tests. The fasting blood glucose test measures blood sugar levels after an overnight fast, while the oral glucose tolerance test checks how well the body processes sugar after consuming a sugary drink. The HbA1c test, which reflects the average blood sugar levels over the past 2-3 months, is also commonly used to diagnose and monitor diabetes. An HbA1c level of 6.5% or higher is typically indicative of diabetes. A diagnosis may also involve checking for other conditions associated with diabetes, such as high blood pressure or cholesterol imbalances. Early detection allows for better management and prevention of complications."
What is the role of insulin in diabetes?,"Insulin is a hormone produced by the pancreas that helps regulate blood sugar levels by allowing glucose to enter cells for energy. In people with diabetes, either the body does not produce enough insulin (Type 1 diabetes) or the bodyâ€™s cells do not respond effectively to insulin (Type 2 diabetes). As a result, glucose accumulates in the bloodstream, leading to high blood sugar. Insulin therapy, typically in the form of injections or an insulin pump, helps to lower blood sugar levels and mimic the bodyâ€™s natural insulin production. Insulin is a crucial part of managing diabetes, particularly for those with Type 1, and can also be used in Type 2 when lifestyle changes and oral medications are not sufficient."
What are the treatment options for type 1 diabetes?,"For people with Type 1 diabetes, treatment primarily involves lifelong insulin therapy to manage blood sugar levels. Insulin can be administered through injections or via an insulin pump. In addition to insulin, individuals with Type 1 diabetes must closely monitor their blood sugar levels throughout the day, maintain a healthy diet, and engage in regular physical activity to help manage their condition. People with Type 1 diabetes also need to be vigilant about potential complications, such as diabetic ketoacidosis, and should regularly consult with their healthcare providers to adjust their treatment plan as needed."
What are the treatment options for type 2 diabetes?,"For Type 2 diabetes, treatment typically starts with lifestyle modifications such as a balanced diet, regular exercise, and weight management. If lifestyle changes are not sufficient, oral medications such as metformin may be prescribed to help regulate blood sugar levels. In some cases, individuals with Type 2 diabetes may require insulin therapy or other injectable medications to improve insulin sensitivity. For those with severe Type 2 diabetes or obesity, bariatric surgery may be considered to help improve blood sugar control. The goal of treatment is to achieve normal blood sugar levels and prevent complications such as heart disease, kidney damage, or nerve damage."
How can I manage my blood sugar levels?,"Managing blood sugar levels is essential for diabetes control. The key factors in blood sugar management include regular monitoring, taking prescribed medications, following a healthy diet, exercising regularly, and managing stress. Monitoring your blood sugar levels allows you to understand how food, exercise, and medications affect your glucose. A balanced diet rich in fiber, lean proteins, and healthy fats can help regulate blood sugar. Exercise increases insulin sensitivity, which helps control blood sugar levels. Stress management techniques, such as yoga and relaxation exercises, are also important for maintaining stable blood sugar. Regular doctor visits are essential to adjust treatment plans as needed."
What is a healthy diet for someone with diabetes?,"A healthy diet for people with diabetes focuses on foods that help regulate blood sugar levels and promote overall health. This includes consuming whole grains, vegetables, lean proteins, and healthy fats while limiting foods with a high glycemic index, such as refined sugars and processed foods. Eating smaller, more frequent meals can also help maintain steady blood sugar levels throughout the day. Carbohydrate counting is a common practice for those with diabetes, as carbohydrates have a direct impact on blood sugar. It is important to work with a dietitian to develop a meal plan that meets individual needs and preferences while promoting blood sugar control."
How often should I check my blood sugar levels?,"Blood sugar monitoring is essential for people with diabetes to understand how their body responds to different foods, activities, and medications. The frequency of testing depends on the type of diabetes and treatment plan. For individuals with Type 1 diabetes or those on insulin, blood sugar may need to be tested multiple times per day, including before and after meals or exercise. For Type 2 diabetes, the frequency of testing may vary based on medication use and blood sugar control. A healthcare provider can provide personalized recommendations for how often to check blood sugar levels, ensuring that levels remain within the target range to prevent complications."


root
 |-- Topic: string (nullable = true)
 |-- Description: string (nullable = true)



## <span style="color:#1f77b4">**Configure Azure OpenAI client**</span>

Load Azure OpenAI secrets from the Key Vault-backed scope and construct a client for chat completions.


In [0]:
from openai import AzureOpenAI
import json

# Pull secrets from the Databricks secret scope (Key Vault-backed).
openai_endpoint = dbutils.secrets.get("aoai-scope", "openai-api-base")
openai_key = dbutils.secrets.get("aoai-scope", "openai-api-key")
openai_api_version = dbutils.secrets.get("aoai-scope", "openai-api-version")
deployment_name = dbutils.secrets.get("aoai-scope", "openai-deployment-name")

# Create the Azure OpenAI client.
client = AzureOpenAI(
    api_key=openai_key,
    api_version=openai_api_version,
    azure_endpoint=openai_endpoint,
)


## <span style="color:#1f77b4">**Persist to Parquet + Delta**</span>

Persist the dataset into the UC volume and a Delta table so Vector Search can sync from Delta.


In [0]:
# Store a parquet copy in the UC volume for raw access.
parquet_path = f"dbfs:/Volumes/{catalog_name}/{schema_name}/raw/diabetes_faq.parquet"
df.write.mode("overwrite").parquet(parquet_path)

# Write a Delta table for downstream Vector Search.
df.write.format("delta").mode("overwrite").saveAsTable(table_name)


## <span style="color:#1f77b4">**Enable Change Data Feed**</span>

Enable Change Data Feed (CDF) so the delta sync index can track incremental updates to the table.


In [0]:
# Enable change data feed for the existing Delta table.
spark.sql(
    f"ALTER TABLE {table_name}
"
    "SET TBLPROPERTIES (delta.enableChangeDataFeed = true)"
)


DataFrame[]

## <span style="color:#1f77b4">**Create or reuse the Vector Search index**</span>

Authenticate with the service principal, create or reuse the Vector Search endpoint, then build a delta sync index.


In [0]:
# ============================================================
# Databricks Vector Search (OAuth SP): quota-safe endpoint reuse + index
# - Fixes endpoint polling (no more state=None forever)
# - Works with list_endpoints() shape: {"endpoints":[{"endpoint_status":{"state":"ONLINE"}}...]}
# ============================================================

import os
import time
from azure.identity import ClientSecretCredential
from databricks.vector_search.client import VectorSearchClient

# ------------------------------------------------------------
# 0) OAuth env vars (notebook convenience only)
#    Serving should set these as env vars, not dbutils.
# ------------------------------------------------------------
os.environ.setdefault("DATABRICKS_AUTH_TYPE", "oauth")

if not os.getenv("DATABRICKS_HOST"):
    # spark.databricks.workspaceUrl often returns "eastus2-c3.azuredatabricks.net" without scheme
    os.environ["DATABRICKS_HOST"] = spark.conf.get("spark.databricks.workspaceUrl", "")

if not os.getenv("DATABRICKS_CLIENT_ID"):
    os.environ["DATABRICKS_CLIENT_ID"] = dbutils.secrets.get("dbx-sp-scope", "dbx-client-id")
if not os.getenv("DATABRICKS_CLIENT_SECRET"):
    os.environ["DATABRICKS_CLIENT_SECRET"] = dbutils.secrets.get("dbx-sp-scope", "dbx-client-secret")
if not os.getenv("DATABRICKS_TENANT_ID"):
    os.environ["DATABRICKS_TENANT_ID"] = dbutils.secrets.get("dbx-sp-scope", "dbx-tenant-id")

host = os.getenv("DATABRICKS_HOST") or ""
if host and not host.startswith("https://"):
    host = f"https://{host}"

# AAD token for Azure Databricks resource
credential = ClientSecretCredential(
    tenant_id=os.getenv("DATABRICKS_TENANT_ID"),
    client_id=os.getenv("DATABRICKS_CLIENT_ID"),
    client_secret=os.getenv("DATABRICKS_CLIENT_SECRET"),
)
token = credential.get_token("2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default").token

def build_vector_client(host: str, token: str) -> VectorSearchClient:
    """
    Construct VectorSearchClient using an AAD token minted from the SP.
    """
    for kwargs in (
        {"workspace_url": host, "personal_access_token": token},
        {"host": host, "token": token},
        {"endpoint": host, "token": token},
    ):
        try:
            return VectorSearchClient(disable_notice=True, **kwargs)
        except TypeError:
            continue
    raise RuntimeError("VectorSearchClient init failed; check databricks-vectorsearch version.")

vector_client = build_vector_client(host, token)

# ------------------------------------------------------------
# Assumes these are defined earlier in your notebook:
#   table_name (e.g. "catalog.schema.table")
#   index_name (e.g. "catalog.schema.index")
# Optional:
#   endpoint_name (preferred endpoint)
# ------------------------------------------------------------
preferred_endpoint_name = globals().get("endpoint_name", "vector_search_endpoint")

# ------------------------------------------------------------
# Endpoint helpers (FIXED)
# ------------------------------------------------------------
def list_endpoints_safe(client) -> list:
    """
    Normalise list_endpoints() across versions:
    - returns {"endpoints":[...]} or {"vector_search_endpoints":[...]} or list
    """
    if not hasattr(client, "list_endpoints"):
        return []
    resp = client.list_endpoints()
    if isinstance(resp, dict):
        return resp.get("endpoints") or resp.get("vector_search_endpoints") or []
    return resp or []

def get_endpoint_info_safe(client, name: str):
    """
    Prefer get_endpoint(); otherwise search list_endpoints().
    """
    if hasattr(client, "get_endpoint"):
        try:
            info = client.get_endpoint(name)
            # Some versions wrap: {"endpoint": {...}}
            if isinstance(info, dict) and "endpoint" in info and isinstance(info["endpoint"], dict):
                return info["endpoint"]
            return info
        except Exception:
            pass

    for ep in list_endpoints_safe(client):
        if isinstance(ep, dict) and ep.get("name") == name:
            return ep
    return None

def ensure_endpoint(client, preferred_name: str) -> str:
    """
    Ensure an endpoint exists without exceeding quota.
    If quota exceeded, reuse existing endpoint.
    """
    # Reuse if already exists
    info = get_endpoint_info_safe(client, preferred_name)
    if info is not None:
        return preferred_name

    # If any endpoint exists, reuse the first
    eps = list_endpoints_safe(client)
    if eps:
        first = eps[0]
        return first.get("name") if isinstance(first, dict) else first

    # Otherwise create it
    try:
        client.create_endpoint(name=preferred_name, endpoint_type="STANDARD")
        return preferred_name
    except Exception as exc:
        msg = str(exc)
        if "QUOTA_EXCEEDED" in msg or "Maximum number of vector search endpoints" in msg:
            eps = list_endpoints_safe(client)
            if eps:
                first = eps[0]
                return first.get("name") if isinstance(first, dict) else first
        raise

def wait_for_endpoint_online(client, name: str, timeout_s: int = 900, poll_s: int = 10):
    """
    Wait until endpoint_status.state is ONLINE/READY.
    This version is guaranteed to work with your observed payload:
    {"endpoints":[{"endpoint_status":{"state":"ONLINE"}}...]}
    """
    deadline = time.time() + timeout_s
    last = None
    while time.time() < deadline:
        info = get_endpoint_info_safe(client, name)
        last = info
        state = None
        if isinstance(info, dict):
            state = (info.get("endpoint_status") or {}).get("state")

        print(f"[{time.strftime('%H:%M:%S')}] endpoint={name} state={state}")

        if state in ("ONLINE", "READY"):
            return info
        if state in ("FAILED", "ERROR"):
            raise RuntimeError(f"Endpoint failure: {info}")

        time.sleep(poll_s)

    raise TimeoutError(f"Endpoint not online after {timeout_s}s. Last: {last}")

# ------------------------------------------------------------
# Index helpers (your index.describe() shape uses status.ready)
# ------------------------------------------------------------
def wait_for_index_ready(index, timeout_s: int = 1800, poll_s: int = 15):
    deadline = time.time() + timeout_s
    last = None
    while time.time() < deadline:
        info = index.describe()
        last = info
        status = info.get("status") or {}
        detailed_state = status.get("detailed_state")
        ready = status.get("ready")

        print(f"[{time.strftime('%H:%M:%S')}] index_state={detailed_state} ready={ready}")

        if ready is True:
            return info
        if isinstance(detailed_state, str) and ("FAILED" in detailed_state or "ERROR" in detailed_state):
            raise RuntimeError(f"Index failure: {info}")

        time.sleep(poll_s)

    raise TimeoutError(f"Index not ready after {timeout_s}s. Last: {last}")

# ------------------------------------------------------------
# 1) Get endpoint without exceeding quota
# ------------------------------------------------------------
endpoint_name = ensure_endpoint(vector_client, preferred_endpoint_name)
print(f"Using endpoint: {endpoint_name}")

# 2) Wait for endpoint (will return immediately if already ONLINE)
wait_for_endpoint_online(vector_client, endpoint_name)

# ------------------------------------------------------------
# 3) Get or create index (retry if provisioning)
# ------------------------------------------------------------

def get_or_create_index(client, endpoint, index_name, table_name):
    try:
        idx = client.get_index(endpoint, index_name)
        print("Index already exists.")
        return idx
    except Exception:
        pass

    try:
        print("Creating index...")
        return client.create_delta_sync_index(
            endpoint_name=endpoint,
            source_table_name=table_name,
            index_name=index_name,
            pipeline_type="TRIGGERED",
            primary_key="Topic",
            embedding_source_column="Description",
            embedding_model_endpoint_name="databricks-gte-large-en",
        )
    except Exception as exc:
        msg = str(exc).lower()
        if "not ready" in msg or "already exists" in msg:
            print("Index is provisioning; waiting for readiness.")
            for _ in range(12):
                try:
                    return client.get_index(endpoint, index_name)
                except Exception:
                    time.sleep(10)
        raise

index = get_or_create_index(vector_client, endpoint_name, index_name, table_name)

# 4) Trigger sync (optional, safe)
try:
    index.sync()
    print("index.sync() triggered.")
except Exception as exc:
    if "not supported" not in str(exc).lower():
        raise

# 5) Wait for readiness
final_info = wait_for_index_ready(index)

print("Index READY")
print(final_info["status"])


Using endpoint: vector_search_endpoint
[18:00:45] endpoint=vector_search_endpoint state=ONLINE
Index already exists.
index.sync() triggered.
[18:00:48] index_state=ONLINE_UPDATING_PIPELINE_RESOURCES ready=True
Index READY
{'detailed_state': 'ONLINE_UPDATING_PIPELINE_RESOURCES', 'message': 'Index is currently online, pipeline update is pending setup of pipeline resources. Check latest status: https://adb-7405608176797015.15.azuredatabricks.net/explore/data/adb_genai_super_locust/rag/diabetes_faq_index', 'indexed_row_count': 10, 'triggered_update_status': {'last_processed_commit_version': 98, 'last_processed_commit_timestamp': '2026-01-01T17:50:20Z'}, 'ready': True, 'index_url': '5c9c28c1-77e2-4a65-ab48-93a9e4076ad6.0.vector-search.azuredatabricks.net/api/2.0/7405608176797015/vector-search/indexes/adb_genai_super_locust.rag.diabetes_faq_index'}


## <span style="color:#1f77b4">**Retrieve context from the index**</span>

Run a quick similarity search to confirm the index returns relevant content.


In [0]:
# Example user query to validate retrieval.
user_question = "what is diabetes?"

# Fetch the nearest matching row from Vector Search.
results_dict = index.similarity_search(
    query_text=user_question,
    columns=["Topic", "Description"],
    num_results=1,
)

# Capture the retrieved content for downstream generation.
content = str(results_dict["result"]["data_array"][0])
print(content)


[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
['What is diabetes?', 'Diabetes is a chronic condition that affects how the body processes glucose (sugar). It occurs when the body cannot produce enough insulin or the insulin it produces is ineffective in regulating blood sugar. Insulin is a hormone produced by the pancreas that helps glucose enter the cells of the body for energy. Without sufficient insulin, glucose builds up in the bloodstream, leading to high blood sugar levels. Over time, uncontrolled diabetes can cause serious health complications such as heart disease, kidney damage, nerve damage, and vision problems. Proper management and treatment of diabetes are essential to preventing these complications and maintaining a good quality of life. Early detection, lifestyle changes, and medication are key factors in

## <span style="color:#1f77b4">**Enable MLflow tracing (optional)**</span>

Enable MLflow tracing so OpenAI calls are captured and stored under a known experiment path.


In [0]:
import os
import mlflow

# Toggle to enable/disable MLflow tracing for LLM calls.
ENABLE_MLFLOW_TRACING = True
EXPERIMENT_PATH = "/Shared/rag-traces"

# Ensure workspace host includes scheme for MLflow tracking.
host = spark.conf.get("spark.databricks.workspaceUrl", "")
if host and not host.startswith("https://"):
    host = f"https://{host}"
if host:
    os.environ["DATABRICKS_HOST"] = host

if ENABLE_MLFLOW_TRACING:
    try:
        mlflow.set_tracking_uri("databricks")
        mlflow.set_experiment(EXPERIMENT_PATH)
        try:
            mlflow.openai.autolog()
        except Exception:
            pass
    except Exception:
        pass
else:
    try:
        mlflow.tracing.disable()
    except Exception:
        pass


## <span style="color:#1f77b4">**Generate an answer with Azure OpenAI**</span>

Send the user query and retrieved context to Azure OpenAI and record a traced run in MLflow.


In [0]:
import os
import traceback
import mlflow

# Ensure MLflow has a proper workspace host to talk to.
host = spark.conf.get("spark.databricks.workspaceUrl", "")
if host and not host.startswith("https://"):
    host = f"https://{host}"
if host:
    os.environ["DATABRICKS_HOST"] = host

# Route traces to the workspace MLflow backend.
mlflow.set_tracking_uri("databricks")
mlflow.set_experiment("/Shared/rag-traces")

# Log a run so traces and errors are easy to find.
with mlflow.start_run() as run:
    try:
        gpt_response = client.chat.completions.create(
            model=deployment_name,
            messages=[
                {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
                {"role": "user", "content": f"user query : {user_question} and supporting knowledge: {content}"},
            ],
        )
        print(gpt_response.choices[0].message.content)
    except Exception as exc:
        mlflow.set_tag("run_error", str(exc))
        traceback.print_exc()
        raise


Diabetes is a chronic condition that affects how the body processes glucose (sugar). It happens when the body either doesnâ€™t produce enough insulin or cannot use the insulin it makes effectively. Insulin, a hormone made by the pancreas, helps glucose enter the bodyâ€™s cells to be used for energy. When insulin levels are low or not working properly, glucose builds up in the bloodstream, leading to high blood sugar levels. Over time, uncontrolled diabetes can cause serious health problems such as heart disease, kidney damage, nerve damage, and vision issues. Managing diabetes involves early detection, healthy lifestyle choices, and sometimes medication to keep blood sugar levels stable and prevent complications.


Trace(trace_id=tr-c6dd4e73e932205389ac43d0676bd56f)

## <span style="color:#1f77b4">**Define the RAG model (pyfunc)**</span>

Wrap retrieval + generation as a pyfunc model so it can be logged and served.


In [0]:
import mlflow
from mlflow import pyfunc
import pandas as pd


class RAGModel(pyfunc.PythonModel):
    def __init__(self, vector_index, openai_client, deployment_name):
        # Store the external dependencies used at inference time.
        self.vector_index = vector_index
        self.openai_client = openai_client
        self.deployment_name = deployment_name

    def retrieve(self, query):
        # Retrieve the best matching row from the vector index.
        results_dict = self.vector_index.similarity_search(
            query_text=query,
            columns=["Topic", "Description"],
            num_results=1,
        )
        return str(results_dict["result"]["data_array"][0])

    def chatCompletionsAPI(self, user_query, supporting_knowledge):
        # Send the query + context to the AOAI deployment.
        response = self.openai_client.chat.completions.create(
            model=self.deployment_name,
            messages=[
                {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
                {"role": "user", "content": f"user query : {user_query} and supporting knowledge: {supporting_knowledge}"},
            ],
        )
        return response.choices[0].message.content

    def predict(self, context, model_input: pd.DataFrame) -> pd.DataFrame:
        # Extract the query, retrieve context, then generate the answer.
        query = model_input["query"].iloc[0]
        supporting_knowledge = self.retrieve(query)
        answer = self.chatCompletionsAPI(query, supporting_knowledge)
        return pd.DataFrame({"answer": [answer]})


## <span style="color:#1f77b4">**Instantiate the model**</span>

Create a local instance so we can test behavior before logging to MLflow.


In [0]:
# Wire the model to the vector index and AOAI client.
test_model = RAGModel(vector_index=index, openai_client=client, deployment_name=deployment_name)


## <span style="color:#1f77b4">**Log a serving-safe model (Models-from-Code)**</span>

Write a serving-safe model script (no dbutils), log it to MLflow, and capture the resulting model URI.


In [0]:
# ============================================================
# Models-from-Code RAG model (SERVING-SAFE, OAuth)
# - Vector Search auth via Azure AD client credentials (DATABRICKS_CLIENT_ID/SECRET/TENANT_ID)
# - Azure OpenAI via AZURE_OPENAI_* env vars
# - dbutils only used OUTSIDE the model, for notebook convenience
# ============================================================

import os
import pandas as pd
import mlflow

# 0) (Notebook only) set env vars from secrets if not already set
if not os.getenv("AZURE_OPENAI_ENDPOINT"):
    os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get("aoai-scope", "openai-api-base")
if not os.getenv("AZURE_OPENAI_API_KEY"):
    os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get("aoai-scope", "openai-api-key")
if not os.getenv("AZURE_OPENAI_API_VERSION"):
    os.environ["AZURE_OPENAI_API_VERSION"] = dbutils.secrets.get("aoai-scope", "openai-api-version")

os.environ.setdefault("DATABRICKS_AUTH_TYPE", "oauth")
if not os.getenv("DATABRICKS_HOST"):
    os.environ["DATABRICKS_HOST"] = spark.conf.get("spark.databricks.workspaceUrl", "")
if not os.getenv("DATABRICKS_CLIENT_ID"):
    os.environ["DATABRICKS_CLIENT_ID"] = dbutils.secrets.get("dbx-sp-scope", "dbx-client-id")
if not os.getenv("DATABRICKS_CLIENT_SECRET"):
    os.environ["DATABRICKS_CLIENT_SECRET"] = dbutils.secrets.get("dbx-sp-scope", "dbx-client-secret")
if not os.getenv("DATABRICKS_TENANT_ID"):
    os.environ["DATABRICKS_TENANT_ID"] = dbutils.secrets.get("dbx-sp-scope", "dbx-tenant-id")

# Ensure the workspace host includes the scheme.
_host = os.environ.get("DATABRICKS_HOST", "")
if _host and not _host.startswith("https://"):
    os.environ["DATABRICKS_HOST"] = f"https://{_host}"

# MLflow OAuth path uses the Databricks SDK inside the served container.
os.environ.setdefault("MLFLOW_ENABLE_DB_SDK", "true")

print("AOAI env vars present:",
      bool(os.getenv("AZURE_OPENAI_ENDPOINT")),
      bool(os.getenv("AZURE_OPENAI_API_KEY")),
      bool(os.getenv("AZURE_OPENAI_API_VERSION")))

print("DBX OAuth env vars present:",
      bool(os.getenv("DATABRICKS_HOST")),
      bool(os.getenv("DATABRICKS_CLIENT_ID")),
      bool(os.getenv("DATABRICKS_CLIENT_SECRET")),
      bool(os.getenv("DATABRICKS_TENANT_ID")))

# -----------------------------
# 1) Config
# -----------------------------
ENDPOINT_NAME = "vector_search_endpoint"
INDEX_NAME = "adb_genai_super_locust.rag.diabetes_faq_index"

try:
    DEPLOYMENT_NAME = deployment_name
except NameError:
    DEPLOYMENT_NAME = "YOUR_AZURE_OPENAI_DEPLOYMENT_NAME"

# -----------------------------
# 2) Write serving-safe model script (NO dbutils inside)
# -----------------------------
script_path = "/tmp/rag_model_from_code.py"

script = f'''
import os
import pandas as pd
from typing import Any
from mlflow.pyfunc import PythonModel
from mlflow.models import set_model
from databricks.vector_search.client import VectorSearchClient
from openai import AzureOpenAI
from azure.identity import ClientSecretCredential

ENDPOINT_NAME = "{ENDPOINT_NAME}"
INDEX_NAME = "{INDEX_NAME}"
DEPLOYMENT_NAME = "{DEPLOYMENT_NAME}"
RESOURCE_SCOPE = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d/.default"

class RAGModel(PythonModel):
    def __init__(self, top_k: int = 1):
        self.top_k = top_k

    def _build_vector_client(self) -> VectorSearchClient:
        host = os.getenv("DATABRICKS_HOST")
        client_id = os.getenv("DATABRICKS_CLIENT_ID")
        client_secret = os.getenv("DATABRICKS_CLIENT_SECRET")
        tenant_id = os.getenv("DATABRICKS_TENANT_ID")

        if not (host and client_id and client_secret and tenant_id):
            raise RuntimeError(
                "Missing Databricks OAuth env vars. "
                "Set DATABRICKS_HOST, DATABRICKS_CLIENT_ID, DATABRICKS_CLIENT_SECRET, DATABRICKS_TENANT_ID."
            )

        if host and not host.startswith("https://"):
            host = f"https://{host}"

        credential = ClientSecretCredential(
            tenant_id=tenant_id,
            client_id=client_id,
            client_secret=client_secret,
        )
        token = credential.get_token(RESOURCE_SCOPE).token

        for kwargs in (
            {{"workspace_url": host, "personal_access_token": token}},
            {{"host": host, "token": token}},
            {{"endpoint": host, "token": token}},
        ):
            try:
                return VectorSearchClient(disable_notice=True, **kwargs)
            except TypeError:
                continue
        raise RuntimeError("VectorSearchClient init failed; check databricks-vectorsearch version.")
    def load_context(self, context: Any) -> None:
        vsc = self._build_vector_client()
        self.index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)

        aoai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
        aoai_key = os.getenv("AZURE_OPENAI_API_KEY")
        aoai_version = os.getenv("AZURE_OPENAI_API_VERSION")
        if not (aoai_endpoint and aoai_key and aoai_version):
            raise RuntimeError(
                "Missing Azure OpenAI env vars. "
                "Set AZURE_OPENAI_ENDPOINT / AZURE_OPENAI_API_KEY / AZURE_OPENAI_API_VERSION."
            )

        self.client = AzureOpenAI(
            api_key=aoai_key,
            api_version=aoai_version,
            azure_endpoint=aoai_endpoint,
        )

    def _retrieve(self, q: str) -> str:
        res = self.index.similarity_search(
            query_text=q,
            columns=["Topic", "Description"],
            num_results=self.top_k,
        )
        return str(res)

    def _chat(self, q: str, ctx: str) -> str:
        resp = self.client.chat.completions.create(
            model=DEPLOYMENT_NAME,
            messages=[
                {{"role": "system", "content": "Answer using the supporting knowledge."}},
                {{"role": "user", "content": "user query: " + q + "\\n" + "supporting knowledge: " + ctx }},
            ],
        )
        return resp.choices[0].message.content

    def predict(self, context: Any, model_input: pd.DataFrame) -> pd.DataFrame:
        queries = model_input["query"].astype(str).tolist()
        answers = []
        for q in queries:
            ctx = self._retrieve(q)
            answers.append(self._chat(q, ctx))
        return pd.DataFrame({{"answer": answers}})

set_model(RAGModel(top_k=1))
'''

with open(script_path, "w", encoding="utf-8") as f:
    f.write(script)

print(f"Wrote model script: {script_path}")

# -----------------------------
# 3) Log model + keep model_uri
# -----------------------------
input_example = pd.DataFrame([{"query": "what is diabetes?"}])

# MLflow tracking auth: avoid DB SDK conflicts when PAT is present
_saved_mlflow_sdk = os.environ.get("MLFLOW_ENABLE_DB_SDK")
if os.environ.get("DATABRICKS_TOKEN"):
    os.environ["MLFLOW_ENABLE_DB_SDK"] = "false"
    _host = os.environ.get("DATABRICKS_HOST", "")
    if _host and not _host.startswith("https://"):
        os.environ["DATABRICKS_HOST"] = f"https://{_host}"

with mlflow.start_run() as run:
    model_info = mlflow.pyfunc.log_model(
        python_model=script_path,
        name="rag_model",
        input_example=input_example,
        pip_requirements=[
            "mlflow",
            "pandas",
            "openai",
            "databricks-vectorsearch",
            "azure-identity",
            "databricks-sdk",
        ],
    )
    model_uri = model_info.model_uri
    mlflow.log_text(model_uri, "model_uri.txt")

# Restore MLflow DB SDK setting after logging
if _saved_mlflow_sdk is None:
    os.environ.pop("MLFLOW_ENABLE_DB_SDK", None)
else:
    os.environ["MLFLOW_ENABLE_DB_SDK"] = _saved_mlflow_sdk

print("Logged model_uri:", model_uri)
print("Run ID:", run.info.run_id)


AOAI env vars present: True True True
DBX OAuth env vars present: True True True True
Wrote model script: /tmp/rag_model_from_code.py


ðŸ”— View Logged Model at: https://adb-7405608176797015.15.azuredatabricks.net/ml/experiments/2081252236371713/models/m-ef6b45a931294119bb7022ca1cd91844?o=7405608176797015
2026/01/01 18:00:55 INFO mlflow.pyfunc: Inferring model signature from input example


[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.


2026/01/01 18:02:52 INFO mlflow.models.model: Found the following environment variables used during model inference: [AZURE_OPENAI_API_KEY, DATABRICKS_CLIENT_ID, DATABRICKS_CLIENT_SECRET, ... ]. Please check if you need to set them when deploying the model. To disable this message, set environment variable `MLFLOW_RECORD_ENV_VARS_IN_MODEL_LOGGING` to `false`.


Logged model_uri: models:/m-ef6b45a931294119bb7022ca1cd91844
Run ID: 3f34a580803542a0ac187aa17618c5c5


## <span style="color:#1f77b4">**Load the logged model**</span>

Reload the logged model from MLflow so you can validate the inference path.


In [0]:
# ============================================================
# LOAD THE LATEST LOGGED MODEL (from the previous cell)
# ============================================================

import mlflow
import pandas as pd

# This variable is created in the previous cell:
#   model_uri = model_info.model_uri
if "model_uri" not in globals():
    raise RuntimeError(
        "model_uri is not defined. Run the previous 'log model' cell first."
    )

# Load the pyfunc model from MLflow.
loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri)
print("Loaded model from:", model_uri)


Downloading artifacts:   0%|          | 0/13 [00:00<?, ?it/s]

âœ… Loaded model from: models:/m-ef6b45a931294119bb7022ca1cd91844
[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
                                              answer
0  Diabetes is a chronic condition that affects h...


Trace(trace_id=tr-dd7baf1ec995ea8cfbcb5048f7dd93f6)

## <span style="color:#1f77b4">**Test the loaded model**</span>

Run a sample query against the loaded pyfunc model to confirm end-to-end behavior.


In [0]:
# Build a small test payload.
model_input = pd.DataFrame([{"query": "what is diabetes?"}])

# Execute the model's predict path.
model_response = loaded_pyfunc_model.predict(model_input)

print(model_response)


[NOTICE] Using a Personal Authentication Token (PAT). Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
                                              answer
0  Diabetes is a chronic condition that affects h...


Trace(trace_id=tr-82c19d8fa676f3af36ebc121479e1a5e)

## <span style="color:#1f77b4">**Register the model**</span>

Register the model in the workspace registry so the serving endpoint can reference a version.


In [0]:
import mlflow

# Use the workspace registry (not Unity Catalog) for this flow.
mlflow.set_registry_uri("databricks")
registered = mlflow.register_model(model_uri=model_uri, name="rag_model")

print("Registered:", registered.name, "v", registered.version)


Registered model 'rag_model' already exists. Creating a new version of this model...
2026/01/01 18:04:56 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: rag_model, version 6


âœ… Registered: rag_model v 6


Created version '6' of model 'rag_model'.


## <span style="color:#1f77b4">**Call the serving endpoint**</span>

Call the serving endpoint with a Databricks PAT from Key Vault to validate real-time inference.


In [0]:
import os
import json
import requests
import pandas as pd

# Ensure PAT is in env (from your Key Vault-backed secret scope).
if not os.getenv("DATABRICKS_TOKEN"):
    os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get("aoai-scope", "databricks-pat")


def create_tf_serving_json(data):
    return {"inputs": {name: data[name].tolist() for name in data.keys()}} if isinstance(data, dict) else data.tolist()


def score_model(dataset):
    url = "https://adb-7405608176797015.15.azuredatabricks.net/serving-endpoints/rag-model-endpoint-otter/invocations"
    headers = {
        "Authorization": f"Bearer {os.environ.get('DATABRICKS_TOKEN')}",
        "Content-Type": "application/json",
    }
    ds_dict = {"dataframe_split": dataset.to_dict(orient="split")} if isinstance(dataset, pd.DataFrame) else create_tf_serving_json(dataset)
    data_json = json.dumps(ds_dict, allow_nan=True)
    response = requests.request(method="POST", headers=headers, url=url, data=data_json)
    if response.status_code != 200:
        raise Exception(f"Request failed with status {response.status_code}, {response.text}")
    return response.json()

# Example call payload.
payload = pd.DataFrame([{"query": "what is diabetes?"}])
score_model(payload)


{'predictions': [{'answer': 'Diabetes is a chronic condition that affects how the body processes glucose (sugar). It develops when the body either doesnâ€™t produce enough insulin or cannot use insulin effectively. Insulin is a hormone made by the pancreas that allows glucose to enter cells to be used for energy. When insulin levels are low or the body becomes resistant to insulin, glucose builds up in the blood, leading to high blood sugar levels.  \n\nUncontrolled diabetes can cause serious complications over time, including heart disease, kidney damage, nerve problems, and vision impairment. Managing diabetes effectively involves early detection, healthy lifestyle habits (such as proper diet and regular exercise), and, in some cases, medication. These steps help keep blood sugar levels stable and reduce the risk of long-term complications.'}]}