In [0]:
import logging
import io
from pathlib import Path
import gradio as gr
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ChatMessage, ChatMessageRole
from maud.interface.config import load_config
import json
from PIL import Image

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize the Databricks Workspace Client
workspace_client = WorkspaceClient()

# Load configuration
app_config = load_config("../implementations/interfaces/doc_retriever/app_config.yaml")

In [0]:
messages = [ChatMessage(role=ChatMessageRole.USER, content="What is machine learning?")]

In [0]:
response = workspace_client.serving_endpoints.query(
        name=app_config.serving_endpoint, messages=messages
    )

In [0]:
import requests

# Get the workspace URL and token from the workspace client
workspace_url = workspace_client.config.host
token = workspace_client.config.token

# Construct the serving endpoint URL


endpoint_url = f"{workspace_url}/serving-endpoints/{app_config.serving_endpoint}/invocations"

# Prepare the request headers
headers = {
    "Authorization": f"Bearer {token}",
    "Content-Type": "application/json"
}

# Prepare the request payload
payload = {
    "messages": [
        {
            "role": "system", 
            "content": "Respond with So... and keep it short."
        },
        {
            "role": "user", 
            "content": "What is the factor of safety for strapping?"
        }
    ]
}

# Make the REST request
response = requests.post(
    endpoint_url,
    headers=headers,
    json=payload
)

# Check response status and get result
if response.status_code == 200:
    result = response.json()
    print(result)
else:
    print(f"Error: {response.status_code}")
    print(response.text)


In [0]:
result['choices'][0]['message']['content']

In [0]:
# Format documents for gradio chat display
documents = result.get("custom_outputs", {}).get("documents", [])
formatted_docs = []

for doc in documents:
    metadata = doc.get("metadata", {})
    content = doc.get("page_content", "")
    
    formatted_text = f"""
Source: {metadata.get('filename', 'Unknown')}
Type: {metadata.get('type', 'Unknown')}
ID: {metadata.get('id', 'Unknown')}

Content:
{content}
"""
    formatted_docs.append(formatted_text)

"\n\n---\n\n".join(formatted_docs)