# GenAI with Azure Databricks - Developing RAG System

### Loading the csv file into the DBFS (Databricks File System)

In [0]:
# Configure Unity Catalog widgets and resolve the active catalog.

dbutils.widgets.removeAll()
dbutils.widgets.text("CATALOG", "")
dbutils.widgets.text("SCHEMA", "rag")
dbutils.widgets.text("VOLUME", "raw")
dbutils.widgets.text("EXTERNAL_LOCATION", "uc-external-location")

catalog_widget = dbutils.widgets.get("CATALOG")
if catalog_widget:
    catalog_name = catalog_widget
else:
    current = spark.sql("SELECT current_catalog()").first()[0]
    catalogs = [r.catalog for r in spark.sql("SHOW CATALOGS").collect()]
    catalog_name = current if current not in ("system",) else next(c for c in catalogs if c not in ("system",))

schema_name = dbutils.widgets.get("SCHEMA")
volume_leaf = dbutils.widgets.get("VOLUME")
external_location_name = dbutils.widgets.get("EXTERNAL_LOCATION")

table_name = f"{catalog_name}.{schema_name}.diabetes_faq_table"
index_name = f"{catalog_name}.{schema_name}.diabetes_faq_index"
volume_name = f"{catalog_name}.{schema_name}.{volume_leaf}"

spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog_name}.{schema_name}")

location_rows = spark.sql(f"DESCRIBE EXTERNAL LOCATION `{external_location_name}`")
if "url" in location_rows.columns:
    external_url = location_rows.select("url").first()["url"].rstrip("/")
else:
    external_url = location_rows.filter("key = \"url\"").select("value").first()["value"].rstrip("/")

spark.sql(
    f"""CREATE EXTERNAL VOLUME IF NOT EXISTS {volume_name}
    LOCATION '{external_url}'
    """
)

data_path = f"dbfs:/Volumes/{catalog_name}/{schema_name}/{volume_leaf}/diabetes_treatment_faq.csv"


### Loading the csv file into a dataframe

In [0]:
from pyspark.sql.functions import *

df = spark.read.csv(data_path, header=True)
display(df.limit(10))
df.printSchema()


