In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma
from langchain_openai import OpenAIEmbeddings
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain.schema import Document  # Optional: For using Document objects
from langchain_core.runnables import chain
from langchain.prompts import ChatPromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser

import sqlite3
import pandas as pd
from typing import List, Dict

from typing_extensions import TypedDict

In [3]:
def create_retriever(docs_list, embeddings):
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        chunk_size=250, chunk_overlap=0
    )

    # (Optional) Convert dictionaries to Document objects for better compatibility
    documents = [Document(page_content=doc['page_content'], metadata=doc['metadata']) for doc in docs_list]

    # Split the documents into smaller chunks
    split_docs = text_splitter.split_documents(documents)

    # Add to vectorDB
    vectorstore = Chroma.from_documents(
        documents=split_docs,
        collection_name="sql-rag-test3",
        embedding=embeddings,
    )
    retriever = vectorstore.as_retriever(search_kwargs={'k': 3})
    return retriever

def call_db(local_db_path=None, query=None, info=None):
    assert local_db_path is not None
    assert query is not None
    assert info is not None
    with sqlite3.connect(local_db_path) as local_conn:
        cursor = local_conn.cursor()
        cursor.execute(query, (info,))
        rows = cursor.fetchall()
        column_names = [column[0] for column in cursor.description]
        results = [dict(zip(column_names, row)) for row in rows]
    return results

In [4]:
class FormatAns(BaseModel):
    """Structured output for generated SQL query."""
    user_input: str = Field(description='The user specified input for the SQL query.')
    sql_query: str = Field(description="Syntactically valid SQL query.")

class Plan(BaseModel):
    """Plan to follow in future."""
    plan: str = Field(description="Steps to follow, should be in sorted order")

# Data model - how to track information over states
class Grade(BaseModel):
    """Binary score for relevance check to generate sql or do RAG."""
    binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")


In [5]:
# original sql generation prompt
template = """
You are an assistant for generating sql queries. Use the following pieces of retrieved context to answer the question. 
If you don't know the answer, just say that you don't know. Only answer with a sql query.
**When the sql is generated, replace any user defined values with a question mark.**
===
Example:
Question: What airport am I flying out of? my passenger id is \"3442 587242.\"
Context:[Document(page_content='How to answer questions with What. Database has 11 rows: seat_no, fare_conditions, scheduled_arrival, scheduled_departure, ticket_no, book_ref, passenger_id, flight_id, flight_no, departure_airport, arrival_airport.\n        The database is named query_results.')]
user_input: 3442 587242
sql_query: SELECT departure_airport FROM query_results WHERE passenger_id = ?
===
Question: {question} 
Context: {context} 
"""

chat_prompt = ChatPromptTemplate.from_template(template)

# Updated process prompt now includes conversation history
process_template = """
You are an assistant that answers a user's question based on json. 
Here is the conversation history so far:
{history}

Answer the question using the following retrieved context:
Question: {question} 
Context: {context} 
If you don't know the answer, just say that you don't know. Answer in a personal, conversational tone.
"""

process_prompt = ChatPromptTemplate.from_template(process_template)

# Add prompt to create a plan on how to generate sql query
plan_template = """
You are a helpful assistant that generates two precise plans on how to implement a sql query.
Make sure to pay attention to the user input provided in the question.
===
Question: What airport am I flying out of? my passenger id is \"3442 587242.\"
Context:[Document(page_content='How to answer questions with What. Database has 11 rows: seat_no, fare_conditions, scheduled_arrival, scheduled_departure, ticket_no, book_ref, passenger_id, flight_id, flight_no, departure_airport, arrival_airport.\n        The database is named query_results.')]
Plan: This SQL command retrieves all columns (SELECT *) from a table named query_results where the passenger_id column matches the value provided by the ? placeholder. The ? is a parameter marker, indicating a value will be passed in later to complete the query.
===
Question: {question} 
Context: {context} 
"""

plan_prompt = ChatPromptTemplate.from_template(plan_template)

