# Imports

In [35]:
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from Prompts import Prompts
from LLMProvider import LLMProvider
from PydanticModels import StoneBreakerState, IsPromptRelated, Query, OptimizedQuery, FinalEvaluation
from ParquetRAG import ParquetRAG
import sqlglot
from Utils import  DatabaseConnection
from sqlalchemy import create_engine, inspect
from sqlalchemy.exc import SQLAlchemyError

# Models

In [15]:
llm_provider=LLMProvider()

In [18]:
groq_llm = llm_provider.get_llm()
groq_is_prompt_related = llm_provider.get_structured_llm(IsPromptRelated)
groq_gen_query = llm_provider.get_structured_llm(Query)
groq_optimized_query = llm_provider.get_structured_llm(OptimizedQuery)
groq_final_evaluation = llm_provider.get_structured_llm(FinalEvaluation)

# Nodes

In [36]:
def database_connection_node(state:StoneBreakerState):
    db_type=state.db_configurations.db_type
    username=state.db_configurations.db_username
    password=state.db_configurations.db_password
    host=state.db_configurations.db_host
    port=state.db_configurations.db_port
    db_name=state.db_configurations.db_name
    
    config:dict={
        "db_type":db_type,
        "username":username,
        "password":state.password,
        "host":state.host,
        "port":state.port,
        "db_name":state.db_name
    }
    db_connection_str = DatabaseConnection.build_connection_string(config=config) 
    engine= DatabaseConnection.connect_to_database(config=config)
    if enigne is None:
        return {
            "error":"Failed to connect to database"
        }
    try:
        inspector= inspect(engine)
        table_names= inspector.get_table_names()
        schema=[]
        for table in table_names:
            columns= inspector.get_columns(table)
            for col in columns:
                col_name=col.get("name")
                col_type=col.get("type")
                is_pk= "PRIMARY KEY" if col.get("primary_key") else ""
                cols.append(f"{col_name} {col_type} {is_pk}".strip())

            create_stmt = f"CREATE TABLE {table} (" + ", ".join(cols) + ");"
            schema.append(create_stmt)
        
        engine.dispose()
        
        return {
            "db_configurations":{
                "db_connection_str": db_connection_str
            },
            "sql_context": "\n".join(schema)
        }
    except  Exception as e:
        return {
            "error": f"Failed to connect to database: {str(e)}"
        }

In [37]:
def sql_prompt_node(state: StoneBreakerState):
    return {   
        "sql_prompt": state.sql_prompt
    }

In [39]:
def get_correct_sql_prompt_node(state: StoneBreakerState):
    sql_prompt=input("Enter a query which is related to the schema:")
    return {
        "sql_prompt": sql_prompt   
    }

In [40]:
def sql_context_from_vector_store_node(state: StoneBreakerState):
    query = f"""
    SQL Context: {state.sql_prompt}
    SQL Prompt: {state.sql_context}
    """
    parquetRAG = ParquetRAG()
    results = parquetRAG.retrieve(query)
    sql_context_from_vector_store = ""
    for d in results:
        sql_context_from_vector_store += d.page_content
    return {
        "sql_context_from_vector_store": sql_context_from_vector_store
    }

In [41]:
def query_generation_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    sql_context_from_vector_store = state.sql_context_from_vector_store
    prompt = Prompts.gen_query().invoke({
        "sql_context": sql_context, 
        "sql_prompt": sql_prompt,
        "sql_context_from_vector_store": sql_context_from_vector_store
    })
    response = groq_gen_query.invoke(prompt)
    sql_query_generated = response.query
    return {
        "sql_query_generated": sql_query_generated
    }

In [42]:
def optimizations_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_query = state.sql_query_generated
    prompt = Prompts.gen_optimized_query().invoke({"sql_context": sql_context, "sql_query": sql_query})
    response = groq_optimized_query.invoke(prompt)
    sql_query_optimized = response.query
    return {
        "sql_query_optimized": sql_query_optimized
    }

In [46]:
def execution_node(state: StoneBreakerState):
    optimized_query = state.sql_query_optimized
    config = state.connection_config  
    conn_str = state.db_configurations.db_connection_str
    engine = DatabaseConnection.get_engine(conn_str)
    
    if engine is None:
        return {"executed_success": False, "error": "Could not connect to the database"}
    
    try:
        # Use a transactional context to execute the query
        with engine.begin() as connection:
            result_proxy = connection.execute(optimized_query)
            # Fetch all rows returned by the query
            results = result_proxy.fetchall()
        
        # Dispose of the engine when done
        engine.dispose()
        
        return {
            "executed_success": True,
            "execution_results": results
        }
    except Exception as e:
        return {
            "executed_success": False,
            "error": str(e)
        }