Topic,Description
What is diabetes?,"Diabetes is a chronic condition that affects how the body processes glucose (sugar). It occurs when the body cannot produce enough insulin or the insulin it produces is ineffective in regulating blood sugar. Insulin is a hormone produced by the pancreas that helps glucose enter the cells of the body for energy. Without sufficient insulin, glucose builds up in the bloodstream, leading to high blood sugar levels. Over time, uncontrolled diabetes can cause serious health complications such as heart disease, kidney damage, nerve damage, and vision problems. Proper management and treatment of diabetes are essential to preventing these complications and maintaining a good quality of life. Early detection, lifestyle changes, and medication are key factors in effectively managing the disease."
What are the different types of diabetes?,"Diabetes is categorized into two main types: Type 1 and Type 2. Type 1 diabetes is an autoimmune condition where the body’s immune system attacks and destroys the insulin-producing cells in the pancreas, leading to little or no insulin production. It typically develops in children or young adults and requires lifelong insulin therapy. Type 2 diabetes, on the other hand, occurs when the body becomes resistant to insulin or does not produce enough insulin to meet the body’s needs. It is more common in adults, particularly those who are overweight, inactive, or have a family history of the disease. While Type 1 is not preventable, Type 2 can often be prevented or delayed through lifestyle changes, including diet and exercise."
What are the symptoms of diabetes?,"The symptoms of diabetes can vary depending on the type and how long the condition has been present. Common signs include frequent urination, excessive thirst, hunger, and unexplained weight loss. Some people may experience blurred vision, fatigue, and slow-healing wounds. In the case of Type 1 diabetes, symptoms often develop rapidly, while Type 2 diabetes symptoms may be more subtle and develop over time. Because the early symptoms may not always be noticeable, it is important to get regular check-ups, especially if you are at risk for diabetes. Uncontrolled diabetes can lead to serious complications, so timely diagnosis and treatment are essential."
How is diabetes diagnosed?,"Diabetes is diagnosed through various blood tests. The fasting blood glucose test measures blood sugar levels after an overnight fast, while the oral glucose tolerance test checks how well the body processes sugar after consuming a sugary drink. The HbA1c test, which reflects the average blood sugar levels over the past 2-3 months, is also commonly used to diagnose and monitor diabetes. An HbA1c level of 6.5% or higher is typically indicative of diabetes. A diagnosis may also involve checking for other conditions associated with diabetes, such as high blood pressure or cholesterol imbalances. Early detection allows for better management and prevention of complications."
What is the role of insulin in diabetes?,"Insulin is a hormone produced by the pancreas that helps regulate blood sugar levels by allowing glucose to enter cells for energy. In people with diabetes, either the body does not produce enough insulin (Type 1 diabetes) or the body’s cells do not respond effectively to insulin (Type 2 diabetes). As a result, glucose accumulates in the bloodstream, leading to high blood sugar. Insulin therapy, typically in the form of injections or an insulin pump, helps to lower blood sugar levels and mimic the body’s natural insulin production. Insulin is a crucial part of managing diabetes, particularly for those with Type 1, and can also be used in Type 2 when lifestyle changes and oral medications are not sufficient."
What are the treatment options for type 1 diabetes?,"For people with Type 1 diabetes, treatment primarily involves lifelong insulin therapy to manage blood sugar levels. Insulin can be administered through injections or via an insulin pump. In addition to insulin, individuals with Type 1 diabetes must closely monitor their blood sugar levels throughout the day, maintain a healthy diet, and engage in regular physical activity to help manage their condition. People with Type 1 diabetes also need to be vigilant about potential complications, such as diabetic ketoacidosis, and should regularly consult with their healthcare providers to adjust their treatment plan as needed."
What are the treatment options for type 2 diabetes?,"For Type 2 diabetes, treatment typically starts with lifestyle modifications such as a balanced diet, regular exercise, and weight management. If lifestyle changes are not sufficient, oral medications such as metformin may be prescribed to help regulate blood sugar levels. In some cases, individuals with Type 2 diabetes may require insulin therapy or other injectable medications to improve insulin sensitivity. For those with severe Type 2 diabetes or obesity, bariatric surgery may be considered to help improve blood sugar control. The goal of treatment is to achieve normal blood sugar levels and prevent complications such as heart disease, kidney damage, or nerve damage."
How can I manage my blood sugar levels?,"Managing blood sugar levels is essential for diabetes control. The key factors in blood sugar management include regular monitoring, taking prescribed medications, following a healthy diet, exercising regularly, and managing stress. Monitoring your blood sugar levels allows you to understand how food, exercise, and medications affect your glucose. A balanced diet rich in fiber, lean proteins, and healthy fats can help regulate blood sugar. Exercise increases insulin sensitivity, which helps control blood sugar levels. Stress management techniques, such as yoga and relaxation exercises, are also important for maintaining stable blood sugar. Regular doctor visits are essential to adjust treatment plans as needed."
What is a healthy diet for someone with diabetes?,"A healthy diet for people with diabetes focuses on foods that help regulate blood sugar levels and promote overall health. This includes consuming whole grains, vegetables, lean proteins, and healthy fats while limiting foods with a high glycemic index, such as refined sugars and processed foods. Eating smaller, more frequent meals can also help maintain steady blood sugar levels throughout the day. Carbohydrate counting is a common practice for those with diabetes, as carbohydrates have a direct impact on blood sugar. It is important to work with a dietitian to develop a meal plan that meets individual needs and preferences while promoting blood sugar control."
How often should I check my blood sugar levels?,"Blood sugar monitoring is essential for people with diabetes to understand how their body responds to different foods, activities, and medications. The frequency of testing depends on the type of diabetes and treatment plan. For individuals with Type 1 diabetes or those on insulin, blood sugar may need to be tested multiple times per day, including before and after meals or exercise. For Type 2 diabetes, the frequency of testing may vary based on medication use and blood sugar control. A healthcare provider can provide personalized recommendations for how often to check blood sugar levels, ensuring that levels remain within the target range to prevent complications."