# Add prompt to create plan on executing plan
exec_template = """
You are an assistant for generating sql queries. Use the following pieces of a plan to answer the question. 
If you don't know the answer, just say that you don't know. Only answer with a sql query. When the sql is generated, replace the user defined values with a question mark.
*Do not have user input be in brackets or quotations at the final answer.
*Make sure to pay attention to the user input value, [ticket_no] is not correct.
Question: {question} 
Plan: {plan} 
"""
execute_prompt = ChatPromptTemplate.from_template(exec_template)

process_template2 = """
You are an assistant that answers a user's question based on json. Use the following pieces of retrieved context to answer the question. 
If you don't know the answer, just say that you don't know. Do not answer based on the plan. Answer in a personal, conversational tone.
Question: {question} 
Plan: {plan}
Context: {context} 
"""

process_prompt2 = ChatPromptTemplate.from_template(process_template2)

# Add prompt to create plan on how to generate sql query
grade_template = """
Give a binary score 'yes' or 'no' to indicate whether the context and the question needs sql query generation.
Make sure to pay attention to the user input provided in the question.
===
Examples:
Question: How can I reschedule my flight?
Context: How to reschedule a flight: need to email returns@hpe.com, submit request, and someone will get back to you.
binary_score=no

Question: Can you tell me about my flight? my passenger id is \"3442 587242.\"
Context: [Document( page_content='How to answer questions with What. Database has 11 rows: seat_no, fare_conditions, scheduled_arrival, scheduled_departure, ticket_no, book_ref, passenger_id, flight_id, flight_no, departure_airport, arrival_airport.\n        The database is named query_results.'), Document( page_content='How to reschedule a flight: need to email returns@hpe.com, submit request, and someone will get back to you.'), Document( page_content='How to submit expense report: go to concur through home.hpe.com, and follow the necessary forms.')] 
binary_score=yes
===
Question: {question} 
Context: {context} 
"""

grade_prompt = ChatPromptTemplate.from_template(grade_template)

rag_template = """
You are a helpful assistant that answers question based on retrieved context.
Question: {question} 
Context: {context} 
"""

rag_prompt = ChatPromptTemplate.from_template(rag_template)

# Initalize Vector DB and Models

In [6]:
# Note: uncomment if you want to use 

# import getpass
# import os

# if not os.environ.get("NVIDIA_API_KEY", "").startswith("nvapi-"):
#     nvapi_key = getpass.getpass("Enter your NVIDIA API key: ")
#     assert nvapi_key.startswith("nvapi-"), f"{nvapi_key[:5]}... is not a valid key"
#     os.environ["NVIDIA_API_KEY"] = nvapi_key
# os.environ["NVIDIA_API_KEY"]

In [None]:
# Create RAG DB
docs_list = [
    {'metadata': {}, 'page_content': """
    How to answer questions with What. Database has 11 rows: seat_no, fare_conditions, scheduled_arrival, scheduled_departure, ticket_no, book_ref, passenger_id, flight_id, flight_no, departure_airport, arrival_airport.
    The database is named query_results. 
    """},
    {'metadata': {}, 'page_content': 'How to reschedule a flight: need to email returns@hpe.com, submit request, and someone will get back to you.'},
    {'metadata': {}, 'page_content': 'How to submit expense report: go to concur through home.hpe.com, and follow the necessary forms.'}
]

# NOTE: uncomment if you want to use NVIDIA Embedding model
# embeddings = NVIDIAEmbeddings(base_url='https://integrate.api.nvidia.com/v1',
#                                model="nvidia/nv-embed-v1", 
#                                api_key=os.environ["NVIDIA_API_KEY"],
#                                truncate="NONE")

embeddings = NVIDIAEmbeddings(base_url='http://embedding-tyler.models.mlds-kserve.us.rdlabs.hpecorp.net',
                               model="thenlper/gte-base", 
                               api_key='',
                               truncate="NONE")

retriever = create_retriever(docs_list, embeddings)
local_db_path = 'tickets_joined.db'

