In [1]:
from langchain_community.utilities import SQLDatabase
from typing_extensions import TypedDict
from IPython.display import display, Markdown

from langchain.chat_models import init_chat_model
from langchain.prompts import PromptTemplate
from langchain import hub
from typing_extensions import Annotated
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool

from langgraph.graph import START, StateGraph,END
from langgraph.prebuilt import tools_condition
from langgraph.checkpoint.memory import MemorySaver
from langchain_ollama.llms import OllamaLLM


In [2]:
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"

In [11]:



def query_template():
    query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")
    return query_prompt_template
    #query_prompt_template.messages[0].pretty_print()
    
def my_llm():
    llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq",temperature=0)
    # llm = OllamaLLM(
    #     model="ollama/deepseek-r1:8b",
    #     base_url="http://localhost:11434"
    # )
    
    return llm

class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):

    query_prompt_template= query_template()
    
    """Generate SQL query to fetch information."""
    prompt = query_prompt_template.invoke(
        {
            "dialect": db.dialect,
            "top_k": 10,
            "table_info": db.get_table_info(),
            "input": state["question"],
        }
    )
    llm = my_llm()
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return {"query": result["query"]}


def execute_query(state: State):
    
    """Execute SQL query."""
    execute_query_tool = QuerySQLDatabaseTool(db=db)
    return {"result": execute_query_tool.invoke(state["query"])}


def generate_answer(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    llm = my_llm()
    response = llm.invoke(prompt)
    return {"answer": response.content}

def analyst(state: State):
    """Answer question using retrieved information as context."""
    prompt = (
        "Given the following user question, corresponding SQL query, "
        "and SQL result, answer the user question.\n\n"
        f'Question: {state["question"]}\n'
        f'SQL Query: {state["query"]}\n'
        f'SQL Result: {state["result"]}'
    )
    llm = my_llm()
    response = llm.invoke(prompt)
    return {"answer": response.content}

def create_sql_graph():
    # Initialize the StateGraph with the State TypedDict
    graph_builder = StateGraph(State).add_sequence(
        [write_query, execute_query, generate_answer]
    )
    
    # Add edges to the graph
    graph_builder.add_edge(START, "write_query")
    graph_builder.add_edge("generate_answer", END)

    # Create a MemorySaver for persistence
    memory = MemorySaver()
    
    # Compile the graph with the memory checkpointer and specify interrupt points
    sql_graph = graph_builder.compile(checkpointer=memory, interrupt_before=["execute_query"])
    
    # Return the compiled graph and the configuration
    config = {"configurable": {"thread_id": "1"}}
    
    return sql_graph, config

def run_sql_generator(graph,config,ques):
    # for step in graph.stream(ques, config,stream_mode="updates"):
    #     print(step)
    # graph.invoke()
    graph.invoke(ques, config)


    try:
        user_approval = 'yes'#input("Do you want to go to execute query? (yes/no): ")
    except Exception:
        user_approval = "no"

    if user_approval.lower() == "yes":
        # If approved, continue the graph execution
        for step in graph.stream(None, config, stream_mode="updates"):
            print(step),
        formatted_answer = f"#### {step['generate_answer']['answer'].replace(',', '\n- ')}"
        display(Markdown(formatted_answer))     
    else:
        print("Operation cancelled by user.")



In [None]:
ques1 = {"question": "what are the different tables available in database? Shownit as bullet points"}
ques2 = {"question": "can you provide the column names in each column in tabular format like tabel_name,Column_name,Description?"}
ques3 = {'question': "How many different types of Genre? Provide them with number of times it occus in our Genre table "}
ques4 = {'question': "How many different types of Genre? Provide them with number of times it occus in our Genre table "}
ques5 = {"question": "How many employees are there?"}
ques5 = {"question": "Which is most frequent  BillingCountry?"}
sql_graph, config = create_sql_graph()
run_sql_generator(sql_graph,config,ques5)