root
 |-- Topic: string (nullable = true)
 |-- Description: string (nullable = true)



### Installing the openai SDK in our python kernel

In [0]:
# Libraries are installed on the cluster via Terraform.
# - openai
# - databricks-vectorsearch


### Restarting our python kernel

In [0]:
# No restart needed when cluster libraries are pre-installed.


### Creating an Azure OpenAI Client


In [0]:
from openai import AzureOpenAI
import json

openai_endpoint = dbutils.secrets.get("aoai-scope", "openai-api-base")
openai_key = dbutils.secrets.get("aoai-scope", "openai-api-key")
openai_api_version = dbutils.secrets.get("aoai-scope", "openai-api-version")
deployment_name = dbutils.secrets.get("aoai-scope", "openai-deployment-name")

client = AzureOpenAI(
    api_key=openai_key,
    api_version=openai_api_version,
    azure_endpoint=openai_endpoint,
)


### Saving the updated/new dataframe into ADLS as parquet storage

In [0]:
# Save the updated DataFrame as a Parquet file or table
parquet_path = f"dbfs:/Volumes/{catalog_name}/{schema_name}/raw/diabetes_faq.parquet"
df.write.mode("overwrite").parquet(parquet_path)
df.write.format("delta").mode("overwrite").saveAsTable(table_name)


### Installing the databricks vectorsearch SDK

In [0]:
# Libraries are installed on the cluster via Terraform.
# - openai
# - databricks-vectorsearch


### Restarting our python environment

In [0]:
# No restart needed when cluster libraries are pre-installed.


### Enabling Change Data Feed on Our Table

In [0]:
# Enable change data feed for the existing Delta table
spark.sql(
    f"""
ALTER TABLE {table_name}
SET TBLPROPERTIES (delta.enableChangeDataFeed = true)
    """
)


DataFrame[]

### Developing the Cluster managed Vector index

In [0]:
# ============================================================
# Databricks Vector Search: quota-safe endpoint reuse + index
# ============================================================

from databricks.vector_search.client import VectorSearchClient
import time

vector_client = VectorSearchClient()  # optionally: VectorSearchClient(disable_notice=True)

# ------------------------------------------------------------
# Assumes these are defined earlier in your notebook:
#   table_name (e.g. "catalog.schema.table")
#   index_name (e.g. "catalog.schema.index")
# Optional:
#   endpoint_name (preferred endpoint)
# ------------------------------------------------------------

preferred_endpoint_name = globals().get("endpoint_name", "vector_search_endpoint")

# ------------------------------------------------------------
# Endpoint helpers
# ------------------------------------------------------------

def list_endpoints_safe(client):
    if not hasattr(client, "list_endpoints"):
        return []
    eps = client.list_endpoints()
    if isinstance(eps, dict):
        eps = eps.get("endpoints") or eps.get("vector_search_endpoints") or []
    return eps or []

def get_endpoint_info_safe(client, name):
    if hasattr(client, "get_endpoint"):
        try:
            return client.get_endpoint(name)
        except Exception:
            return None
    # fallback to list
    for ep in list_endpoints_safe(client):
        if isinstance(ep, dict) and ep.get("name") == name:
            return ep
    return None

def pick_endpoint_name(client, preferred):
    """
    Return an endpoint to use:
    1) preferred if it exists
    2) otherwise the first existing endpoint
    3) otherwise None
    """
    if get_endpoint_info_safe(client, preferred) is not None:
        return preferred

    eps = list_endpoints_safe(client)
    if eps:
        first = eps[0]
        return first.get("name") if isinstance(first, dict) else first

    return None