In [47]:
def final_evaluation_node(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    sql_query = state.sql_query_optimized
    prompt = Prompts.gen_final_evaluation().invoke({
        "sql_prompt": sql_prompt,
        "sql_context": sql_context,
        "sql_query": sql_query
    })
    response = groq_final_evaluation.invoke(prompt)
    return {
        "final_evaluation": response.evaluation  
    }

In [48]:
def conversion_node(state: StoneBreakerState):
    sql_query = state.sql_query_optimized
    spark_sql = sqlglot.transpile(sql_query, write="spark", read="sqlite", pretty=True)[0]  
    trino_sql = sqlglot.transpile(sql_query, write="trino", read="sqlite", pretty=True)[0]  
    return {
        "trino_sql": trino_sql,
        "spark_sql": spark_sql
    }


# Edges

In [49]:
def sql_prompt_evaluation_edge(state: StoneBreakerState):
    sql_context = state.sql_context
    sql_prompt = state.sql_prompt
    prompt = Prompts.gen_is_prompt_related().invoke({"sql_context": sql_context, "sql_prompt": sql_prompt})
    evaluation_response = groq_is_prompt_related.invoke(prompt)
    if evaluation_response.isRelated:
        return "sql_context_from_vector_store_node"
    else:
        return "get_correct_sql_prompt_node"

In [50]:
def error_edge(state: StoneBreakerState):
    if state.executed_success:
        return "final_evaluation_node"
    else:
        return "query_generation_node"

In [51]:
def check_creation_edge(state: StoneBreakerState):
    evaluation = state.final_evaluation
    
    if not evaluation:  
        return "query_generation_node"
    else:
        return "conversion_node"

# Graph Workflow

In [52]:
workflow = StateGraph(state_schema=StoneBreakerState)

In [53]:
workflow.add_node("database_connection_node",database_connection_node)
workflow.add_node("sql_prompt_node", sql_prompt_node)
workflow.add_node("get_correct_sql_prompt_node",get_correct_sql_prompt_node)
workflow.add_node("sql_context_from_vector_store_node", sql_context_from_vector_store_node)
workflow.add_node("query_generation_node", query_generation_node)
workflow.add_node("optimizations_node", optimizations_node)
workflow.add_node("execution_node", execution_node) 
workflow.add_node("final_evaluation_node", final_evaluation_node)
workflow.add_node("conversion_node", conversion_node)

<langgraph.graph.state.StateGraph at 0x7598f93f8520>

In [54]:
workflow.add_edge(START, "database_connection_node")
workflow.add_edge("database_connection_node","sql_prompt_node")
workflow.add_conditional_edges("sql_prompt_node",sql_context_evaluation_edge)
workflow.add_conditional_edges("get_correct_sql_prompt_node", sql_context_evaluation_edge)

workflow.add_edge("sql_context_from_vector_store_node", "query_generation_node")
workflow.add_edge("query_generation_node", "optimizations_node")
workflow.add_edge("optimizations_node", "execution_node")
workflow.add_conditional_edges("execution_node", error_edge)
workflow.add_conditional_edges("final_evaluation_node", check_creation_edge)
workflow.add_edge("conversion_node", END)

<langgraph.graph.state.StateGraph at 0x7598f93f8520>

In [55]:
workflow.nodes

{'database_connection_node': StateNodeSpec(runnable=database_connection_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'sql_prompt_node': StateNodeSpec(runnable=sql_prompt_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'get_correct_sql_prompt_node': StateNodeSpec(runnable=get_correct_sql_prompt_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModels.StoneBreakerState'>, retry_policy=None, ends=()),
 'sql_context_from_vector_store_node': StateNodeSpec(runnable=sql_context_from_vector_store_node(tags=None, recurse=True, explode_args=False, func_accepts_config=False, func_accepts={}), metadata=None, input=<class 'PydanticModel

In [56]:
memory = MemorySaver()
graph = workflow.compile(checkpointer=memory)

In [58]:
from IPython.display import Image, display

display(Image(graph.get_graph().draw_mermaid_png()))

ReadTimeout: HTTPSConnectionPool(host='mermaid.ink', port=443): Read timed out. (read timeout=10)

In [20]:
config = {"configurable": {"thread_id": "abc123"}}

In [15]:
try:
    output = graph.invoke({"sql_context":"Nice"}, config)
except Exception as err:
    print(err)

  self.embeddings_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2",model_kwargs={"token":"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"})
  results = self.retriever.get_relevant_documents(query)


In [16]:
output

{'sql_context': "CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150, '2021-02-01'), (3, 2, 180, '2021-01-01');",
 'sql_prompt': 'What is the total volume of timber sold by each salesperson, sorted by salesperson?',
 'sql_context_from_vector_store': "SQL Context: CREATE TABLE salesperson (salesperson_id INT, name TEXT, region TEXT); INSERT INTO salesperson (salesperson_id, name, region) VALUES (1, 'John Doe', 'North'), (2, 'Jane Smith', 'South'); CREATE TABLE timber_sales (sales_id INT, salesperson_id INT, volume REAL, sale_date DATE); INSERT INTO timber_sales (sales_id, salesperson_id, volume, sale_date) VALUES (1, 1, 120, '2021-01-01'), (2, 1, 150,