In [1]:
# load config
from pathlib import Path
import mlflow
from maud.agent.config import parse_config
import os

root_dir = Path(os.getcwd()).parent
config_path = root_dir / 'implementations' / 'agents' / 'openai' / 'config.yaml'
agent_path = root_dir / 'implementations' / 'agents' / 'openai' / 'agent.py'

mlflow_config = mlflow.models.ModelConfig(development_config=config_path)
maud_config = parse_config(mlflow_config)

In [4]:
from openai import OpenAI
from databricks.sdk import WorkspaceClient
w = WorkspaceClient()
llm_client: OpenAI = w.serving_endpoints.get_open_ai_client()

In [7]:
import mlflow

from typing import List
@mlflow.trace(span_type="RETRIEVER", name="vector_search_retriever")
def retrieve_docs(query: str) -> List[dict]:
    """
    Performs vector search to retrieve relevant chunks.

    Args:
        query: Search query.
        filters: Optional filters to apply to the search. Should follow the LLM-generated filter pattern of a list of field/filter pairs that will be converted to Databricks Vector Search filter format.

    Returns:
        List of retrieved Documents.
    """
    traced_search = mlflow.trace(
        w.vector_search_indexes.query_index,
        name="_workspace_client.vector_search_indexes.query_index",
        span_type="FUNCTION",
    )

    results = traced_search(
        index_name=maud_config.retriever.index_name,
        query_text=query,
        columns=maud_config.retriever.mapping.all_columns,
        **maud_config.retriever.parameters.model_dump(),
    )
    return results.as_dict()

In [8]:
retrieve_docs(query="What is SQL?")



TypeError: VectorSearchIndexesAPI.query_index() got an unexpected keyword argument 'k'

In [3]:
retriever_tool_spec = [
    {
        "type": "function",
        "function": {
            "name": maud_config.retriever.tool_name,
            "description": maud_config.retriever.tool_description,
            "parameters": {
                "type": "object",
                "required": ["query"],
                "additionalProperties": False,
                "properties": {
                    "query": {
                        "description": "query to look up in retriever",
                        "type": "string",
                    }
                },
            },
        },
    }
]

tool_functions = {maud_config.retriever.tool_name: self.retrieve_docs}

In [2]:
from maud.agent.retrievers import get_vector_retriever
retriever = get_vector_retriever(maud_config)

from databricks_langchain import ChatDatabricks
model = ChatDatabricks(endpoint=maud_config.model.endpoint_name)