def ensure_endpoint(client, preferred_name):
    """
    Ensure we have an endpoint name to use, respecting quota.
    If quota exceeded, reuse an existing endpoint.
    """
    # If it already exists, reuse it
    existing = pick_endpoint_name(client, preferred_name)
    if existing:
        return existing

    # Otherwise attempt to create it
    try:
        client.create_endpoint(name=preferred_name, endpoint_type="STANDARD")
        return preferred_name
    except Exception as exc:
        msg = str(exc)
        # Quota hit: reuse any existing endpoint
        if "QUOTA_EXCEEDED" in msg or "Maximum number of vector search endpoints" in msg:
            existing = pick_endpoint_name(client, preferred_name)
            if existing:
                return existing
        # Anything else is real
        raise

def wait_for_endpoint_online(client, name, timeout_s=1800, poll_s=15):
    deadline = time.time() + timeout_s
    last = None
    while time.time() < deadline:
        info = get_endpoint_info_safe(client, name)
        last = info
        state = None
        if isinstance(info, dict):
            st = info.get("endpoint_status") or info.get("status") or {}
            state = st.get("state") if isinstance(st, dict) else st
        print(f"[{time.strftime('%H:%M:%S')}] endpoint={name} state={state}")
        if state in ("ONLINE", "READY"):
            return info
        if state in ("FAILED", "ERROR"):
            raise RuntimeError(f"Endpoint failure: {info}")
        time.sleep(poll_s)
    raise TimeoutError(f"Endpoint not online after {timeout_s}s. Last: {last}")

# ------------------------------------------------------------
# Index helpers (matches your describe() payload)
# ------------------------------------------------------------

def wait_for_index_ready(index, timeout_s=1800, poll_s=15):
    deadline = time.time() + timeout_s
    last = None
    while time.time() < deadline:
        info = index.describe()
        last = info
        status = info.get("status") or {}
        detailed_state = status.get("detailed_state")
        ready = status.get("ready")
        print(f"[{time.strftime('%H:%M:%S')}] index_state={detailed_state} ready={ready}")
        if ready is True:
            return info
        if isinstance(detailed_state, str) and ("FAILED" in detailed_state or "ERROR" in detailed_state):
            raise RuntimeError(f"Index failure: {info}")
        time.sleep(poll_s)
    raise TimeoutError(f"Index not ready after {timeout_s}s. Last: {last}")

# ------------------------------------------------------------
# 1) Get endpoint without exceeding quota
# ------------------------------------------------------------

endpoint_name = ensure_endpoint(vector_client, preferred_endpoint_name)
print(f"✅ Using endpoint: {endpoint_name}")

# 2) Wait for endpoint
wait_for_endpoint_online(vector_client, endpoint_name)

# ------------------------------------------------------------
# 3) Get or create index
# ------------------------------------------------------------

try:
    index = vector_client.get_index(endpoint_name, index_name)
    print("ℹ️ Index already exists.")
except Exception:
    print("ℹ️ Creating index...")
    index = vector_client.create_delta_sync_index(
        endpoint_name=endpoint_name,
        source_table_name=table_name,
        index_name=index_name,
        pipeline_type="TRIGGERED",
        primary_key="Topic",
        embedding_source_column="Description",
        embedding_model_endpoint_name="databricks-gte-large-en",
    )

# 4) Trigger sync (optional, safe)
try:
    index.sync()
    print("ℹ️ index.sync() triggered.")
except Exception as exc:
    if "not supported" not in str(exc).lower():
        raise

# 5) Wait for readiness
final_info = wait_for_index_ready(index)

print("\n✅ Index READY")
print(final_info["status"])


