# Imports

In [15]:
from typing import TypedDict, Optional, Dict, Annotated, List, Literal
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate, ChatPromptTemplate
from langgraph.graph.message import add_messages
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import MessagesState
from langchain_core.messages import HumanMessage, SystemMessage, AnyMessage, AIMessage
from pydantic import BaseModel, Field
from Prompts import Prompts
from LLMProvider import LLMProvider
from PydanticModels import StoneBreakerState, IsContextEnough, FixContext, Query, OptimizedQuery, FinalEvaluation
from ParquetRAG import ParquetRAG
import sqlglot
import sqlite3

# Models

In [2]:
llm_provider=LLMProvider()

In [16]:
groq_llm = llm_provider.get_llm()
groq_is_context_enough = llm_provider.get_structured_llm(IsContextEnough)
groq_fix_context = llm_provider.get_structured_llm(FixContext)
groq_gen_query = llm_provider.get_structured_llm(Query)
groq_optimized_query = llm_provider.get_structured_llm(OptimizedQuery)
groq_final_evaluation = llm_provider.get_structured_llm(FinalEvaluation)

# Nodes

In [18]:
def sql_prompt_node(state: StoneBreakerState):
    sql_prompt = input("Enter the SQL prompt you want: ")
    return {   
        "sql_prompt": sql_prompt
    }

In [19]:
def database_connection_node(state:StoneBreakerState):
    db_path=input("Enter the path to your SQLite database:")
    
    try:
        conn= sqlite3.connect(db_path)
        cursor=conn.cursor()
        
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables=cursor.fetchall()
        schema=[]
        for table in tables:
            table_name=table[0]
            cursor.execute(f"PRAGMA table_info({table_name});")
            columns=cursor.fetchall()
            
            create_stmt = f"CREATE TABLE {table_name} ("
            cols=[]
            for col in columns:
                col_name=col[1]
                col_type=col[2]
                is_pk= "PRIMARY KEY" if col[5]==1 else ""
                cols.append(f"{col_name} {col_type} {is_pk}".strip())
                
            create_stmt += ", ".join(cols) + ");"
            schema.append(create_stmt)
        conn.close()
        
        return {
            "data_base":db_path,
            "sql_context": "\n".join(schema)
        }
    except  Exception as e:
        return {
            "error": f"Failed to connect to database: {str(e)}"
        }

In [20]:
def evaluation_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    prompt = Prompts.gen_fix_context().invoke({"sql_context": sql_context, "sql_prompt": sql_prompt})
    response = groq_fix_context.invoke(prompt)
    refined_sql_context = response.context
    return {
        "sql_context": refined_sql_context,
        "sql_prompt": sql_prompt   
    }

In [21]:
def sql_context_from_vector_store_node(state: StoneBreakerState):
    query = f"""
    SQL Context: {state.sql_prompt}
    SQL Prompt: {state.sql_context}
    """
    parquetRAG = ParquetRAG()
    results = parquetRAG.retrieve(query)
    sql_context_from_vector_store = ""
    for d in results:
        sql_context_from_vector_store += d.page_content
    return {
        "sql_context_from_vector_store": sql_context_from_vector_store
    }

In [22]:
def query_generation_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    sql_context_from_vector_store = state.sql_context_from_vector_store
    prompt = Prompts.gen_query().invoke({
        "sql_context": sql_context, 
        "sql_prompt": sql_prompt,
        "sql_context_from_vector_store": sql_context_from_vector_store
    })
    response = groq_gen_query.invoke(prompt)
    sql_query_generated = response.query
    return {
        "sql_query_generated": sql_query_generated
    }

In [23]:
def optimizations_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_query = state.sql_query_generated
    prompt = Prompts.gen_optimized_query().invoke({"sql_context": sql_context, "sql_query": sql_query})
    response = groq_optimized_query.invoke(prompt)
    sql_query_optimized = response.query
    return {
        "sql_query_optimized": sql_query_optimized
    }

In [24]:
def execution_node(state: StoneBreakerState):
    optimized_query = state.sql_query_optimized
    database = state.data_base
    try:
        conn = sqlite3.connect(database) 
        cursor = conn.cursor()
        cursor.execute(optimized_query)
        result = cursor.fetchall()
        conn.commit()
        conn.close()
        return {
            "executed_success": True,
            "execution_results": result
        }
    except Exception as e:
        return {
            "executed_success": False,
            "error": str(e)
        }