[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True to VectorSearchClient().


Let's setup some nodes to play with

In [3]:
from langgraph.graph import StateGraph, START, END
from langchain_core.messages import HumanMessage, AIMessage
from maud.agent.states import get_state
from maud.agent.nodes import (
    make_query_vector_database_node, 
    make_context_generation_node,
    make_rephrase_generation_node,
    make_simple_generation_node
)

state = get_state(maud_config)
retriever_node = make_query_vector_database_node(retriever, maud_config)
simple_generation_node = make_simple_generation_node(model, maud_config)
context_generation_node = make_context_generation_node(model, maud_config)
rephrase_generation_node = make_rephrase_generation_node(model, maud_config)

This section builds a simple generation graph. It expects an input state with a dictionary of messages.

In [4]:
from data.messages.input_examples import input_example
input_state = {'messages':[{'type':'user', 'content':'What is SQL?'}]}

We use langchain's convert_to_messages to convert the input state to a list of messages. This is convenient because OpenAI uses 'role' and LangChain uses 'type'. We centralize on LangChain's message type for now.

This expects a list of dictionaries with 'type' and 'content' keys. Will fail is the entire {'messages':[{'type':'user', 'content':'What is SQL?'}]} is passed in.

We can use the convert_to_openai_messages function to convert the list of LangChain messages back to a list of dictionaries with 'role' and 'content' keys.

```python
from langchain_core.messages.utils import convert_to_messages
lc_msgs = convert_to_messages(input_example['messages'])
```

In [5]:
workflow = StateGraph(state)
workflow.add_node("generate", simple_generation_node)
workflow.add_edge(START, "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
app.invoke(input_example)

{'messages': [{'role': 'user', 'content': 'What is Apache Spark'},
  {'role': 'assistant',
   'content': 'Apache Spark is a unified analytics engine for large-scale data processing, providing high-level APIs in Java, Scala, Python, and R, and an optimized engine that supports general execution graphs. It has a rich set of higher-level tools, including Spark SQL for SQL and structured data processing, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for incremental computation and stream processing. Apache Spark is capable of handling large-scale data processing, machine learning, and data analytics, making it a unified analytics engine.'},
  {'role': 'user', 'content': 'Does it support streaming?'},
  {'role': 'assistant',
   'content': 'So, yeah. Yes, it does support streaming, allowing you to watch your favorite shows and movies online.'}]}

Check our rephrasing generation node

In [6]:
workflow = StateGraph(state)
workflow.add_node("generate", rephrase_generation_node)
workflow.add_edge(START, "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
app.invoke(input_example)

{'messages': [{'role': 'user', 'content': 'What is Apache Spark'},
  {'role': 'assistant',
   'content': 'Apache Spark is a unified analytics engine for large-scale data processing, providing high-level APIs in Java, Scala, Python, and R, and an optimized engine that supports general execution graphs. It has a rich set of higher-level tools, including Spark SQL for SQL and structured data processing, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for incremental computation and stream processing. Apache Spark is capable of handling large-scale data processing, machine learning, and data analytics, making it a unified analytics engine.'},
  {'role': 'user', 'content': 'Does it support streaming?'},
  {'role': 'user',
   'content': 'What streaming capabilities does Apache Spark support for processing real-time data and incremental computation, as part of its unified analytics engine for large-scale data processi

Check context generation node with no context

In [7]:
workflow = StateGraph(state)
workflow.add_node("generate", context_generation_node)
workflow.add_edge(START, "generate")
workflow.add_edge("generate", END)
app = workflow.compile()
app.invoke(input_example)

{'messages': [{'role': 'user', 'content': 'What is Apache Spark'},
  {'role': 'assistant',
   'content': 'Apache Spark is a unified analytics engine for large-scale data processing, providing high-level APIs in Java, Scala, Python, and R, and an optimized engine that supports general execution graphs. It has a rich set of higher-level tools, including Spark SQL for SQL and structured data processing, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for incremental computation and stream processing. Apache Spark is capable of handling large-scale data processing, machine learning, and data analytics, making it a unified analytics engine.'},
  {'role': 'user', 'content': 'Does it support streaming?'},
  {'role': 'assistant',
   'content': "I don't know. The provided context is empty, so I have no information to determine if something supports streaming."}]}

Check retriever only


In [8]:
workflow = StateGraph(state)
workflow.add_node("retrieve", retriever_node)
workflow.add_edge(START, "retrieve")
app = workflow.compile()
app.invoke(input_example)

{'messages': [{'role': 'user', 'content': 'What is Apache Spark'},
  {'role': 'assistant',
   'content': 'Apache Spark is a unified analytics engine for large-scale data processing, providing high-level APIs in Java, Scala, Python, and R, and an optimized engine that supports general execution graphs. It has a rich set of higher-level tools, including Spark SQL for SQL and structured data processing, pandas API on Spark for pandas workloads, MLlib for machine learning, GraphX for graph processing, and Structured Streaming for incremental computation and stream processing. Apache Spark is capable of handling large-scale data processing, machine learning, and data analytics, making it a unified analytics engine.'},
  {'role': 'user', 'content': 'Does it support streaming?'}],
 'context': 'Passage: Ch-Se-Su = . , Effectivity = . , Page = 503. , Date = Jan 18/2005. , Effectivity = . , Page = 404. , Date = Sep 16/2009. , Ch-Se-Su = . , Effectivity = . , Page = 504. , Date = Jan 18/2005. , E

This section builds a rag graph without history

In [9]:
workflow = StateGraph(state)
workflow.add_node("retrieve", retriever_node)
workflow.add_node("generate_w_context", context_generation_node)
workflow.add_edge(START, "retrieve")
workflow.add_edge("retrieve", "generate_w_context")
workflow.add_edge("generate_w_context", END)
app = workflow.compile()
result = app.invoke(input_example)

In [10]:
from maud.agent.utils import graph_state_to_chat_type
from langchain_core.runnables import RunnableLambda

chain = app | RunnableLambda(graph_state_to_chat_type)
chain.invoke(input_example)

{'choices': [{'message': {'role': 'assistant',
    'content': 'I do not know. The context appears to be a collection of unrelated data points with various codes, dates, and page numbers, but it does not provide any information about streaming capabilities.',
    'refusal': None,
    'name': None,
    'tool_calls': None,
    'tool_call_id': None},
   'index': 0,
   'finish_reason': 'stop',
   'logprobs': None}],
 'usage': None,
 'id': None,
 'model': None,
 'object': 'chat.completion',
 'created': 1738970510,
 'custom_outputs': {'message_history': [{'role': 'user',
    'content': 'What is Apache Spark'},
   {'role': 'assistant',
    'content': 'Apache Spark is a unified analytics engine for large-scale data processing, providing high-level APIs in Java, Scala, Python, and R, and an optimized engine that supports general execution graphs. It has a rich set of higher-level tools, including Spark SQL for SQL and structured data processing, pandas API on Spark for pandas workloads, MLlib fo