# Create models
llm = ChatNVIDIA(base_url="http://10.182.1.167:8080/v1",
                  model="meta/llama-3.1-70b-instruct", 
                  api_key="\'\'",
                  verbose=True)
formatted_llm = ChatNVIDIA(base_url="http://10.182.1.167:8080/v1",
                            model="meta/llama-3.1-70b-instruct", 
                            api_key="\'\'",
                            verbose=True).with_structured_output(FormatAns)
plan_llm = ChatNVIDIA(base_url="http://10.182.1.167:8080/v1",
                     model="meta/llama-3.1-70b-instruct", 
                     api_key="\'\'",
                     verbose=True).with_structured_output(Plan)

# LangGraph implementation

In [8]:
# Utilities and code needed to run agent with LangGraph

# Extend the GraphState to include a conversation history (list of messages)
class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: The current user question.
        history: List of messages (each is a dict with 'role' and 'content').
        generation: LLM generation (assistant response).
        structured_sql_query: Generated SQL query with user defined value.
        sql_results: Data returned from SQL query (list of JSON dicts).
        plan: LLM generated plan on how to generate the SQL query.
        documents: List of documents.
        sql_gen_check: Grade, a binary score ('yes' or 'no').
    """
    question: str
    history: List[Dict[str, str]]
    generation: str
    structured_sql_query: FormatAns
    sql_results: List[Dict]
    plan: Plan
    documents: List[str]
    sql_gen_check: Grade


def retrieve(state):
    print("---RETRIEVE---")
    # Initialize history if not already present and add the user question
    state.setdefault("history", [])
    state["history"].append({"role": "user", "content": state["question"]})
    
    question = state["question"]
    documents = retriever.get_relevant_documents(question)
    return {"documents": documents, "question": question, 'plan': None, "history": state["history"]}


def gen_sql_query(state):
    print("---GENERATE_SQL_QUERY---")
    formatted = chat_prompt.invoke({"context": state['documents'], 
                                    "question": state['question']})
    query_result = formatted_llm.invoke(formatted)
    return {"documents": state['documents'], 
            "question": state['question'],
            'plan': state['plan'],
            "structured_sql_query": query_result}

def execute_sql_query(state):
    print("---EXEC_SQL_QUERY---")
    print(state)
    query_result = state['structured_sql_query']
    r = None
    try:
        r = call_db(local_db_path=local_db_path,
                   query=query_result.sql_query,
                   info=query_result.user_input)
        print("r: ", r)
        return {"documents": state['documents'], 
                "question": state['question'],
                "structured_sql_query": state['structured_sql_query'],
                'plan': state['plan'],
                "sql_results": r}
    except Exception as e:
        print(e)
        return {"documents": state['documents'], 
                "question": state['question'],
                "structured_sql_query": state['structured_sql_query'],
                'plan': state['plan'],
                "sql_results": r}

def answer(state):
    print("---ANSWER---")
    print("state: ", state)
    # Combine the conversation history into a string
    conversation_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in state["history"])
    
    final_format = process_prompt.invoke({
        "history": conversation_history,
        "context": state["sql_results"],
        "question": state["question"]
    })
    print(final_format.to_string())
    answer_result = llm.invoke(final_format)
    
    # Append the assistant's answer to the conversation history
    state["history"].append({"role": "assistant", "content": answer_result.content})
    
    return {"documents": state["documents"], 
            "question": state["question"], 
            "structured_sql_query": state['structured_sql_query'],
            "sql_results": state['sql_results'],
            'plan': state['plan'],
            "generation": answer_result,
            "history": state["history"]}

def plan(state):
    docs = state['documents']
    question = state['question']
    plan_result = plan_prompt.invoke({"context": docs, "question": question})
    ans = plan_llm.invoke(plan_result)
    return {'question': state['question'],
            'documents': state['documents'],
            'plan': ans}

# update function
def gen_sql_query2(state):
    print("---GENERATE_SQL_QUERY---")
    plan_obj = state['plan']
    docs = state['documents']
    question = state['question']
    exec_prompt = execute_prompt.invoke({"plan": plan_obj.plan, "question": question, 'context': docs})
    print(exec_prompt.to_string())
    query_result = llm.with_structured_output(FormatAns).invoke(exec_prompt)
    print("query_result.user_input: ", query_result.user_input)
    print("query_result.sql_query: ", query_result.sql_query)
    return {"documents": state['documents'], 
            "question": state['question'],
            'plan': state['plan'],
            "structured_sql_query": query_result}

def check(state):
    docs = state['documents']
    print("docs: ", docs)
    question = state['question']
    print("question: ", question)
    grade = grade_prompt.invoke({"context": docs, "question": question})
    ans = llm.with_structured_output(Grade).invoke(grade)
    print("ans: ", ans)
    return {"documents": state['documents'], "question": state['question'], 'sql_gen_check': ans, 'plan': None}

def decide(state):
    ans = state['sql_gen_check'].binary_score
    if ans == 'yes':
        return 'yes'
    elif ans == 'no':
        return 'no'

def rag_answer(state):
    print("---ANSWER---")
    print("state: ", state)
    # Combine the conversation history into a string
    conversation_history = "\n".join(f"{msg['role']}: {msg['content']}" for msg in state["history"])
    
    final_format = process_prompt.invoke({
        "history": conversation_history,
        "context": state['documents'],
        "question": state["question"]
    })
    print(final_format.to_string())
    answer_result = llm.invoke(final_format)
    
    # Append the assistant's answer to the conversation history
    state["history"].append({"role": "assistant", "content": answer_result.content})
    
    return {"documents": state["documents"], 
            "question": state["question"], 
            'plan': state['plan'],
            "generation": answer_result,
            "history": state["history"]}


In [9]:
from IPython.display import Image, display
from langgraph.graph import StateGraph, START, END

# Build graph
builder3 = StateGraph(GraphState)
builder3.add_node("retrieve", retrieve)
builder3.add_node("check", check)
builder3.add_node("generate_sql_query", gen_sql_query2)
builder3.add_node("execute_sql_query", execute_sql_query)
builder3.add_node("planning", plan)
builder3.add_node("answer", answer)
builder3.add_node("rag_answer", rag_answer)

# Build Edges
builder3.add_edge(START, "retrieve")
builder3.add_edge('retrieve','check')
builder3.add_conditional_edges('check',
                               decide,
                               {
                                   'yes':'planning',
                                   'no': 'rag_answer'
                               }
                              )
builder3.add_edge("planning", 'generate_sql_query')
builder3.add_edge("generate_sql_query", "execute_sql_query")
builder3.add_edge("execute_sql_query", "answer")
builder3.add_edge("answer", END)

# Compile
app3 = builder3.compile()

In [10]:
# Uncomment to display the graph if needed
# from IPython.display import Image, display
# display(Image(app3.get_graph(xray=True).draw_mermaid_png()))

In [None]:
# Run graph with conversation history support
inputs = {"question": "Can you tell me about my flight's departure time? my ticket_no is \"7240005432906569\".",
          'history':[]}
for output in app3.stream(inputs):
    for key, value in output.items():
        print(f"Node '{key}':")
print(value["generation"].content)
print(value)

In [None]:
# Run graph with conversation history support
inputs = {"question": "can you clarify the time in PM?",
          'history':[{'role': 'user', 
                      'content': 'Can you tell me about my flight\'s departure time? my ticket_no is "7240005432906569".'}, 
                      {'role': 'assistant', 'content': 'Your flight is scheduled to depart on January 17th, 2025, at 16:35:25, so be sure to arrive at the airport with plenty of time to spare.'}]}
for output in app3.stream(inputs):
    for key, value in output.items():
        print(f"Node '{key}':")
print(value["generation"].content)
print(value)

In [None]:
# Run graph for a non-sql case
inputs = {"question": "How can I reschedule my flight?",'history':[]}
for output in app3.stream(inputs):
    for key, value in output.items():
        print(f"Node '{key}':")
print(value["generation"].content)