In [25]:
def final_evaluation_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    sql_query = state.sql_query_optimized
    prompt = Prompts.gen_final_evaluation().invoke({
        "sql_prompt": sql_prompt,
        "sql_context": sql_context,
        "sql_query": sql_query
    })
    response = groq_final_evaluation.invoke(prompt)
    return {
        "final_evaluation": response.evaluation  
    }

In [26]:
def conversion_node(state: StoneBreakerState):
    sql_query = state.sql_query_optimized
    spark_sql = sqlglot.transpile(sql_query, write="spark", read="sqlite", pretty=True)[0]  
    trino_sql = sqlglot.transpile(sql_query, write="trino", read="sqlite", pretty=True)[0]  
    return {
        "trino_sql": trino_sql,
        "spark_sql": spark_sql
    }


# Edges

In [27]:
def sql_context_evaluation_edge(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    prompt = Prompts.gen_is_context_enough().invoke({"sql_context": sql_context, "sql_prompt": sql_prompt})
    evaluation_response = groq_is_context_enough.invoke(prompt)
    if evaluation_response.isEnough:
        return "sql_context_from_vector_store_node"
    else:
        return "evaluation_node"

In [28]:

def error_edge(state: StoneBreakerState):
    if state.executed_success:
        return "final_evaluation_node"
    else:
        return "query_generation_node"

In [29]:
def check_creation_edge(state: StoneBreakerState):
    evaluation = state.final_evaluation
    
    if not evaluation:  
        return "query_generation_node"
    else:
        return "conversion_node"

# Graph Workflow

In [43]:
workflow = StateGraph(state_schema=StoneBreakerState)

In [44]:
workflow.add_node("sql_prompt_node", sql_prompt_node)
workflow.add_node("database_connection_node",database_connection_node)
workflow.add_node("evaluation_node", evaluation_node)
workflow.add_node("sql_context_from_vector_store_node", sql_context_from_vector_store_node)
workflow.add_node("query_generation_node", query_generation_node)
workflow.add_node("optimizations_node", optimizations_node)
workflow.add_node("execution_node", execution_node)
workflow.add_node("final_evaluation_node", final_evaluation_node)
workflow.add_node("conversion_node", conversion_node)

<langgraph.graph.state.StateGraph at 0x7fb834382350>

In [45]:
workflow.add_edge(START, "sql_prompt_node")
workflow.add_edge("sql_prompt_node","database_connection_node")
workflow.add_conditional_edges("database_connection_node",sql_context_evaluation_edge)
workflow.add_conditional_edges("evaluation_node", sql_context_evaluation_edge)

workflow.add_edge("sql_context_from_vector_store_node", "query_generation_node")
workflow.add_edge("query_generation_node", "optimizations_node")
workflow.add_edge("optimizations_node", "execution_node")
workflow.add_conditional_edges("execution_node", error_edge)
workflow.add_conditional_edges("final_evaluation_node", check_creation_edge)
workflow.add_edge("conversion_node", END)

<langgraph.graph.state.StateGraph at 0x7fb834382350>

In [46]:
workflow.nodes

{'sql_prompt_node': StateNodeSpec(runnable=sql_prompt_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'database_connection_node': StateNodeSpec(runnable=database_connection_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'evaluation_node': StateNodeSpec(runnable=evaluation_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'sql_context_from_vector_store_node': StateNodeSpec(runnable=sql_context_from_vector_store_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, r

In [47]:
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)

In [49]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_png()))

ImportError: Install pygraphviz to draw graphs: `pip install pygraphviz`.

In [20]:
config = {"configurable": {"thread_id": "abc123"}}

In [15]:
try:
    output = graph.invoke({"sql_context":"Nice"}, config)
except Exception as err:
    print(err)

  self.embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",model_kwargs={"token":"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"})
  results = self.retriever.get_relevant_documents(query)


In [16]:
output

{'sql_context': "CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",
 'sql_prompt': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
 'sql_context_from_vector_store': "SQL Context: CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150,