[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
✅ Using endpoint: vector_search_endpoint
[20:10:12] endpoint=vector_search_endpoint state=ONLINE
ℹ️ Index already exists.
ℹ️ index.sync() triggered.
[20:10:14] index_state=ONLINE_UPDATING_PIPELINE_RESOURCES ready=True

✅ Index READY
{'detailed_state': 'ONLINE_UPDATING_PIPELINE_RESOURCES', 'message': 'Index is currently online, pipeline update is pending setup of pipeline resources. Check latest status: https://adb-7405608176797015.15.azuredatabricks.net/explore/data/adb_genai_super_locust/rag/diabetes_faq_index', 'indexed_row_count': 10, 'triggered_update_status': {'last_processed_commit_version': 14, 'last_processed_commit_timestamp': '2025-12-31T20:09:02Z'}, 'ready': True, 'index_url': 'adb-7405608176797015.15.azuredatabricks.net/api/2.0/vector-search/indexes/adb_genai_super_lo

### Triggering our Vector Index - Information Retriever

In [0]:
user_question = "what is diabetes?"

results_dict = index.similarity_search(
    query_text=user_question,
    columns=["Topic", "Description"],
    num_results=1,
)

content = str(results_dict["result"]["data_array"][0])
print(content)


[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
['What is diabetes?', 'Diabetes is a chronic condition that affects how the body processes glucose (sugar). It occurs when the body cannot produce enough insulin or the insulin it produces is ineffective in regulating blood sugar. Insulin is a hormone produced by the pancreas that helps glucose enter the cells of the body for energy. Without sufficient insulin, glucose builds up in the bloodstream, leading to high blood sugar levels. Over time, uncontrolled diabetes can cause serious health complications such as heart disease, kidney damage, nerve damage, and vision problems. Proper management and treatment of diabetes are essential to preventing these complications and maintaining a good quality of life. Early detection, lifestyle changes, and medication are key factors in effec

### Developing the Generation Component of our RAG architecture


In [0]:
gpt_response = client.chat.completions.create(
    model=deployment_name,
    messages=[
        {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
        {"role": "user", "content": f"user query : {user_question} and supporting knowledge: {content}"},
    ],
)
print(gpt_response.choices[0].message.content)


Diabetes is a long-term condition that affects how the body regulates blood sugar (glucose). It occurs when the body doesn’t produce enough insulin or when the insulin it makes doesn’t work properly. Insulin is a hormone made by the pancreas that allows glucose from food to enter cells and be used for energy. When this process doesn’t work as it should, glucose builds up in the blood, leading to high blood sugar levels.  

If diabetes isn’t well managed, it can lead to serious complications such as heart disease, kidney damage, nerve problems, and vision loss. Effective management involves early detection, healthy lifestyle choices, regular monitoring, and in some cases, medication or insulin therapy.


Trace(trace_id=tr-06d8bf357e6b0ca415e312e51b806317)

### Developing the RAG model

In [0]:
import mlflow
from mlflow import pyfunc

class RAGModel(pyfunc.PythonModel):
    def __init__(self, vector_index, openai_client, deployment_name):
        self.vector_index = vector_index
        self.openai_client = openai_client
        self.deployment_name = deployment_name

    def retrieve(self, query):
        results_dict = self.vector_index.similarity_search(
            query_text=query,
            columns=["Topic", "Description"],
            num_results=1,
        )
        return str(results_dict["result"]["data_array"][0])

    def chatCompletionsAPI(self, user_query, supporting_knowledge):
        response = self.openai_client.chat.completions.create(
            model=self.deployment_name,
            messages=[
                {"role": "system", "content": "You are a helpful assistant. You will be passed the user query and the supporting knowledge that can be used to answer the user_query"},
                {"role": "user", "content": f"user query : {user_query} and supporting knowledge: {supporting_knowledge}"},
            ],
        )
        return response.choices[0].message.content

    def predict(self, context, data):
        query = data["query"].iloc[0]
        supporting_knowledge = self.retrieve(query)
        return self.chatCompletionsAPI(query, supporting_knowledge)


  param_names = _check_func_signature(func, "predict")


### Saving our Model

In [0]:
test_model = RAGModel(vector_index=index, openai_client=client, deployment_name=deployment_name)


In [0]:
os.environ["DATABRICKS_HOST"] = "https://eastus2-c3.azuredatabricks.net"
os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get("YOUR_SCOPE", "YOUR_PAT_SECRET")


[0;31m---------------------------------------------------------------------------[0m
[0;31mIllegalArgumentException[0m                  Traceback (most recent call last)
File [0;32m<command-5844878010764383>, line 2[0m
[1;32m      1[0m os[38;5;241m.[39menviron[[38;5;124m"[39m[38;5;124mDATABRICKS_HOST[39m[38;5;124m"[39m] [38;5;241m=[39m [38;5;124m"[39m[38;5;124mhttps://eastus2-c3.azuredatabricks.net[39m[38;5;124m"[39m
[0;32m----> 2[0m os[38;5;241m.[39menviron[[38;5;124m"[39m[38;5;124mDATABRICKS_TOKEN[39m[38;5;124m"[39m] [38;5;241m=[39m dbutils[38;5;241m.[39msecrets[38;5;241m.[39mget([38;5;124m"[39m[38;5;124mYOUR_SCOPE[39m[38;5;124m"[39m, [38;5;124m"[39m[38;5;124mYOUR_PAT_SECRET[39m[38;5;124m"[39m)

File [0;32m/databricks/python_shell/lib/dbruntime/dbutils.py:376[0m, in [0;36mDBUtils.SecretsHandler.get[0;34m(self, scope, catalog, schema, key, *posArgs)[0m
[1;32m    373[0m     scope [38;5;241m=[39m argsToPass[[38;5;241m0[39m]


In [0]:
# ============================================================
# Models-from-Code RAG model (SERVING-SAFE)
# - Vector Search auth via env vars (DATABRICKS_HOST / DATABRICKS_TOKEN)
# - Azure OpenAI via env vars (AZURE_OPENAI_*)
# - dbutils only used OUTSIDE the model, for notebook convenience
# ============================================================

import os
import pandas as pd
import mlflow

# -----------------------------
# 0) (Notebook only) set env vars from secrets if not already set
#    Serving will NOT use dbutils, so this is only to make notebook tests work.
# -----------------------------
if not os.getenv("AZURE_OPENAI_ENDPOINT"):
    os.environ["AZURE_OPENAI_ENDPOINT"] = dbutils.secrets.get("aoai-scope", "openai-api-base")
if not os.getenv("AZURE_OPENAI_API_KEY"):
    os.environ["AZURE_OPENAI_API_KEY"] = dbutils.secrets.get("aoai-scope", "openai-api-key")
if not os.getenv("AZURE_OPENAI_API_VERSION"):
    os.environ["AZURE_OPENAI_API_VERSION"] = dbutils.secrets.get("aoai-scope", "openai-api-version")

# IMPORTANT: Serving needs these too (set in Serving endpoint config)
# For notebook testing, you can set them here once if you want:
# os.environ["DATABRICKS_HOST"] = "https://eastus2-c3.azuredatabricks.net"
# os.environ["DATABRICKS_TOKEN"] = dbutils.secrets.get("your-scope", "your-pat-key")

print("✅ AOAI env vars present:",
      bool(os.getenv("AZURE_OPENAI_ENDPOINT")),
      bool(os.getenv("AZURE_OPENAI_API_KEY")),
      bool(os.getenv("AZURE_OPENAI_API_VERSION")))

print("✅ DBX env vars present:",
      bool(os.getenv("DATABRICKS_HOST")),
      bool(os.getenv("DATABRICKS_TOKEN")))

# -----------------------------
# 1) Config
# -----------------------------
ENDPOINT_NAME = "vector_search_endpoint"
INDEX_NAME = "adb_genai_super_locust.rag.diabetes_faq_index"

try:
    DEPLOYMENT_NAME = deployment_name
except NameError:
    DEPLOYMENT_NAME = "YOUR_AZURE_OPENAI_DEPLOYMENT_NAME"

# -----------------------------
# 2) Write serving-safe model script (NO dbutils inside)
# -----------------------------
script_path = "/tmp/rag_model_from_code.py"

script = f'''
import os
import pandas as pd
from typing import Any
from mlflow.pyfunc import PythonModel
from mlflow.models import set_model
from databricks.vector_search.client import VectorSearchClient
from openai import AzureOpenAI

ENDPOINT_NAME = "{ENDPOINT_NAME}"
INDEX_NAME = "{INDEX_NAME}"
DEPLOYMENT_NAME = "{DEPLOYMENT_NAME}"

class RAGModel(PythonModel):
    def __init__(self, top_k: int = 1):
        self.top_k = top_k

    def load_context(self, context: Any) -> None:
        # -------------------------
        # Vector Search auth (Serving-safe)
        # -------------------------
        host = os.getenv("DATABRICKS_HOST")
        token = os.getenv("DATABRICKS_TOKEN")
        if not (host and token):
            raise RuntimeError(
                "Missing DATABRICKS_HOST / DATABRICKS_TOKEN. "
                "Set these in the Serving endpoint environment variables."
            )

        # VectorSearchClient constructor can vary slightly by workspace version,
        # so we try common parameter names.
        try:
            vsc = VectorSearchClient(workspace_url=host, personal_access_token=token, disable_notice=True)
        except TypeError:
            try:
                vsc = VectorSearchClient(host=host, token=token, disable_notice=True)
            except TypeError as exc:
                raise RuntimeError(f"VectorSearchClient auth init failed: {{exc}}") from exc

        self.index = vsc.get_index(ENDPOINT_NAME, INDEX_NAME)

        # -------------------------
        # Azure OpenAI auth (Serving-safe)
        # -------------------------
        aoai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
        aoai_key = os.getenv("AZURE_OPENAI_API_KEY")
        aoai_version = os.getenv("AZURE_OPENAI_API_VERSION")
        if not (aoai_endpoint and aoai_key and aoai_version):
            raise RuntimeError(
                "Missing Azure OpenAI env vars. "
                "Set AZURE_OPENAI_ENDPOINT / AZURE_OPENAI_API_KEY / AZURE_OPENAI_API_VERSION "
                "in the Serving endpoint environment variables."
            )

        self.client = AzureOpenAI(
            api_key=aoai_key,
            api_version=aoai_version,
            azure_endpoint=aoai_endpoint,
        )

    def _retrieve(self, q: str) -> str:
        res = self.index.similarity_search(
            query_text=q,
            columns=["Topic", "Description"],
            num_results=self.top_k,
        )
        return str(res)

    def _chat(self, q: str, ctx: str) -> str:
        resp = self.client.chat.completions.create(
            model=DEPLOYMENT_NAME,
            messages=[
                {{"role": "system", "content": "Answer using the supporting knowledge."}},
                {{"role": "user", "content": f"user query: {{q}}\\nsupporting knowledge: {{ctx}}" }},
            ],
        )
        return resp.choices[0].message.content

    def predict(self, context: Any, model_input: pd.DataFrame) -> pd.DataFrame:
        queries = model_input["query"].astype(str).tolist()
        answers = []
        for q in queries:
            ctx = self._retrieve(q)
            answers.append(self._chat(q, ctx))
        return pd.DataFrame({{"answer": answers}})

set_model(RAGModel(top_k=1))
'''

with open(script_path, "w", encoding="utf-8") as f:
    f.write(script)

print(f"✅ Wrote model script: {script_path}")

# -----------------------------
# 3) Log model + keep model_uri
# -----------------------------
input_example = pd.DataFrame([{"query": "what is diabetes?"}])

with mlflow.start_run() as run:
    model_info = mlflow.pyfunc.log_model(
        python_model=script_path,
        name="rag_model",
        input_example=input_example,
        pip_requirements=[
            "mlflow",
            "pandas",
            "openai",
            "databricks-vectorsearch",
        ],
    )
    model_uri = model_info.model_uri
    mlflow.log_text(model_uri, "model_uri.txt")

print("✅ Logged model_uri:", model_uri)
print("✅ Run ID:", run.info.run_id)

# -----------------------------
# 4) Quick notebook test (optional)
# -----------------------------
loaded = mlflow.pyfunc.load_model(model_uri)
print(loaded.predict(pd.DataFrame([{"query": "what is diabetes?"}])))


✅ AOAI env vars present: True True True
✅ DBX env vars present: False False
✅ Wrote model script: /tmp/rag_model_from_code.py


🔗 View Logged Model at: https://adb-7405608176797015.15.azuredatabricks.net/ml/experiments/2016950050110299/models/m-0a6d8f405dbe43abacae10002d9aeba5?o=7405608176797015


[0;31m---------------------------------------------------------------------------[0m
[0;31mRuntimeError[0m                              Traceback (most recent call last)
File [0;32m<command-5475092539584904>, line 153[0m
[1;32m    150[0m input_example [38;5;241m=[39m pd[38;5;241m.[39mDataFrame([{[38;5;124m"[39m[38;5;124mquery[39m[38;5;124m"[39m: [38;5;124m"[39m[38;5;124mwhat is diabetes?[39m[38;5;124m"[39m}])
[1;32m    152[0m [38;5;28;01mwith[39;00m mlflow[38;5;241m.[39mstart_run() [38;5;28;01mas[39;00m run:
[0;32m--> 153[0m     model_info [38;5;241m=[39m mlflow[38;5;241m.[39mpyfunc[38;5;241m.[39mlog_model(
[1;32m    154[0m         python_model[38;5;241m=[39mscript_path,
[1;32m    155[0m         name[38;5;241m=[39m[38;5;124m"[39m[38;5;124mrag_model[39m[38;5;124m"[39m,
[1;32m    156[0m         input_example[38;5;241m=[39minput_example,
[1;32m    157[0m         pip_requirements[38;5;241m=[39m[
[1;32m    158[0m             [

### Loading Our Saved Model

In [0]:
# ============================================================
# LOAD THE LATEST LOGGED MODEL (from the previous cell)
# ============================================================

import mlflow
import pandas as pd

# This variable is created in the previous cell:
#   model_uri = model_info.model_uri
if "model_uri" not in globals():
    raise RuntimeError(
        "model_uri is not defined. Run the previous 'log model' cell first."
    )

loaded_pyfunc_model = mlflow.pyfunc.load_model(model_uri)

# Quick sanity test
test_df = pd.DataFrame([{"query": "what is diabetes?"}])
print("✅ Loaded model from:", model_uri)
print(loaded_pyfunc_model.predict(test_df))




### Testing our Loaded/Saved Model

In [0]:
model_input = pd.DataFrame([{"query": "what is diabetes?"}])

model_response = loaded_pyfunc_model.predict(model_input)

print(model_response)



### Logging our saved model as an artifact

In [0]:
import mlflow

# Log the model as an artifact
with mlflow.start_run() as run:
    mlflow.log_artifacts(local_dir=model_path, artifact_path="rag_model")
    print(f"Model logged with run ID: {run.info.run_id}")




In [0]:
import mlflow

mlflow.set_registry_uri("databricks")  # workspace registry (not UC)
registered = mlflow.register_model(model_uri=model_uri, name="rag_model")

print("✅ Registered:", registered.name, "v", registered.version)



### Inferencing the real-time endpoint

In [0]:
{
  "dataframe_records":[
    {
        "query":"what is diabetes?"
    }
  ]
}

